first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

@@ -0,0 +1,444 @@
"""
GCTStream - Streaming GCT with KV cache for online inference.
Provides streaming inference functionality:
- Temporal causal attention with KV cache
- Sliding window support
- Efficient frame-by-frame processing
- 3D RoPE support for temporal consistency
"""
import logging
import torch
import torch.nn as nn
from typing import Optional, Dict, Any, List
from tqdm.auto import tqdm
from lingbot_map.heads.camera_head import CameraCausalHead
from lingbot_map.models.gct_base import GCTBase
from lingbot_map.aggregator.stream import AggregatorStream
logger = logging.getLogger(__name__)
class GCTStream(GCTBase):
"""
Streaming GCT model with KV cache for efficient online inference.
Features:
- AggregatorStream with KV cache support (FlashInfer backend)
- CameraCausalHead for pose refinement
- Sliding window attention for memory efficiency
- Frame-by-frame streaming inference
"""
def __init__(
self,
# Architecture parameters
img_size: int = 518,
patch_size: int = 14,
embed_dim: int = 1024,
patch_embed: str = 'dinov2_vitl14_reg',
pretrained_path: str = '',
disable_global_rope: bool = False,
# Head configuration
enable_camera: bool = True,
enable_point: bool = True,
enable_local_point: bool = False,
enable_depth: bool = True,
enable_track: bool = False,
# Normalization
enable_normalize: bool = False,
# Prediction normalization
pred_normalization: bool = False,
# Stream-specific parameters
sliding_window_size: int = -1,
num_frame_for_scale: int = 1,
num_random_frames: int = 0,
attend_to_special_tokens: bool = False,
attend_to_scale_frames: bool = False,
enable_stream_inference: bool = True, # Default to True for streaming
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
enable_camera_3d_rope: bool = False,
camera_rope_theta: float = 10000.0,
# Scale token configuration (kept for checkpoint compat, ignored)
use_scale_token: bool = True,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# Backend selection
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
# Gradient checkpointing
use_gradient_checkpoint: bool = True,
):
"""
Initialize GCTStream.
Args:
img_size: Input image size
patch_size: Patch size for embedding
embed_dim: Embedding dimension
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
pretrained_path: Path to pretrained DINOv2 weights
disable_global_rope: Disable RoPE in global attention
enable_camera/point/depth/track: Enable prediction heads
enable_normalize: Enable normalization
sliding_window_size: Sliding window size in blocks (-1 for full causal)
num_frame_for_scale: Number of scale estimation frames
num_random_frames: Number of random frames for long-range dependencies
attend_to_special_tokens: Enable cross-frame special token attention
attend_to_scale_frames: Whether to attend to scale frames
enable_stream_inference: Enable streaming inference with KV cache
enable_3d_rope: Enable 3D RoPE for temporal consistency
max_frame_num: Maximum number of frames for 3D RoPE
use_scale_token: Kept for checkpoint compatibility, ignored
kv_cache_sliding_window: Sliding window size for KV cache eviction
kv_cache_scale_frames: Number of scale frames to keep in KV cache
kv_cache_cross_frame_special: Keep special tokens from evicted frames
kv_cache_include_scale_frames: Include scale frames in KV cache
kv_cache_camera_only: Only keep camera tokens from evicted frames
"""
# Store stream-specific parameters before calling super().__init__()
self.pretrained_path = pretrained_path
self.sliding_window_size = sliding_window_size
self.num_frame_for_scale = num_frame_for_scale
self.num_random_frames = num_random_frames
self.attend_to_special_tokens = attend_to_special_tokens
self.attend_to_scale_frames = attend_to_scale_frames
self.enable_stream_inference = enable_stream_inference
self.enable_3d_rope = enable_3d_rope
self.max_frame_num = max_frame_num
# Camera head 3D RoPE settings
self.enable_camera_3d_rope = enable_camera_3d_rope
self.camera_rope_theta = camera_rope_theta
# KV cache parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
self.use_sdpa = use_sdpa
# Call base class __init__ (will call _build_aggregator)
super().__init__(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
patch_embed=patch_embed,
disable_global_rope=disable_global_rope,
enable_camera=enable_camera,
enable_point=enable_point,
enable_local_point=enable_local_point,
enable_depth=enable_depth,
enable_track=enable_track,
enable_normalize=enable_normalize,
pred_normalization=pred_normalization,
enable_3d_rope=enable_3d_rope,
use_gradient_checkpoint=use_gradient_checkpoint,
)
def _build_aggregator(self) -> nn.Module:
"""
Build streaming aggregator with KV cache support (FlashInfer backend).
Returns:
AggregatorStream module
"""
return AggregatorStream(
img_size=self.img_size,
patch_size=self.patch_size,
embed_dim=self.embed_dim,
patch_embed=self.patch_embed,
pretrained_path=self.pretrained_path,
disable_global_rope=self.disable_global_rope,
sliding_window_size=self.sliding_window_size,
num_frame_for_scale=self.num_frame_for_scale,
num_random_frames=self.num_random_frames,
attend_to_special_tokens=self.attend_to_special_tokens,
attend_to_scale_frames=self.attend_to_scale_frames,
enable_stream_inference=self.enable_stream_inference,
enable_3d_rope=self.enable_3d_rope,
max_frame_num=self.max_frame_num,
# Backend: FlashInfer (default) or SDPA (fallback)
use_flashinfer=not self.use_sdpa,
use_sdpa=self.use_sdpa,
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
use_gradient_checkpoint=self.use_gradient_checkpoint,
)
def _build_camera_head(self) -> nn.Module:
"""
Build causal camera head for streaming inference.
Returns:
CameraCausalHead module or None
"""
return CameraCausalHead(
dim_in=2 * self.embed_dim,
sliding_window_size=self.sliding_window_size,
attend_to_scale_frames=self.attend_to_scale_frames,
# KV cache parameters
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
# Camera head 3D RoPE parameters
enable_3d_rope=self.enable_camera_3d_rope,
max_frame_num=self.max_frame_num,
rope_theta=self.camera_rope_theta,
)
def _aggregate_features(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
**kwargs,
) -> tuple:
"""
Run aggregator to get multi-scale features.
Args:
images: Input images [B, S, 3, H, W]
num_frame_for_scale: Number of frames for scale estimation
sliding_window_size: Override sliding window size
num_frame_per_block: Number of frames per block
Returns:
(aggregated_tokens_list, patch_start_idx)
"""
aggregated_tokens_list, patch_start_idx = self.aggregator(
images,
selected_idx=[4, 11, 17, 23],
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
)
return aggregated_tokens_list, patch_start_idx
def clean_kv_cache(self):
"""
Clean KV cache in aggregator.
Call this method when starting a new video sequence to clear
cached key-value pairs from previous sequences.
"""
if hasattr(self.aggregator, 'clean_kv_cache'):
self.aggregator.clean_kv_cache()
else:
logger.warning("Aggregator does not support KV cache cleaning")
if hasattr(self.camera_head, 'kv_cache'):
self.camera_head.clean_kv_cache()
else:
logger.warning("Camera head does not support KV cache cleaning")
def _set_skip_append(self, skip: bool):
"""Set _skip_append flag on all KV caches (aggregator + camera head).
When skip=True, attention layers will attend to [cached_kv + current_kv]
but will NOT store the current frame's KV in cache. This is used for
non-keyframe processing in keyframe-based streaming inference.
Args:
skip: If True, subsequent forward passes will not append KV to cache.
"""
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
self.aggregator.kv_cache["_skip_append"] = skip
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
for cache_dict in self.camera_head.kv_cache:
cache_dict["_skip_append"] = skip
def get_kv_cache_info(self) -> Dict[str, Any]:
"""
Get information about current KV cache state.
Returns:
Dictionary with cache statistics:
- num_cached_blocks: Number of blocks with cached KV
- cache_memory_mb: Approximate memory usage in MB
"""
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
kv_cache = self.aggregator.kv_cache
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
# Estimate memory usage
total_elements = 0
for _, v in kv_cache.items():
if v is not None and torch.is_tensor(v):
total_elements += v.numel()
# Assume bfloat16 (2 bytes per element)
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
return {
"num_cached_blocks": num_cached,
"cache_memory_mb": round(cache_memory_mb, 2)
}
@torch.no_grad()
def inference_streaming(
self,
images: torch.Tensor,
num_scale_frames: Optional[int] = None,
keyframe_interval: int = 1,
output_device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Streaming inference: process scale frames first, then frame-by-frame.
This method enables efficient online inference by:
1. Processing initial scale frames together (bidirectional attention via scale token)
2. Processing remaining frames one-by-one with KV cache (causal streaming)
Keyframe mode (keyframe_interval > 1):
- Every keyframe_interval-th frame (after scale frames) is a keyframe
- Keyframes: KV is stored in cache (normal behavior)
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
- All frames produce full predictions regardless of keyframe status
- Reduces KV cache memory growth by ~1/keyframe_interval
Args:
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
num_scale_frames: Number of initial frames for scale estimation.
If None, uses self.num_frame_for_scale.
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
whose KV persists in cache. 1 = every frame is a
keyframe (default, same as original behavior).
output_device: Device to store output predictions on. If None, keeps on
the same device as the model. Set to torch.device('cpu')
to offload predictions per-frame and avoid GPU OOM on
long sequences.
Returns:
Dictionary containing predictions for all frames:
- pose_enc: [B, S, 9]
- depth: [B, S, H, W, 1]
- depth_conf: [B, S, H, W]
- world_points: [B, S, H, W, 3]
- world_points_conf: [B, S, H, W]
"""
# Normalize input shape
if len(images.shape) == 4:
images = images.unsqueeze(0)
B, S, C, H, W = images.shape
# Determine number of scale frames
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
scale_frames = min(scale_frames, S) # Cap to available frames
# Helper to move tensor to output device
def _to_out(t: torch.Tensor) -> torch.Tensor:
if output_device is not None:
return t.to(output_device)
return t
# Clean KV caches before starting new sequence
self.clean_kv_cache()
# Phase 1: Process scale frames together
# These frames get bidirectional attention among themselves via scale token
logger.info(f'Processing {scale_frames} scale frames...')
scale_images = images[:, :scale_frames]
scale_output = self.forward(
scale_images,
num_frame_for_scale=scale_frames,
num_frame_per_block=scale_frames, # Process all scale frames as one block
causal_inference=True,
)
# Initialize output lists with scale frame predictions (offload if needed)
all_pose_enc = [_to_out(scale_output["pose_enc"])]
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
del scale_output
# Phase 2: Process remaining frames one-by-one
pbar = tqdm(
range(scale_frames, S),
desc='Streaming inference',
initial=scale_frames,
total=S,
)
for i in pbar:
frame_image = images[:, i:i+1]
# Determine if this frame is a keyframe
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
if not is_keyframe:
self._set_skip_append(True)
frame_output = self.forward(
frame_image,
num_frame_for_scale=scale_frames, # Keep same for scale token logic
num_frame_per_block=1, # Single frame per block
causal_inference=True,
)
if not is_keyframe:
self._set_skip_append(False)
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
if "depth" in frame_output:
all_depth.append(_to_out(frame_output["depth"]))
if "depth_conf" in frame_output:
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
if "world_points" in frame_output:
all_world_points.append(_to_out(frame_output["world_points"]))
if "world_points_conf" in frame_output:
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
del frame_output
# Free GPU memory before concatenation
if output_device is not None:
# Move images to output device, then free GPU copy
images_out = _to_out(images)
del images
# Clean KV cache (no longer needed after inference)
self.clean_kv_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
images_out = images
# Concatenate all predictions along sequence dimension
predictions = {
"pose_enc": torch.cat(all_pose_enc, dim=1),
}
del all_pose_enc
if all_depth:
predictions["depth"] = torch.cat(all_depth, dim=1)
del all_depth
if all_depth_conf:
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
del all_depth_conf
if all_world_points:
predictions["world_points"] = torch.cat(all_world_points, dim=1)
del all_world_points
if all_world_points_conf:
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
del all_world_points_conf
# Store images for visualization
predictions["images"] = images_out
# Apply prediction normalization if enabled
if self.pred_normalization:
predictions = self._normalize_predictions(predictions)
return predictions