first commit
This commit is contained in:
531
lingbot_map/aggregator/stream.py
Normal file
531
lingbot_map/aggregator/stream.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
|
||||
|
||||
Provides:
|
||||
- Temporal causal attention
|
||||
- Sliding window support
|
||||
- Scale token for scale estimation frames
|
||||
- Streaming inference with FlashInfer paged KV cache
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
|
||||
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
||||
from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AggregatorStream(AggregatorBase):
|
||||
"""
|
||||
Streaming causal aggregator with FlashInfer paged KV cache.
|
||||
|
||||
Features:
|
||||
- Temporal causal attention (each frame only attends to past frames)
|
||||
- Sliding window support to limit attention scope
|
||||
- Scale token for scale estimation frames
|
||||
- Streaming inference with FlashInfer KV cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Causal-specific parameters
|
||||
sliding_window_size: int = -1,
|
||||
num_frame_for_scale: int = 1,
|
||||
num_random_frames: int = 0,
|
||||
attend_to_special_tokens: bool = False,
|
||||
attend_to_scale_frames: bool = False,
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# Base class parameters via **kwargs
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize AggregatorStream.
|
||||
|
||||
Args:
|
||||
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
||||
num_frame_for_scale: Number of scale estimation frames
|
||||
num_random_frames: Number of random frames for long-range dependencies
|
||||
attend_to_special_tokens: Enable cross-frame special token attention
|
||||
attend_to_scale_frames: Include scale frames in attention
|
||||
enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
|
||||
max_frame_num: Maximum number of frames for 3D RoPE
|
||||
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
||||
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
||||
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
||||
kv_cache_include_scale_frames: Include scale frames in KV cache
|
||||
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
||||
**kwargs: Base class parameters
|
||||
"""
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.num_frame_for_scale = num_frame_for_scale
|
||||
self.num_random_frames = num_random_frames
|
||||
self.attend_to_special_tokens = attend_to_special_tokens
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
self.max_frame_num = max_frame_num
|
||||
# KV cache parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
# Pop kwargs that are passed but not needed by base class
|
||||
kwargs.pop('enable_stream_inference', None)
|
||||
use_flashinfer = kwargs.pop('use_flashinfer', True)
|
||||
kwargs.pop('use_flexflash', None)
|
||||
use_sdpa = kwargs.pop('use_sdpa', False)
|
||||
|
||||
# Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
|
||||
|
||||
# Call parent __init__
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize KV cache
|
||||
self._init_kv_cache()
|
||||
|
||||
# Initialize 3D RoPE if enabled
|
||||
if self.enable_3d_rope:
|
||||
self._init_3d_rope()
|
||||
|
||||
def _build_blocks(
|
||||
self,
|
||||
block_fn,
|
||||
depth: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
ffn_bias: bool,
|
||||
init_values: float,
|
||||
qk_norm: bool,
|
||||
):
|
||||
"""Build frame and global blocks for streaming causal mode."""
|
||||
block_params = dict(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
init_values=init_values,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
# Frame blocks: Standard Block + RoPE
|
||||
self.frame_blocks = nn.ModuleList([
|
||||
block_fn(**block_params, rope=self.rope)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
# Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
|
||||
GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
|
||||
self.global_blocks = nn.ModuleList([
|
||||
GlobalBlockCls(
|
||||
**block_params,
|
||||
rope=self.rope if not self.disable_global_rope else None,
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
def _setup_special_tokens(self):
|
||||
"""Setup camera, register, and scale tokens for causal mode."""
|
||||
# Camera token
|
||||
self.camera_token = nn.Parameter(
|
||||
torch.randn(1, 2, 1, self.embed_dim)
|
||||
)
|
||||
|
||||
# Register tokens
|
||||
if self.num_register_tokens > 0:
|
||||
self.register_token = nn.Parameter(
|
||||
torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
|
||||
)
|
||||
|
||||
# Scale token (causal mode specific)
|
||||
self.scale_token = nn.Parameter(
|
||||
torch.ones(1, 2, 1, self.embed_dim)
|
||||
)
|
||||
|
||||
# Initialize
|
||||
nn.init.normal_(self.camera_token, std=1e-6)
|
||||
if self.num_register_tokens > 0:
|
||||
nn.init.normal_(self.register_token, std=1e-6)
|
||||
nn.init.normal_(self.scale_token, std=1e-6)
|
||||
|
||||
# Token indexing (includes scale token)
|
||||
self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
|
||||
self.num_special_tokens = 1 + self.num_register_tokens + 1
|
||||
|
||||
def _init_kv_cache(self):
|
||||
"""Initialize KV cache for streaming inference."""
|
||||
self.kv_cache_manager = None # FlashInfer (lazy-initialized)
|
||||
self.kv_cache = {} # Dict-based cache for SDPA
|
||||
self.total_frames_processed = 0
|
||||
self._cached_pos3d = None
|
||||
|
||||
if self.use_sdpa:
|
||||
# Dict-based KV cache for SDPA
|
||||
if hasattr(self, 'depth'):
|
||||
for i in range(self.depth):
|
||||
self.kv_cache[f"k_{i}"] = None
|
||||
self.kv_cache[f"v_{i}"] = None
|
||||
self.kv_cache[f"k_{i}_special"] = None
|
||||
self.kv_cache[f"v_{i}_special"] = None
|
||||
logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
|
||||
else:
|
||||
logger.info("FlashInfer KV cache will be lazily initialized on first forward")
|
||||
|
||||
def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
|
||||
"""Lazily initialize FlashInferKVCacheManager on first use.
|
||||
|
||||
Args:
|
||||
device: Device for cache tensors.
|
||||
dtype: Data type for cache tensors.
|
||||
tokens_per_frame: Actual number of tokens per frame (patches + specials).
|
||||
If None, falls back to assuming square images of self.img_size.
|
||||
"""
|
||||
if self.kv_cache_manager is None:
|
||||
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
||||
num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
|
||||
head_dim = 64
|
||||
if tokens_per_frame is None:
|
||||
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
|
||||
# max_num_frames: scale + window + headroom
|
||||
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
|
||||
self.kv_cache_manager = FlashInferKVCacheManager(
|
||||
num_blocks=self.depth,
|
||||
max_num_frames=max_num_frames,
|
||||
tokens_per_frame=tokens_per_frame,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_special_tokens=self.num_special_tokens,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
max_total_frames=self.max_frame_num + 100,
|
||||
force_fp32=getattr(self, 'kv_cache_force_fp32', False),
|
||||
fa3=getattr(self, 'kv_cache_fa3', False),
|
||||
)
|
||||
logger.info(
|
||||
f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
|
||||
f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
|
||||
)
|
||||
return self.kv_cache_manager
|
||||
|
||||
def clean_kv_cache(self):
|
||||
"""Clean KV cache (call this when starting a new sequence)."""
|
||||
if self.kv_cache_manager is not None:
|
||||
self.kv_cache_manager.reset()
|
||||
if self.kv_cache:
|
||||
for key in list(self.kv_cache.keys()):
|
||||
if key == "_skip_append":
|
||||
self.kv_cache[key] = False
|
||||
else:
|
||||
self.kv_cache[key] = None
|
||||
self.total_frames_processed = 0
|
||||
self._cached_pos3d = None
|
||||
logger.info("KV cache cleaned")
|
||||
|
||||
def _init_3d_rope(self):
|
||||
"""Initialize 3D RoPE for streaming inference."""
|
||||
if not self.enable_3d_rope:
|
||||
self.rope3d = None
|
||||
return
|
||||
|
||||
num_heads = 16
|
||||
head_dim = self.embed_dim // num_heads
|
||||
|
||||
self.rope3d = WanRotaryPosEmbed(
|
||||
attention_head_dim=head_dim,
|
||||
patch_size=(1, self.patch_size, self.patch_size),
|
||||
max_seq_len=self.max_frame_num,
|
||||
)
|
||||
logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
|
||||
|
||||
def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
|
||||
"""
|
||||
Generate 3D RoPE positions for streaming mode with correct global frame indices.
|
||||
|
||||
Args:
|
||||
num_frames: Number of frames in current batch
|
||||
H, W: Image height and width
|
||||
device: Device to create positions on
|
||||
f_start: Global start frame index
|
||||
f_end: Global end frame index
|
||||
|
||||
Returns:
|
||||
pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
|
||||
"""
|
||||
if self.rope3d is None:
|
||||
return None
|
||||
|
||||
pph = H // self.patch_size
|
||||
ppw = W // self.patch_size
|
||||
|
||||
pos3d = self.rope3d(
|
||||
ppf=num_frames,
|
||||
pph=pph,
|
||||
ppw=ppw,
|
||||
patch_start_idx=self.num_special_tokens,
|
||||
device=device,
|
||||
f_start=f_start,
|
||||
f_end=f_end
|
||||
)
|
||||
return pos3d
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
C: int,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare camera, register, and scale tokens.
|
||||
|
||||
Args:
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
C: Embedding dimension
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
|
||||
Returns:
|
||||
Special tokens [B*S_global, N_special, C]
|
||||
"""
|
||||
# Get effective num_frame_for_scale
|
||||
scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
|
||||
|
||||
# Check cache state for both backends
|
||||
has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
|
||||
has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
|
||||
|
||||
# Determine if we're in causal inference mode based on KV cache state
|
||||
causal_inference = True
|
||||
|
||||
if causal_inference and has_flashinfer_cache:
|
||||
S_cached = self.kv_cache_manager.num_frames
|
||||
S_true = S_cached + S_global
|
||||
elif causal_inference and has_sdpa_cache:
|
||||
_, _, S_cached, _, _ = self.kv_cache["k_0"].shape
|
||||
S_true = S_cached + S_global
|
||||
else:
|
||||
S_true = S_global
|
||||
|
||||
# Expand tokens based on mode
|
||||
if causal_inference and S_true > S_global:
|
||||
# Streaming mode: expand with S_true, then slice to get current frames
|
||||
effective_scale_frames = min(scale_frames, S_true)
|
||||
|
||||
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
|
||||
camera_token = camera_token_full[-S_global:, :, :]
|
||||
|
||||
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
|
||||
register_token = register_token_full[-S_global:, :, :]
|
||||
scale_token_full = slice_expand_and_flatten(
|
||||
self.scale_token, B, S_true, first_num_frame=effective_scale_frames
|
||||
)
|
||||
scale_token = scale_token_full[-S_global:, :, :]
|
||||
else:
|
||||
# Batch mode or first inference: expand directly
|
||||
effective_scale_frames = min(scale_frames, S_global)
|
||||
|
||||
camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
|
||||
register_token = slice_expand_and_flatten(self.register_token, B, S_global)
|
||||
scale_token = slice_expand_and_flatten(
|
||||
self.scale_token, B, S_global, first_num_frame=effective_scale_frames
|
||||
)
|
||||
|
||||
special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
|
||||
|
||||
# Verify shape
|
||||
expected_shape = (B * S_global, self.num_special_tokens, C)
|
||||
assert special_tokens.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {special_tokens.shape}"
|
||||
|
||||
return special_tokens
|
||||
|
||||
def _process_global_attention(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
P: int,
|
||||
C: int,
|
||||
global_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
# Mode-specific parameters
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
||||
"""
|
||||
Process causal global attention via FlashInfer streaming path.
|
||||
|
||||
Args:
|
||||
tokens: Input tokens
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
P: Tokens per frame
|
||||
C: Embedding dimension
|
||||
global_idx: Current global block index
|
||||
pos: Position embeddings
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
sliding_window_size: Sliding window size in blocks
|
||||
num_frame_per_block: Number of frames per processing block
|
||||
|
||||
Returns:
|
||||
(tokens, global_idx, intermediates)
|
||||
"""
|
||||
# Extract image dimensions from kwargs for 3D RoPE
|
||||
image_height = kwargs.get('image_height', self.img_size)
|
||||
image_width = kwargs.get('image_width', self.img_size)
|
||||
|
||||
return self._process_causal_stream(
|
||||
tokens, B, S_local, S_global, P, C, global_idx, pos,
|
||||
num_frame_per_block, sliding_window_size, num_frame_for_scale,
|
||||
image_height=image_height, image_width=image_width
|
||||
)
|
||||
|
||||
def _process_causal_stream(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
P: int,
|
||||
C: int,
|
||||
global_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
image_height: Optional[int] = None,
|
||||
image_width: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Causal attention for streaming inference using FlashInfer KV cache.
|
||||
|
||||
Args:
|
||||
tokens: Input tokens [B*S_local, P, C]
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
P: Number of patches per frame (includes special tokens)
|
||||
C: Channel dimension
|
||||
global_idx: Starting block index
|
||||
pos: Position embeddings [B*S_global, P, 2]
|
||||
num_frame_per_block: Number of frames per block
|
||||
sliding_window_size: Sliding window size in blocks
|
||||
num_frame_for_scale: Number of scale frames
|
||||
image_height: Image height for 3D RoPE calculation
|
||||
image_width: Image width for 3D RoPE calculation
|
||||
|
||||
Returns:
|
||||
(tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
|
||||
"""
|
||||
# Get effective parameters
|
||||
scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
|
||||
|
||||
# Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
|
||||
if tokens.shape != (B, S_local * P, C):
|
||||
tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
|
||||
|
||||
# Calculate number of frames for block mask
|
||||
num_frames = S_global
|
||||
num_patches = P - self.num_special_tokens
|
||||
|
||||
# Check if this is the first block group
|
||||
is_first_block_group = (global_idx < self.aa_block_size)
|
||||
|
||||
if self.enable_3d_rope and self.rope3d is not None:
|
||||
if is_first_block_group:
|
||||
f_start = self.total_frames_processed
|
||||
f_end = self.total_frames_processed + S_global
|
||||
|
||||
H = image_height if image_height is not None else self.img_size
|
||||
W = image_width if image_width is not None else self.img_size
|
||||
pos3d = self._get_3d_positions_streaming(
|
||||
S_global, H, W, tokens.device, f_start, f_end
|
||||
)
|
||||
self._cached_pos3d = pos3d
|
||||
else:
|
||||
pos3d = self._cached_pos3d
|
||||
pos = pos3d
|
||||
else:
|
||||
# Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
|
||||
if pos is not None and pos.shape != (B, S_global * P, 2):
|
||||
pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
|
||||
|
||||
intermediates = []
|
||||
|
||||
# Process blocks with KV cache
|
||||
for _ in range(self.aa_block_size):
|
||||
num_patches = P - self.num_special_tokens
|
||||
if self.use_sdpa:
|
||||
# SDPA: dict-based KV cache
|
||||
tokens = self.global_blocks[global_idx](
|
||||
tokens,
|
||||
pos=pos,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=num_patches,
|
||||
num_special=self.num_special_tokens,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
kv_cache=self.kv_cache,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_register_tokens=self.num_register_tokens,
|
||||
)
|
||||
else:
|
||||
# FlashInfer: paged KV cache manager
|
||||
manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
|
||||
tokens = self.global_blocks[global_idx](
|
||||
tokens,
|
||||
pos=pos,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=num_patches,
|
||||
num_special=self.num_special_tokens,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
kv_cache=manager,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_register_tokens=self.num_register_tokens,
|
||||
)
|
||||
|
||||
global_idx += 1
|
||||
intermediates.append(tokens.view(B, S_local, P, C))
|
||||
|
||||
# Update total frames processed counter only on the first block group
|
||||
if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
|
||||
self.total_frames_processed += S_global
|
||||
|
||||
return tokens, global_idx, intermediates
|
||||
Reference in New Issue
Block a user