""" AggregatorBase - Base class for all Aggregator implementations. Provides shared functionality: - Patch embedding (DINOv2) - Special tokens (camera, register, scale) - Block building - Common forward pass structure Subclasses implement mode-specific attention logic. """ import logging import torch import torch.nn as nn from abc import ABC, abstractmethod from typing import Optional, Tuple, List from lingbot_map.layers import PatchEmbed from lingbot_map.layers.block import Block from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 logger = logging.getLogger(__name__) _RESNET_MEAN = [0.485, 0.456, 0.406] _RESNET_STD = [0.229, 0.224, 0.225] def slice_expand_and_flatten(token, B, S, first_num_frame=1): """ Helper function to slice, expand and flatten tokens. Args: token: Token tensor [1, 2, N, C] where first index is for first frames B: Batch size S: Sequence length first_num_frame: Number of frames to use first token for Returns: Flattened tokens [B*S, N, C] """ # token shape: [1, 2, N, C] # Expand to [B, S, N, C] if first_num_frame > 1: # Use first token for first first_num_frame frames, second for rest token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C] token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C] token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C] else: # Use first token for first frame, second for rest token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C] token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C] token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C] # Flatten to [B*S, N, C] return token_expanded.reshape(B * S, -1, token.shape[-1]) class AggregatorBase(nn.Module, ABC): """ Base class for all Aggregator implementations. Handles shared components: - Patch embedding (DINOv2 or conv) - Special tokens (camera, register, optionally scale) - Block creation (frame + global) - RoPE (2D rotary position embeddings) - Common forward pass scaffolding Subclasses must implement: - _process_global_attention(): Mode-specific cross-frame attention logic """ def __init__( self, # Architecture parameters img_size=518, patch_size=14, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4.0, num_register_tokens=4, # Block configuration block_fn=Block, qkv_bias=True, proj_bias=True, ffn_bias=True, qk_norm=True, init_values=0.01, # Patch embedding patch_embed="dinov2_vitl14_reg", pretrained_path=None, # Attention pattern aa_order=["frame", "global"], aa_block_size=1, # RoPE rope_freq=100, disable_global_rope=False, # Gradient checkpointing use_reentrant: bool = False, use_gradient_checkpoint: bool = True, ): super().__init__() # Store configuration self.img_size = img_size self.patch_size = patch_size self.embed_dim = embed_dim self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.num_register_tokens = num_register_tokens self.aa_order = aa_order self.aa_block_size = aa_block_size self.disable_global_rope = disable_global_rope self.use_reentrant = use_reentrant self.use_gradient_checkpoint = use_gradient_checkpoint self.pretrained_path = pretrained_path self.enable_ulysses_cp = False # CP disabled print("pretrained_path:", self.pretrained_path) # Validate depth if self.depth % self.aa_block_size != 0: raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") self.aa_block_num = self.depth // self.aa_block_size # Build patch embedding self._build_patch_embed( patch_embed=patch_embed, img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, embed_dim=embed_dim, pretrained_path=pretrained_path ) # Initialize RoPE self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None self.position_getter = PositionGetter() if self.rope is not None else None # Build blocks (frame + global) self._build_blocks( block_fn=block_fn, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, init_values=init_values, qk_norm=qk_norm, ) # Setup special tokens (camera, register, optionally scale) self._setup_special_tokens() # Register normalization constants for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) # Initialize from DINO checkpoint if available if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None: self._init_blocks_from_dino(self._dino_checkpoint) del self._dino_checkpoint # Free memory def _build_patch_embed( self, patch_embed: str, img_size: int, patch_size: int, num_register_tokens: int, embed_dim: int, pretrained_path: str, interpolate_antialias=True, interpolate_offset=0.0, block_chunks=0, init_values=1.0, ): """ Build patch embedding layer. Supports: - "conv": Simple convolutional patch embedding - "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2) """ if "conv" in patch_embed: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim ) self._dino_checkpoint = None else: vit_models = { "dinov2_vitl14_reg": vit_large, "dinov2_vitb14_reg": vit_base, "dinov2_vits14_reg": vit_small, "dinov2_vitg2_reg": vit_giant2, } if patch_embed not in vit_models: raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}") self.patch_embed = vit_models[patch_embed]( img_size=img_size, patch_size=patch_size, num_register_tokens=num_register_tokens, interpolate_antialias=interpolate_antialias, interpolate_offset=interpolate_offset, block_chunks=block_chunks, init_values=init_values, ) # Load pretrained weights try: ckpt = torch.load(pretrained_path) del ckpt['pos_embed'] logger.info("Loading pretrained weights for DINOv2") missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False) logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") # Store checkpoint for block initialization self._dino_checkpoint = ckpt except Exception as e: logger.warning(f"Failed to load pretrained weights: {e}") self._dino_checkpoint = None # Disable gradients for mask token if hasattr(self.patch_embed, "mask_token"): self.patch_embed.mask_token.requires_grad_(False) @abstractmethod def _build_blocks( self, block_fn, depth: int, embed_dim: int, num_heads: int, mlp_ratio: float, qkv_bias: bool, proj_bias: bool, ffn_bias: bool, init_values: float, qk_norm: bool, ): """ Build frame_blocks and global_blocks. Subclasses implement mode-specific block creation. Must create: - self.frame_blocks: nn.ModuleList of frame attention blocks - self.global_blocks: nn.ModuleList of global attention blocks """ pass @abstractmethod def _setup_special_tokens(self): """ Setup camera token, register tokens, and optionally scale token. Subclasses implement mode-specific token initialization. Must create: - self.camera_token - self.register_token - self.scale_token (optional, for causal mode) - self.patch_start_idx - self.num_special_tokens """ pass def _init_blocks_from_dino(self, dino_ckpt: dict): """ Initialize frame_blocks and global_blocks from DINOv2 pretrained weights. Args: dino_ckpt: Checkpoint dictionary from DINOv2 model """ logger.info("Initializing blocks from DINOv2 pretrained weights") # Extract block keys dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')] if not dino_block_keys: logger.warning("No 'blocks' found in DINO checkpoint") return # Get block indices block_indices = set() for key in dino_block_keys: parts = key.split('.') if len(parts) > 1 and parts[1].isdigit(): block_indices.add(int(parts[1])) num_dino_blocks = len(block_indices) print(f"Found {num_dino_blocks} blocks in DINO checkpoint") # Initialize frame_blocks for i, block in enumerate(self.frame_blocks): dino_block_idx = i % num_dino_blocks block_state_dict = {} prefix = f'blocks.{dino_block_idx}.' for key, value in dino_ckpt.items(): if key.startswith(prefix): new_key = key[len(prefix):] block_state_dict[new_key] = value if block_state_dict: missing, unexpected = block.load_state_dict(block_state_dict, strict=False) if i == 0: # Only log for first block to avoid spam print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") # Initialize global_blocks for i, block in enumerate(self.global_blocks): dino_block_idx = i % num_dino_blocks block_state_dict = {} prefix = f'blocks.{dino_block_idx}.' for key, value in dino_ckpt.items(): if key.startswith(prefix): new_key = key[len(prefix):] block_state_dict[new_key] = value if block_state_dict: missing, unexpected = block.load_state_dict(block_state_dict, strict=False) if i == 0: # Only log for first block to avoid spam print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") logger.info("Successfully initialized blocks from DINOv2 weights") def _embed_images( self, images: torch.Tensor, num_frame_for_scale: Optional[int] = None, ) -> Tuple[torch.Tensor, int, int, int, int, int]: """ Embed images and prepare for attention processing. Handles: - Image normalization - Patch embedding - Special token concatenation - Position embedding Args: images: Input images [B, S, 3, H, W] in range [0, 1] num_frame_for_scale: Number of frames for scale estimation (passed to special tokens) Returns: (tokens, B, S, S, P, C): tokens: Embedded tokens [B*S, P, C] B: Batch size S: Sequence length S: Same as above (no CP slicing) P: Number of tokens per frame (patches + special tokens) C: Embedding dimension """ B, S, C_in, H, W = images.shape if C_in != 3: raise ValueError(f"Expected 3 input channels, got {C_in}") # Normalize images images = (images - self._resnet_mean) / self._resnet_std # No CP slicing: S_local == S_global S_local = S S_global = S # Reshape for patch embedding [B*S, C, H, W] images = images.view(B * S, C_in, H, W) # Patch embedding patch_tokens = self.patch_embed(images) if isinstance(patch_tokens, dict): patch_tokens = patch_tokens["x_norm_patchtokens"] _, P_patch, C = patch_tokens.shape # Prepare special tokens special_tokens = self._prepare_special_tokens( B, S_local, S_global, C, num_frame_for_scale=num_frame_for_scale ) # Concatenate special tokens + patch tokens tokens = torch.cat([special_tokens, patch_tokens], dim=1) _, P, C = tokens.shape return tokens, B, S_local, S_global, P, C @abstractmethod def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor: """ Prepare special tokens (camera, register, optionally scale). Subclasses implement mode-specific token preparation. Args: B: Batch size S_local: Local sequence length S_global: Global sequence length C: Embedding dimension **kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode) Returns: Special tokens [B*S, N_special, C] """ pass def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]: """ Get 2D position embeddings for RoPE. Args: B: Batch size S: Sequence length H: Image height W: Image width device: Device to create positions on Returns: Position tensor [B*S, P, 2] or None if rope is disabled """ if self.rope is None: return None # Get patch positions pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device) # Add offset for patch tokens (skip special tokens at pos=0) if self.patch_start_idx > 0: pos = pos + 1 pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device) pos = torch.cat([pos_special, pos], dim=1) return pos def _process_frame_attention( self, tokens: torch.Tensor, B: int, S: int, P: int, C: int, frame_idx: int, pos: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]: """ Process frame attention blocks. Frame attention operates independently per frame (no cross-frame communication). Tokens stay in shape [B*S, P, C]. Args: tokens: Input tokens [B*S, P, C] B: Batch size S: Sequence length P: Tokens per frame C: Embedding dimension frame_idx: Current frame block index pos: Position embeddings [B*S, P, 2] Returns: (tokens, frame_idx, intermediates): tokens: Output tokens [B*S, P, C] frame_idx: Updated frame block index intermediates: List of intermediate outputs [B, S, P, C] """ # Ensure correct shape if tokens.shape != (B * S, P, C): tokens = tokens.view(B * S, P, C) if pos is not None and pos.shape != (B * S, P, 2): pos = pos.view(B * S, P, 2) intermediates = [] # Process blocks for i in range(self.aa_block_size): if self.training and self.use_gradient_checkpoint: from torch.utils.checkpoint import checkpoint tokens = checkpoint( self.frame_blocks[frame_idx], tokens, pos, False, # enable_ulysses_cp (always False) use_reentrant=self.use_reentrant ) else: tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False) frame_idx += 1 intermediates.append(tokens.view(B, S, P, C)) return tokens, frame_idx, intermediates @abstractmethod def _process_global_attention( self, tokens: torch.Tensor, B: int, S_local: int, S_global: int, P: int, C: int, global_idx: int, pos: Optional[torch.Tensor] = None, **kwargs ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]: """ Process global (cross-frame) attention blocks. Subclasses implement mode-specific attention logic. Args: tokens: Input tokens B: Batch size S_local: Local sequence length S_global: Global sequence length P: Tokens per frame C: Embedding dimension global_idx: Current global block index pos: Position embeddings **kwargs: Mode-specific parameters Returns: (tokens, global_idx, intermediates): tokens: Output tokens global_idx: Updated global block index intermediates: List of intermediate outputs """ pass def forward( self, images: torch.Tensor, selected_idx: Optional[List[int]] = None, # Mode-specific parameters num_frame_for_scale: Optional[int] = None, sliding_window_size: Optional[int] = None, num_frame_per_block: int = 1, ) -> Tuple[List[torch.Tensor], int]: """ Forward pass. Args: images: Input images [B, S, 3, H, W] in range [0, 1] selected_idx: Which block indices to output (None = all) num_frame_for_scale: Number of frames for scale estimation (causal mode) sliding_window_size: Sliding window size in blocks (causal mode) num_frame_per_block: Number of frames per processing block (causal mode) Returns: (output_list, patch_start_idx): output_list: List of block outputs [B, S, P, 2C] patch_start_idx: Index where patch tokens start """ B, S_input, _, H, W = images.shape # Embed images tokens, B, S_local, S_global, P, C = self._embed_images( images, num_frame_for_scale=num_frame_for_scale, ) # Get position embeddings pos_local = self._get_positions(B, S_local, H, W, device=images.device) pos_global = self._get_positions(B, S_global, H, W, device=images.device) # Alternating attention frame_idx = 0 global_idx = 0 output_list = [] for block_group_idx in range(self.aa_block_num): for attn_type in self.aa_order: if attn_type == "frame": tokens, frame_idx, frame_intermediates = self._process_frame_attention( tokens, B, S_local, P, C, frame_idx, pos=pos_local ) elif attn_type == "global": tokens, global_idx, global_intermediates = self._process_global_attention( tokens, B, S_local, S_global, P, C, global_idx, pos=pos_global, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, num_frame_per_block=num_frame_per_block, image_height=H, image_width=W, ) else: raise ValueError(f"Unknown attention type: {attn_type}") # Collect outputs if selected_idx is None or block_group_idx in selected_idx: for i in range(len(frame_intermediates)): # Concatenate frame and global intermediates [B, S, P, 2C] concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) output_list.append(concat_inter) return output_list, self.patch_start_idx