Files
lingbot-map/lingbot_map/aggregator/base.py
LinZhuoChen f9b3ae457a first commit
2026-04-16 09:51:30 +08:00

609 lines
20 KiB
Python

"""
AggregatorBase - Base class for all Aggregator implementations.
Provides shared functionality:
- Patch embedding (DINOv2)
- Special tokens (camera, register, scale)
- Block building
- Common forward pass structure
Subclasses implement mode-specific attention logic.
"""
import logging
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
from lingbot_map.layers import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
logger = logging.getLogger(__name__)
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
def slice_expand_and_flatten(token, B, S, first_num_frame=1):
"""
Helper function to slice, expand and flatten tokens.
Args:
token: Token tensor [1, 2, N, C] where first index is for first frames
B: Batch size
S: Sequence length
first_num_frame: Number of frames to use first token for
Returns:
Flattened tokens [B*S, N, C]
"""
# token shape: [1, 2, N, C]
# Expand to [B, S, N, C]
if first_num_frame > 1:
# Use first token for first first_num_frame frames, second for rest
token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
else:
# Use first token for first frame, second for rest
token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
# Flatten to [B*S, N, C]
return token_expanded.reshape(B * S, -1, token.shape[-1])
class AggregatorBase(nn.Module, ABC):
"""
Base class for all Aggregator implementations.
Handles shared components:
- Patch embedding (DINOv2 or conv)
- Special tokens (camera, register, optionally scale)
- Block creation (frame + global)
- RoPE (2D rotary position embeddings)
- Common forward pass scaffolding
Subclasses must implement:
- _process_global_attention(): Mode-specific cross-frame attention logic
"""
def __init__(
self,
# Architecture parameters
img_size=518,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=4,
# Block configuration
block_fn=Block,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
qk_norm=True,
init_values=0.01,
# Patch embedding
patch_embed="dinov2_vitl14_reg",
pretrained_path=None,
# Attention pattern
aa_order=["frame", "global"],
aa_block_size=1,
# RoPE
rope_freq=100,
disable_global_rope=False,
# Gradient checkpointing
use_reentrant: bool = False,
use_gradient_checkpoint: bool = True,
):
super().__init__()
# Store configuration
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.num_register_tokens = num_register_tokens
self.aa_order = aa_order
self.aa_block_size = aa_block_size
self.disable_global_rope = disable_global_rope
self.use_reentrant = use_reentrant
self.use_gradient_checkpoint = use_gradient_checkpoint
self.pretrained_path = pretrained_path
self.enable_ulysses_cp = False # CP disabled
print("pretrained_path:", self.pretrained_path)
# Validate depth
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
# Build patch embedding
self._build_patch_embed(
patch_embed=patch_embed,
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
embed_dim=embed_dim,
pretrained_path=pretrained_path
)
# Initialize RoPE
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
# Build blocks (frame + global)
self._build_blocks(
block_fn=block_fn,
depth=depth,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
)
# Setup special tokens (camera, register, optionally scale)
self._setup_special_tokens()
# Register normalization constants
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
# Initialize from DINO checkpoint if available
if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
self._init_blocks_from_dino(self._dino_checkpoint)
del self._dino_checkpoint # Free memory
def _build_patch_embed(
self,
patch_embed: str,
img_size: int,
patch_size: int,
num_register_tokens: int,
embed_dim: int,
pretrained_path: str,
interpolate_antialias=True,
interpolate_offset=0.0,
block_chunks=0,
init_values=1.0,
):
"""
Build patch embedding layer.
Supports:
- "conv": Simple convolutional patch embedding
- "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
"""
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim
)
self._dino_checkpoint = None
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}
if patch_embed not in vit_models:
raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)
# Load pretrained weights
try:
ckpt = torch.load(pretrained_path)
del ckpt['pos_embed']
logger.info("Loading pretrained weights for DINOv2")
missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Store checkpoint for block initialization
self._dino_checkpoint = ckpt
except Exception as e:
logger.warning(f"Failed to load pretrained weights: {e}")
self._dino_checkpoint = None
# Disable gradients for mask token
if hasattr(self.patch_embed, "mask_token"):
self.patch_embed.mask_token.requires_grad_(False)
@abstractmethod
def _build_blocks(
self,
block_fn,
depth: int,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
proj_bias: bool,
ffn_bias: bool,
init_values: float,
qk_norm: bool,
):
"""
Build frame_blocks and global_blocks.
Subclasses implement mode-specific block creation.
Must create:
- self.frame_blocks: nn.ModuleList of frame attention blocks
- self.global_blocks: nn.ModuleList of global attention blocks
"""
pass
@abstractmethod
def _setup_special_tokens(self):
"""
Setup camera token, register tokens, and optionally scale token.
Subclasses implement mode-specific token initialization.
Must create:
- self.camera_token
- self.register_token
- self.scale_token (optional, for causal mode)
- self.patch_start_idx
- self.num_special_tokens
"""
pass
def _init_blocks_from_dino(self, dino_ckpt: dict):
"""
Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
Args:
dino_ckpt: Checkpoint dictionary from DINOv2 model
"""
logger.info("Initializing blocks from DINOv2 pretrained weights")
# Extract block keys
dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
if not dino_block_keys:
logger.warning("No 'blocks' found in DINO checkpoint")
return
# Get block indices
block_indices = set()
for key in dino_block_keys:
parts = key.split('.')
if len(parts) > 1 and parts[1].isdigit():
block_indices.add(int(parts[1]))
num_dino_blocks = len(block_indices)
print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
# Initialize frame_blocks
for i, block in enumerate(self.frame_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Initialize global_blocks
for i, block in enumerate(self.global_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
logger.info("Successfully initialized blocks from DINOv2 weights")
def _embed_images(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
) -> Tuple[torch.Tensor, int, int, int, int, int]:
"""
Embed images and prepare for attention processing.
Handles:
- Image normalization
- Patch embedding
- Special token concatenation
- Position embedding
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
Returns:
(tokens, B, S, S, P, C):
tokens: Embedded tokens [B*S, P, C]
B: Batch size
S: Sequence length
S: Same as above (no CP slicing)
P: Number of tokens per frame (patches + special tokens)
C: Embedding dimension
"""
B, S, C_in, H, W = images.shape
if C_in != 3:
raise ValueError(f"Expected 3 input channels, got {C_in}")
# Normalize images
images = (images - self._resnet_mean) / self._resnet_std
# No CP slicing: S_local == S_global
S_local = S
S_global = S
# Reshape for patch embedding [B*S, C, H, W]
images = images.view(B * S, C_in, H, W)
# Patch embedding
patch_tokens = self.patch_embed(images)
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
_, P_patch, C = patch_tokens.shape
# Prepare special tokens
special_tokens = self._prepare_special_tokens(
B, S_local, S_global, C,
num_frame_for_scale=num_frame_for_scale
)
# Concatenate special tokens + patch tokens
tokens = torch.cat([special_tokens, patch_tokens], dim=1)
_, P, C = tokens.shape
return tokens, B, S_local, S_global, P, C
@abstractmethod
def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
"""
Prepare special tokens (camera, register, optionally scale).
Subclasses implement mode-specific token preparation.
Args:
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
C: Embedding dimension
**kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
Returns:
Special tokens [B*S, N_special, C]
"""
pass
def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
"""
Get 2D position embeddings for RoPE.
Args:
B: Batch size
S: Sequence length
H: Image height
W: Image width
device: Device to create positions on
Returns:
Position tensor [B*S, P, 2] or None if rope is disabled
"""
if self.rope is None:
return None
# Get patch positions
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
# Add offset for patch tokens (skip special tokens at pos=0)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
pos = torch.cat([pos_special, pos], dim=1)
return pos
def _process_frame_attention(
self,
tokens: torch.Tensor,
B: int,
S: int,
P: int,
C: int,
frame_idx: int,
pos: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process frame attention blocks.
Frame attention operates independently per frame (no cross-frame communication).
Tokens stay in shape [B*S, P, C].
Args:
tokens: Input tokens [B*S, P, C]
B: Batch size
S: Sequence length
P: Tokens per frame
C: Embedding dimension
frame_idx: Current frame block index
pos: Position embeddings [B*S, P, 2]
Returns:
(tokens, frame_idx, intermediates):
tokens: Output tokens [B*S, P, C]
frame_idx: Updated frame block index
intermediates: List of intermediate outputs [B, S, P, C]
"""
# Ensure correct shape
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B * S, P, 2)
intermediates = []
# Process blocks
for i in range(self.aa_block_size):
if self.training and self.use_gradient_checkpoint:
from torch.utils.checkpoint import checkpoint
tokens = checkpoint(
self.frame_blocks[frame_idx],
tokens,
pos,
False, # enable_ulysses_cp (always False)
use_reentrant=self.use_reentrant
)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
@abstractmethod
def _process_global_attention(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process global (cross-frame) attention blocks.
Subclasses implement mode-specific attention logic.
Args:
tokens: Input tokens
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Tokens per frame
C: Embedding dimension
global_idx: Current global block index
pos: Position embeddings
**kwargs: Mode-specific parameters
Returns:
(tokens, global_idx, intermediates):
tokens: Output tokens
global_idx: Updated global block index
intermediates: List of intermediate outputs
"""
pass
def forward(
self,
images: torch.Tensor,
selected_idx: Optional[List[int]] = None,
# Mode-specific parameters
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
) -> Tuple[List[torch.Tensor], int]:
"""
Forward pass.
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
selected_idx: Which block indices to output (None = all)
num_frame_for_scale: Number of frames for scale estimation (causal mode)
sliding_window_size: Sliding window size in blocks (causal mode)
num_frame_per_block: Number of frames per processing block (causal mode)
Returns:
(output_list, patch_start_idx):
output_list: List of block outputs [B, S, P, 2C]
patch_start_idx: Index where patch tokens start
"""
B, S_input, _, H, W = images.shape
# Embed images
tokens, B, S_local, S_global, P, C = self._embed_images(
images,
num_frame_for_scale=num_frame_for_scale,
)
# Get position embeddings
pos_local = self._get_positions(B, S_local, H, W, device=images.device)
pos_global = self._get_positions(B, S_global, H, W, device=images.device)
# Alternating attention
frame_idx = 0
global_idx = 0
output_list = []
for block_group_idx in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S_local, P, C, frame_idx, pos=pos_local
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S_local, S_global, P, C, global_idx,
pos=pos_global,
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
image_height=H,
image_width=W,
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
# Collect outputs
if selected_idx is None or block_group_idx in selected_idx:
for i in range(len(frame_intermediates)):
# Concatenate frame and global intermediates [B, S, P, 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
return output_list, self.patch_start_idx