first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .stream import AggregatorStream
from .base import AggregatorBase

View File

@@ -0,0 +1,608 @@
"""
AggregatorBase - Base class for all Aggregator implementations.
Provides shared functionality:
- Patch embedding (DINOv2)
- Special tokens (camera, register, scale)
- Block building
- Common forward pass structure
Subclasses implement mode-specific attention logic.
"""
import logging
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
from lingbot_map.layers import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
logger = logging.getLogger(__name__)
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
def slice_expand_and_flatten(token, B, S, first_num_frame=1):
"""
Helper function to slice, expand and flatten tokens.
Args:
token: Token tensor [1, 2, N, C] where first index is for first frames
B: Batch size
S: Sequence length
first_num_frame: Number of frames to use first token for
Returns:
Flattened tokens [B*S, N, C]
"""
# token shape: [1, 2, N, C]
# Expand to [B, S, N, C]
if first_num_frame > 1:
# Use first token for first first_num_frame frames, second for rest
token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
else:
# Use first token for first frame, second for rest
token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
# Flatten to [B*S, N, C]
return token_expanded.reshape(B * S, -1, token.shape[-1])
class AggregatorBase(nn.Module, ABC):
"""
Base class for all Aggregator implementations.
Handles shared components:
- Patch embedding (DINOv2 or conv)
- Special tokens (camera, register, optionally scale)
- Block creation (frame + global)
- RoPE (2D rotary position embeddings)
- Common forward pass scaffolding
Subclasses must implement:
- _process_global_attention(): Mode-specific cross-frame attention logic
"""
def __init__(
self,
# Architecture parameters
img_size=518,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=4,
# Block configuration
block_fn=Block,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
qk_norm=True,
init_values=0.01,
# Patch embedding
patch_embed="dinov2_vitl14_reg",
pretrained_path=None,
# Attention pattern
aa_order=["frame", "global"],
aa_block_size=1,
# RoPE
rope_freq=100,
disable_global_rope=False,
# Gradient checkpointing
use_reentrant: bool = False,
use_gradient_checkpoint: bool = True,
):
super().__init__()
# Store configuration
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.num_register_tokens = num_register_tokens
self.aa_order = aa_order
self.aa_block_size = aa_block_size
self.disable_global_rope = disable_global_rope
self.use_reentrant = use_reentrant
self.use_gradient_checkpoint = use_gradient_checkpoint
self.pretrained_path = pretrained_path
self.enable_ulysses_cp = False # CP disabled
print("pretrained_path:", self.pretrained_path)
# Validate depth
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
# Build patch embedding
self._build_patch_embed(
patch_embed=patch_embed,
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
embed_dim=embed_dim,
pretrained_path=pretrained_path
)
# Initialize RoPE
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
# Build blocks (frame + global)
self._build_blocks(
block_fn=block_fn,
depth=depth,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
)
# Setup special tokens (camera, register, optionally scale)
self._setup_special_tokens()
# Register normalization constants
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
# Initialize from DINO checkpoint if available
if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
self._init_blocks_from_dino(self._dino_checkpoint)
del self._dino_checkpoint # Free memory
def _build_patch_embed(
self,
patch_embed: str,
img_size: int,
patch_size: int,
num_register_tokens: int,
embed_dim: int,
pretrained_path: str,
interpolate_antialias=True,
interpolate_offset=0.0,
block_chunks=0,
init_values=1.0,
):
"""
Build patch embedding layer.
Supports:
- "conv": Simple convolutional patch embedding
- "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
"""
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim
)
self._dino_checkpoint = None
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}
if patch_embed not in vit_models:
raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)
# Load pretrained weights
try:
ckpt = torch.load(pretrained_path)
del ckpt['pos_embed']
logger.info("Loading pretrained weights for DINOv2")
missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Store checkpoint for block initialization
self._dino_checkpoint = ckpt
except Exception as e:
logger.warning(f"Failed to load pretrained weights: {e}")
self._dino_checkpoint = None
# Disable gradients for mask token
if hasattr(self.patch_embed, "mask_token"):
self.patch_embed.mask_token.requires_grad_(False)
@abstractmethod
def _build_blocks(
self,
block_fn,
depth: int,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
proj_bias: bool,
ffn_bias: bool,
init_values: float,
qk_norm: bool,
):
"""
Build frame_blocks and global_blocks.
Subclasses implement mode-specific block creation.
Must create:
- self.frame_blocks: nn.ModuleList of frame attention blocks
- self.global_blocks: nn.ModuleList of global attention blocks
"""
pass
@abstractmethod
def _setup_special_tokens(self):
"""
Setup camera token, register tokens, and optionally scale token.
Subclasses implement mode-specific token initialization.
Must create:
- self.camera_token
- self.register_token
- self.scale_token (optional, for causal mode)
- self.patch_start_idx
- self.num_special_tokens
"""
pass
def _init_blocks_from_dino(self, dino_ckpt: dict):
"""
Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
Args:
dino_ckpt: Checkpoint dictionary from DINOv2 model
"""
logger.info("Initializing blocks from DINOv2 pretrained weights")
# Extract block keys
dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
if not dino_block_keys:
logger.warning("No 'blocks' found in DINO checkpoint")
return
# Get block indices
block_indices = set()
for key in dino_block_keys:
parts = key.split('.')
if len(parts) > 1 and parts[1].isdigit():
block_indices.add(int(parts[1]))
num_dino_blocks = len(block_indices)
print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
# Initialize frame_blocks
for i, block in enumerate(self.frame_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Initialize global_blocks
for i, block in enumerate(self.global_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
logger.info("Successfully initialized blocks from DINOv2 weights")
def _embed_images(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
) -> Tuple[torch.Tensor, int, int, int, int, int]:
"""
Embed images and prepare for attention processing.
Handles:
- Image normalization
- Patch embedding
- Special token concatenation
- Position embedding
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
Returns:
(tokens, B, S, S, P, C):
tokens: Embedded tokens [B*S, P, C]
B: Batch size
S: Sequence length
S: Same as above (no CP slicing)
P: Number of tokens per frame (patches + special tokens)
C: Embedding dimension
"""
B, S, C_in, H, W = images.shape
if C_in != 3:
raise ValueError(f"Expected 3 input channels, got {C_in}")
# Normalize images
images = (images - self._resnet_mean) / self._resnet_std
# No CP slicing: S_local == S_global
S_local = S
S_global = S
# Reshape for patch embedding [B*S, C, H, W]
images = images.view(B * S, C_in, H, W)
# Patch embedding
patch_tokens = self.patch_embed(images)
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
_, P_patch, C = patch_tokens.shape
# Prepare special tokens
special_tokens = self._prepare_special_tokens(
B, S_local, S_global, C,
num_frame_for_scale=num_frame_for_scale
)
# Concatenate special tokens + patch tokens
tokens = torch.cat([special_tokens, patch_tokens], dim=1)
_, P, C = tokens.shape
return tokens, B, S_local, S_global, P, C
@abstractmethod
def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
"""
Prepare special tokens (camera, register, optionally scale).
Subclasses implement mode-specific token preparation.
Args:
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
C: Embedding dimension
**kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
Returns:
Special tokens [B*S, N_special, C]
"""
pass
def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
"""
Get 2D position embeddings for RoPE.
Args:
B: Batch size
S: Sequence length
H: Image height
W: Image width
device: Device to create positions on
Returns:
Position tensor [B*S, P, 2] or None if rope is disabled
"""
if self.rope is None:
return None
# Get patch positions
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
# Add offset for patch tokens (skip special tokens at pos=0)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
pos = torch.cat([pos_special, pos], dim=1)
return pos
def _process_frame_attention(
self,
tokens: torch.Tensor,
B: int,
S: int,
P: int,
C: int,
frame_idx: int,
pos: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process frame attention blocks.
Frame attention operates independently per frame (no cross-frame communication).
Tokens stay in shape [B*S, P, C].
Args:
tokens: Input tokens [B*S, P, C]
B: Batch size
S: Sequence length
P: Tokens per frame
C: Embedding dimension
frame_idx: Current frame block index
pos: Position embeddings [B*S, P, 2]
Returns:
(tokens, frame_idx, intermediates):
tokens: Output tokens [B*S, P, C]
frame_idx: Updated frame block index
intermediates: List of intermediate outputs [B, S, P, C]
"""
# Ensure correct shape
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B * S, P, 2)
intermediates = []
# Process blocks
for i in range(self.aa_block_size):
if self.training and self.use_gradient_checkpoint:
from torch.utils.checkpoint import checkpoint
tokens = checkpoint(
self.frame_blocks[frame_idx],
tokens,
pos,
False, # enable_ulysses_cp (always False)
use_reentrant=self.use_reentrant
)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
@abstractmethod
def _process_global_attention(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process global (cross-frame) attention blocks.
Subclasses implement mode-specific attention logic.
Args:
tokens: Input tokens
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Tokens per frame
C: Embedding dimension
global_idx: Current global block index
pos: Position embeddings
**kwargs: Mode-specific parameters
Returns:
(tokens, global_idx, intermediates):
tokens: Output tokens
global_idx: Updated global block index
intermediates: List of intermediate outputs
"""
pass
def forward(
self,
images: torch.Tensor,
selected_idx: Optional[List[int]] = None,
# Mode-specific parameters
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
) -> Tuple[List[torch.Tensor], int]:
"""
Forward pass.
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
selected_idx: Which block indices to output (None = all)
num_frame_for_scale: Number of frames for scale estimation (causal mode)
sliding_window_size: Sliding window size in blocks (causal mode)
num_frame_per_block: Number of frames per processing block (causal mode)
Returns:
(output_list, patch_start_idx):
output_list: List of block outputs [B, S, P, 2C]
patch_start_idx: Index where patch tokens start
"""
B, S_input, _, H, W = images.shape
# Embed images
tokens, B, S_local, S_global, P, C = self._embed_images(
images,
num_frame_for_scale=num_frame_for_scale,
)
# Get position embeddings
pos_local = self._get_positions(B, S_local, H, W, device=images.device)
pos_global = self._get_positions(B, S_global, H, W, device=images.device)
# Alternating attention
frame_idx = 0
global_idx = 0
output_list = []
for block_group_idx in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S_local, P, C, frame_idx, pos=pos_local
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S_local, S_global, P, C, global_idx,
pos=pos_global,
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
image_height=H,
image_width=W,
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
# Collect outputs
if selected_idx is None or block_group_idx in selected_idx:
for i in range(len(frame_intermediates)):
# Concatenate frame and global intermediates [B, S, P, 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
return output_list, self.patch_start_idx

View File

@@ -0,0 +1,531 @@
"""
AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
Provides:
- Temporal causal attention
- Sliding window support
- Scale token for scale estimation frames
- Streaming inference with FlashInfer paged KV cache
"""
import logging
import torch
import torch.nn as nn
from typing import Optional, Tuple, List
from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
from lingbot_map.layers.rope import WanRotaryPosEmbed
from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
logger = logging.getLogger(__name__)
class AggregatorStream(AggregatorBase):
"""
Streaming causal aggregator with FlashInfer paged KV cache.
Features:
- Temporal causal attention (each frame only attends to past frames)
- Sliding window support to limit attention scope
- Scale token for scale estimation frames
- Streaming inference with FlashInfer KV cache
"""
def __init__(
self,
# Causal-specific parameters
sliding_window_size: int = -1,
num_frame_for_scale: int = 1,
num_random_frames: int = 0,
attend_to_special_tokens: bool = False,
attend_to_scale_frames: bool = False,
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# Base class parameters via **kwargs
**kwargs
):
"""
Initialize AggregatorStream.
Args:
sliding_window_size: Sliding window size in blocks (-1 for full causal)
num_frame_for_scale: Number of scale estimation frames
num_random_frames: Number of random frames for long-range dependencies
attend_to_special_tokens: Enable cross-frame special token attention
attend_to_scale_frames: Include scale frames in attention
enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
max_frame_num: Maximum number of frames for 3D RoPE
kv_cache_sliding_window: Sliding window size for KV cache eviction
kv_cache_scale_frames: Number of scale frames to keep in KV cache
kv_cache_cross_frame_special: Keep special tokens from evicted frames
kv_cache_include_scale_frames: Include scale frames in KV cache
kv_cache_camera_only: Only keep camera tokens from evicted frames
**kwargs: Base class parameters
"""
self.sliding_window_size = sliding_window_size
self.num_frame_for_scale = num_frame_for_scale
self.num_random_frames = num_random_frames
self.attend_to_special_tokens = attend_to_special_tokens
self.attend_to_scale_frames = attend_to_scale_frames
self.enable_3d_rope = enable_3d_rope
self.max_frame_num = max_frame_num
# KV cache parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
# Pop kwargs that are passed but not needed by base class
kwargs.pop('enable_stream_inference', None)
use_flashinfer = kwargs.pop('use_flashinfer', True)
kwargs.pop('use_flexflash', None)
use_sdpa = kwargs.pop('use_sdpa', False)
# Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
self.use_sdpa = use_sdpa
self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
# Call parent __init__
super().__init__(**kwargs)
# Initialize KV cache
self._init_kv_cache()
# Initialize 3D RoPE if enabled
if self.enable_3d_rope:
self._init_3d_rope()
def _build_blocks(
self,
block_fn,
depth: int,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
proj_bias: bool,
ffn_bias: bool,
init_values: float,
qk_norm: bool,
):
"""Build frame and global blocks for streaming causal mode."""
block_params = dict(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
)
# Frame blocks: Standard Block + RoPE
self.frame_blocks = nn.ModuleList([
block_fn(**block_params, rope=self.rope)
for _ in range(depth)
])
# Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
self.global_blocks = nn.ModuleList([
GlobalBlockCls(
**block_params,
rope=self.rope if not self.disable_global_rope else None,
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
)
for _ in range(depth)
])
def _setup_special_tokens(self):
"""Setup camera, register, and scale tokens for causal mode."""
# Camera token
self.camera_token = nn.Parameter(
torch.randn(1, 2, 1, self.embed_dim)
)
# Register tokens
if self.num_register_tokens > 0:
self.register_token = nn.Parameter(
torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
)
# Scale token (causal mode specific)
self.scale_token = nn.Parameter(
torch.ones(1, 2, 1, self.embed_dim)
)
# Initialize
nn.init.normal_(self.camera_token, std=1e-6)
if self.num_register_tokens > 0:
nn.init.normal_(self.register_token, std=1e-6)
nn.init.normal_(self.scale_token, std=1e-6)
# Token indexing (includes scale token)
self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
self.num_special_tokens = 1 + self.num_register_tokens + 1
def _init_kv_cache(self):
"""Initialize KV cache for streaming inference."""
self.kv_cache_manager = None # FlashInfer (lazy-initialized)
self.kv_cache = {} # Dict-based cache for SDPA
self.total_frames_processed = 0
self._cached_pos3d = None
if self.use_sdpa:
# Dict-based KV cache for SDPA
if hasattr(self, 'depth'):
for i in range(self.depth):
self.kv_cache[f"k_{i}"] = None
self.kv_cache[f"v_{i}"] = None
self.kv_cache[f"k_{i}_special"] = None
self.kv_cache[f"v_{i}_special"] = None
logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
else:
logger.info("FlashInfer KV cache will be lazily initialized on first forward")
def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
"""Lazily initialize FlashInferKVCacheManager on first use.
Args:
device: Device for cache tensors.
dtype: Data type for cache tensors.
tokens_per_frame: Actual number of tokens per frame (patches + specials).
If None, falls back to assuming square images of self.img_size.
"""
if self.kv_cache_manager is None:
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
head_dim = 64
if tokens_per_frame is None:
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
# max_num_frames: scale + window + headroom
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
self.kv_cache_manager = FlashInferKVCacheManager(
num_blocks=self.depth,
max_num_frames=max_num_frames,
tokens_per_frame=tokens_per_frame,
num_heads=num_heads,
head_dim=head_dim,
dtype=dtype,
device=device,
num_special_tokens=self.num_special_tokens,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
max_total_frames=self.max_frame_num + 100,
force_fp32=getattr(self, 'kv_cache_force_fp32', False),
fa3=getattr(self, 'kv_cache_fa3', False),
)
logger.info(
f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
)
return self.kv_cache_manager
def clean_kv_cache(self):
"""Clean KV cache (call this when starting a new sequence)."""
if self.kv_cache_manager is not None:
self.kv_cache_manager.reset()
if self.kv_cache:
for key in list(self.kv_cache.keys()):
if key == "_skip_append":
self.kv_cache[key] = False
else:
self.kv_cache[key] = None
self.total_frames_processed = 0
self._cached_pos3d = None
logger.info("KV cache cleaned")
def _init_3d_rope(self):
"""Initialize 3D RoPE for streaming inference."""
if not self.enable_3d_rope:
self.rope3d = None
return
num_heads = 16
head_dim = self.embed_dim // num_heads
self.rope3d = WanRotaryPosEmbed(
attention_head_dim=head_dim,
patch_size=(1, self.patch_size, self.patch_size),
max_seq_len=self.max_frame_num,
)
logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
"""
Generate 3D RoPE positions for streaming mode with correct global frame indices.
Args:
num_frames: Number of frames in current batch
H, W: Image height and width
device: Device to create positions on
f_start: Global start frame index
f_end: Global end frame index
Returns:
pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
"""
if self.rope3d is None:
return None
pph = H // self.patch_size
ppw = W // self.patch_size
pos3d = self.rope3d(
ppf=num_frames,
pph=pph,
ppw=ppw,
patch_start_idx=self.num_special_tokens,
device=device,
f_start=f_start,
f_end=f_end
)
return pos3d
def _prepare_special_tokens(
self,
B: int,
S_local: int,
S_global: int,
C: int,
num_frame_for_scale: Optional[int] = None,
) -> torch.Tensor:
"""
Prepare camera, register, and scale tokens.
Args:
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
C: Embedding dimension
num_frame_for_scale: Number of frames for scale estimation
Returns:
Special tokens [B*S_global, N_special, C]
"""
# Get effective num_frame_for_scale
scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
# Check cache state for both backends
has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
# Determine if we're in causal inference mode based on KV cache state
causal_inference = True
if causal_inference and has_flashinfer_cache:
S_cached = self.kv_cache_manager.num_frames
S_true = S_cached + S_global
elif causal_inference and has_sdpa_cache:
_, _, S_cached, _, _ = self.kv_cache["k_0"].shape
S_true = S_cached + S_global
else:
S_true = S_global
# Expand tokens based on mode
if causal_inference and S_true > S_global:
# Streaming mode: expand with S_true, then slice to get current frames
effective_scale_frames = min(scale_frames, S_true)
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
camera_token = camera_token_full[-S_global:, :, :]
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
register_token = register_token_full[-S_global:, :, :]
scale_token_full = slice_expand_and_flatten(
self.scale_token, B, S_true, first_num_frame=effective_scale_frames
)
scale_token = scale_token_full[-S_global:, :, :]
else:
# Batch mode or first inference: expand directly
effective_scale_frames = min(scale_frames, S_global)
camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
register_token = slice_expand_and_flatten(self.register_token, B, S_global)
scale_token = slice_expand_and_flatten(
self.scale_token, B, S_global, first_num_frame=effective_scale_frames
)
special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
# Verify shape
expected_shape = (B * S_global, self.num_special_tokens, C)
assert special_tokens.shape == expected_shape, \
f"Expected {expected_shape}, got {special_tokens.shape}"
return special_tokens
def _process_global_attention(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
# Mode-specific parameters
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
**kwargs,
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process causal global attention via FlashInfer streaming path.
Args:
tokens: Input tokens
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Tokens per frame
C: Embedding dimension
global_idx: Current global block index
pos: Position embeddings
num_frame_for_scale: Number of frames for scale estimation
sliding_window_size: Sliding window size in blocks
num_frame_per_block: Number of frames per processing block
Returns:
(tokens, global_idx, intermediates)
"""
# Extract image dimensions from kwargs for 3D RoPE
image_height = kwargs.get('image_height', self.img_size)
image_width = kwargs.get('image_width', self.img_size)
return self._process_causal_stream(
tokens, B, S_local, S_global, P, C, global_idx, pos,
num_frame_per_block, sliding_window_size, num_frame_for_scale,
image_height=image_height, image_width=image_width
)
def _process_causal_stream(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
num_frame_per_block: int = 1,
sliding_window_size: Optional[int] = None,
num_frame_for_scale: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
):
"""
Causal attention for streaming inference using FlashInfer KV cache.
Args:
tokens: Input tokens [B*S_local, P, C]
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Number of patches per frame (includes special tokens)
C: Channel dimension
global_idx: Starting block index
pos: Position embeddings [B*S_global, P, 2]
num_frame_per_block: Number of frames per block
sliding_window_size: Sliding window size in blocks
num_frame_for_scale: Number of scale frames
image_height: Image height for 3D RoPE calculation
image_width: Image width for 3D RoPE calculation
Returns:
(tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
"""
# Get effective parameters
scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
# Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
if tokens.shape != (B, S_local * P, C):
tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
# Calculate number of frames for block mask
num_frames = S_global
num_patches = P - self.num_special_tokens
# Check if this is the first block group
is_first_block_group = (global_idx < self.aa_block_size)
if self.enable_3d_rope and self.rope3d is not None:
if is_first_block_group:
f_start = self.total_frames_processed
f_end = self.total_frames_processed + S_global
H = image_height if image_height is not None else self.img_size
W = image_width if image_width is not None else self.img_size
pos3d = self._get_3d_positions_streaming(
S_global, H, W, tokens.device, f_start, f_end
)
self._cached_pos3d = pos3d
else:
pos3d = self._cached_pos3d
pos = pos3d
else:
# Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
if pos is not None and pos.shape != (B, S_global * P, 2):
pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
intermediates = []
# Process blocks with KV cache
for _ in range(self.aa_block_size):
num_patches = P - self.num_special_tokens
if self.use_sdpa:
# SDPA: dict-based KV cache
tokens = self.global_blocks[global_idx](
tokens,
pos=pos,
enable_ulysses_cp=False,
num_patches=num_patches,
num_special=self.num_special_tokens,
num_frames=num_frames,
enable_3d_rope=self.enable_3d_rope,
kv_cache=self.kv_cache,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=scale_frames,
num_register_tokens=self.num_register_tokens,
)
else:
# FlashInfer: paged KV cache manager
manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
tokens = self.global_blocks[global_idx](
tokens,
pos=pos,
enable_ulysses_cp=False,
num_patches=num_patches,
num_special=self.num_special_tokens,
num_frames=num_frames,
enable_3d_rope=self.enable_3d_rope,
kv_cache=manager,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=scale_frames,
num_register_tokens=self.num_register_tokens,
)
global_idx += 1
intermediates.append(tokens.view(B, S_local, P, C))
# Update total frames processed counter only on the first block group
if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
self.total_frames_processed += S_global
return tokens, global_idx, intermediates