445 lines
18 KiB
Python
445 lines
18 KiB
Python
"""
|
|
GCTStream - Streaming GCT with KV cache for online inference.
|
|
|
|
Provides streaming inference functionality:
|
|
- Temporal causal attention with KV cache
|
|
- Sliding window support
|
|
- Efficient frame-by-frame processing
|
|
- 3D RoPE support for temporal consistency
|
|
"""
|
|
|
|
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Optional, Dict, Any, List
|
|
from tqdm.auto import tqdm
|
|
|
|
from lingbot_map.heads.camera_head import CameraCausalHead
|
|
from lingbot_map.models.gct_base import GCTBase
|
|
from lingbot_map.aggregator.stream import AggregatorStream
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GCTStream(GCTBase):
|
|
"""
|
|
Streaming GCT model with KV cache for efficient online inference.
|
|
|
|
Features:
|
|
- AggregatorStream with KV cache support (FlashInfer backend)
|
|
- CameraCausalHead for pose refinement
|
|
- Sliding window attention for memory efficiency
|
|
- Frame-by-frame streaming inference
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# Architecture parameters
|
|
img_size: int = 518,
|
|
patch_size: int = 14,
|
|
embed_dim: int = 1024,
|
|
patch_embed: str = 'dinov2_vitl14_reg',
|
|
pretrained_path: str = '',
|
|
disable_global_rope: bool = False,
|
|
# Head configuration
|
|
enable_camera: bool = True,
|
|
enable_point: bool = True,
|
|
enable_local_point: bool = False,
|
|
enable_depth: bool = True,
|
|
enable_track: bool = False,
|
|
# Normalization
|
|
enable_normalize: bool = False,
|
|
# Prediction normalization
|
|
pred_normalization: bool = False,
|
|
# Stream-specific parameters
|
|
sliding_window_size: int = -1,
|
|
num_frame_for_scale: int = 1,
|
|
num_random_frames: int = 0,
|
|
attend_to_special_tokens: bool = False,
|
|
attend_to_scale_frames: bool = False,
|
|
enable_stream_inference: bool = True, # Default to True for streaming
|
|
enable_3d_rope: bool = False,
|
|
max_frame_num: int = 1024,
|
|
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
|
|
enable_camera_3d_rope: bool = False,
|
|
camera_rope_theta: float = 10000.0,
|
|
# Scale token configuration (kept for checkpoint compat, ignored)
|
|
use_scale_token: bool = True,
|
|
# KV cache parameters
|
|
kv_cache_sliding_window: int = 64,
|
|
kv_cache_scale_frames: int = 8,
|
|
kv_cache_cross_frame_special: bool = True,
|
|
kv_cache_include_scale_frames: bool = True,
|
|
kv_cache_camera_only: bool = False,
|
|
# Backend selection
|
|
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
|
# Gradient checkpointing
|
|
use_gradient_checkpoint: bool = True,
|
|
):
|
|
"""
|
|
Initialize GCTStream.
|
|
|
|
Args:
|
|
img_size: Input image size
|
|
patch_size: Patch size for embedding
|
|
embed_dim: Embedding dimension
|
|
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
|
|
pretrained_path: Path to pretrained DINOv2 weights
|
|
disable_global_rope: Disable RoPE in global attention
|
|
enable_camera/point/depth/track: Enable prediction heads
|
|
enable_normalize: Enable normalization
|
|
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
|
num_frame_for_scale: Number of scale estimation frames
|
|
num_random_frames: Number of random frames for long-range dependencies
|
|
attend_to_special_tokens: Enable cross-frame special token attention
|
|
attend_to_scale_frames: Whether to attend to scale frames
|
|
enable_stream_inference: Enable streaming inference with KV cache
|
|
enable_3d_rope: Enable 3D RoPE for temporal consistency
|
|
max_frame_num: Maximum number of frames for 3D RoPE
|
|
use_scale_token: Kept for checkpoint compatibility, ignored
|
|
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
|
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
|
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
|
kv_cache_include_scale_frames: Include scale frames in KV cache
|
|
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
|
"""
|
|
# Store stream-specific parameters before calling super().__init__()
|
|
self.pretrained_path = pretrained_path
|
|
self.sliding_window_size = sliding_window_size
|
|
self.num_frame_for_scale = num_frame_for_scale
|
|
self.num_random_frames = num_random_frames
|
|
self.attend_to_special_tokens = attend_to_special_tokens
|
|
self.attend_to_scale_frames = attend_to_scale_frames
|
|
self.enable_stream_inference = enable_stream_inference
|
|
self.enable_3d_rope = enable_3d_rope
|
|
self.max_frame_num = max_frame_num
|
|
# Camera head 3D RoPE settings
|
|
self.enable_camera_3d_rope = enable_camera_3d_rope
|
|
self.camera_rope_theta = camera_rope_theta
|
|
# KV cache parameters
|
|
self.kv_cache_sliding_window = kv_cache_sliding_window
|
|
self.kv_cache_scale_frames = kv_cache_scale_frames
|
|
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
|
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
|
self.kv_cache_camera_only = kv_cache_camera_only
|
|
self.use_sdpa = use_sdpa
|
|
|
|
# Call base class __init__ (will call _build_aggregator)
|
|
super().__init__(
|
|
img_size=img_size,
|
|
patch_size=patch_size,
|
|
embed_dim=embed_dim,
|
|
patch_embed=patch_embed,
|
|
disable_global_rope=disable_global_rope,
|
|
enable_camera=enable_camera,
|
|
enable_point=enable_point,
|
|
enable_local_point=enable_local_point,
|
|
enable_depth=enable_depth,
|
|
enable_track=enable_track,
|
|
enable_normalize=enable_normalize,
|
|
pred_normalization=pred_normalization,
|
|
enable_3d_rope=enable_3d_rope,
|
|
use_gradient_checkpoint=use_gradient_checkpoint,
|
|
)
|
|
|
|
def _build_aggregator(self) -> nn.Module:
|
|
"""
|
|
Build streaming aggregator with KV cache support (FlashInfer backend).
|
|
|
|
Returns:
|
|
AggregatorStream module
|
|
"""
|
|
return AggregatorStream(
|
|
img_size=self.img_size,
|
|
patch_size=self.patch_size,
|
|
embed_dim=self.embed_dim,
|
|
patch_embed=self.patch_embed,
|
|
pretrained_path=self.pretrained_path,
|
|
disable_global_rope=self.disable_global_rope,
|
|
sliding_window_size=self.sliding_window_size,
|
|
num_frame_for_scale=self.num_frame_for_scale,
|
|
num_random_frames=self.num_random_frames,
|
|
attend_to_special_tokens=self.attend_to_special_tokens,
|
|
attend_to_scale_frames=self.attend_to_scale_frames,
|
|
enable_stream_inference=self.enable_stream_inference,
|
|
enable_3d_rope=self.enable_3d_rope,
|
|
max_frame_num=self.max_frame_num,
|
|
# Backend: FlashInfer (default) or SDPA (fallback)
|
|
use_flashinfer=not self.use_sdpa,
|
|
use_sdpa=self.use_sdpa,
|
|
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
|
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
|
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
|
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
|
kv_cache_camera_only=self.kv_cache_camera_only,
|
|
use_gradient_checkpoint=self.use_gradient_checkpoint,
|
|
)
|
|
|
|
def _build_camera_head(self) -> nn.Module:
|
|
"""
|
|
Build causal camera head for streaming inference.
|
|
|
|
Returns:
|
|
CameraCausalHead module or None
|
|
"""
|
|
return CameraCausalHead(
|
|
dim_in=2 * self.embed_dim,
|
|
sliding_window_size=self.sliding_window_size,
|
|
attend_to_scale_frames=self.attend_to_scale_frames,
|
|
# KV cache parameters
|
|
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
|
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
|
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
|
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
|
kv_cache_camera_only=self.kv_cache_camera_only,
|
|
# Camera head 3D RoPE parameters
|
|
enable_3d_rope=self.enable_camera_3d_rope,
|
|
max_frame_num=self.max_frame_num,
|
|
rope_theta=self.camera_rope_theta,
|
|
)
|
|
|
|
def _aggregate_features(
|
|
self,
|
|
images: torch.Tensor,
|
|
num_frame_for_scale: Optional[int] = None,
|
|
sliding_window_size: Optional[int] = None,
|
|
num_frame_per_block: int = 1,
|
|
**kwargs,
|
|
) -> tuple:
|
|
"""
|
|
Run aggregator to get multi-scale features.
|
|
|
|
Args:
|
|
images: Input images [B, S, 3, H, W]
|
|
num_frame_for_scale: Number of frames for scale estimation
|
|
sliding_window_size: Override sliding window size
|
|
num_frame_per_block: Number of frames per block
|
|
|
|
Returns:
|
|
(aggregated_tokens_list, patch_start_idx)
|
|
"""
|
|
aggregated_tokens_list, patch_start_idx = self.aggregator(
|
|
images,
|
|
selected_idx=[4, 11, 17, 23],
|
|
num_frame_for_scale=num_frame_for_scale,
|
|
sliding_window_size=sliding_window_size,
|
|
num_frame_per_block=num_frame_per_block,
|
|
)
|
|
return aggregated_tokens_list, patch_start_idx
|
|
|
|
def clean_kv_cache(self):
|
|
"""
|
|
Clean KV cache in aggregator.
|
|
|
|
Call this method when starting a new video sequence to clear
|
|
cached key-value pairs from previous sequences.
|
|
"""
|
|
if hasattr(self.aggregator, 'clean_kv_cache'):
|
|
self.aggregator.clean_kv_cache()
|
|
else:
|
|
logger.warning("Aggregator does not support KV cache cleaning")
|
|
if hasattr(self.camera_head, 'kv_cache'):
|
|
self.camera_head.clean_kv_cache()
|
|
else:
|
|
logger.warning("Camera head does not support KV cache cleaning")
|
|
|
|
def _set_skip_append(self, skip: bool):
|
|
"""Set _skip_append flag on all KV caches (aggregator + camera head).
|
|
|
|
When skip=True, attention layers will attend to [cached_kv + current_kv]
|
|
but will NOT store the current frame's KV in cache. This is used for
|
|
non-keyframe processing in keyframe-based streaming inference.
|
|
|
|
Args:
|
|
skip: If True, subsequent forward passes will not append KV to cache.
|
|
"""
|
|
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
|
|
self.aggregator.kv_cache["_skip_append"] = skip
|
|
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
|
for cache_dict in self.camera_head.kv_cache:
|
|
cache_dict["_skip_append"] = skip
|
|
|
|
def get_kv_cache_info(self) -> Dict[str, Any]:
|
|
"""
|
|
Get information about current KV cache state.
|
|
|
|
Returns:
|
|
Dictionary with cache statistics:
|
|
- num_cached_blocks: Number of blocks with cached KV
|
|
- cache_memory_mb: Approximate memory usage in MB
|
|
"""
|
|
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
|
|
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
|
|
|
|
kv_cache = self.aggregator.kv_cache
|
|
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
|
|
|
|
# Estimate memory usage
|
|
total_elements = 0
|
|
for _, v in kv_cache.items():
|
|
if v is not None and torch.is_tensor(v):
|
|
total_elements += v.numel()
|
|
|
|
# Assume bfloat16 (2 bytes per element)
|
|
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
|
|
|
|
return {
|
|
"num_cached_blocks": num_cached,
|
|
"cache_memory_mb": round(cache_memory_mb, 2)
|
|
}
|
|
|
|
@torch.no_grad()
|
|
def inference_streaming(
|
|
self,
|
|
images: torch.Tensor,
|
|
num_scale_frames: Optional[int] = None,
|
|
keyframe_interval: int = 1,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
"""
|
|
Streaming inference: process scale frames first, then frame-by-frame.
|
|
|
|
This method enables efficient online inference by:
|
|
1. Processing initial scale frames together (bidirectional attention via scale token)
|
|
2. Processing remaining frames one-by-one with KV cache (causal streaming)
|
|
|
|
Keyframe mode (keyframe_interval > 1):
|
|
- Every keyframe_interval-th frame (after scale frames) is a keyframe
|
|
- Keyframes: KV is stored in cache (normal behavior)
|
|
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
|
|
- All frames produce full predictions regardless of keyframe status
|
|
- Reduces KV cache memory growth by ~1/keyframe_interval
|
|
|
|
Args:
|
|
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
|
num_scale_frames: Number of initial frames for scale estimation.
|
|
If None, uses self.num_frame_for_scale.
|
|
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
|
|
whose KV persists in cache. 1 = every frame is a
|
|
keyframe (default, same as original behavior).
|
|
output_device: Device to store output predictions on. If None, keeps on
|
|
the same device as the model. Set to torch.device('cpu')
|
|
to offload predictions per-frame and avoid GPU OOM on
|
|
long sequences.
|
|
|
|
Returns:
|
|
Dictionary containing predictions for all frames:
|
|
- pose_enc: [B, S, 9]
|
|
- depth: [B, S, H, W, 1]
|
|
- depth_conf: [B, S, H, W]
|
|
- world_points: [B, S, H, W, 3]
|
|
- world_points_conf: [B, S, H, W]
|
|
"""
|
|
# Normalize input shape
|
|
if len(images.shape) == 4:
|
|
images = images.unsqueeze(0)
|
|
B, S, C, H, W = images.shape
|
|
|
|
# Determine number of scale frames
|
|
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
|
|
scale_frames = min(scale_frames, S) # Cap to available frames
|
|
|
|
# Helper to move tensor to output device
|
|
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
|
if output_device is not None:
|
|
return t.to(output_device)
|
|
return t
|
|
|
|
# Clean KV caches before starting new sequence
|
|
self.clean_kv_cache()
|
|
|
|
# Phase 1: Process scale frames together
|
|
# These frames get bidirectional attention among themselves via scale token
|
|
logger.info(f'Processing {scale_frames} scale frames...')
|
|
scale_images = images[:, :scale_frames]
|
|
scale_output = self.forward(
|
|
scale_images,
|
|
num_frame_for_scale=scale_frames,
|
|
num_frame_per_block=scale_frames, # Process all scale frames as one block
|
|
causal_inference=True,
|
|
)
|
|
|
|
# Initialize output lists with scale frame predictions (offload if needed)
|
|
all_pose_enc = [_to_out(scale_output["pose_enc"])]
|
|
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
|
|
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
|
|
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
|
|
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
|
|
del scale_output
|
|
|
|
# Phase 2: Process remaining frames one-by-one
|
|
pbar = tqdm(
|
|
range(scale_frames, S),
|
|
desc='Streaming inference',
|
|
initial=scale_frames,
|
|
total=S,
|
|
)
|
|
for i in pbar:
|
|
frame_image = images[:, i:i+1]
|
|
|
|
# Determine if this frame is a keyframe
|
|
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
|
|
|
|
if not is_keyframe:
|
|
self._set_skip_append(True)
|
|
|
|
frame_output = self.forward(
|
|
frame_image,
|
|
num_frame_for_scale=scale_frames, # Keep same for scale token logic
|
|
num_frame_per_block=1, # Single frame per block
|
|
causal_inference=True,
|
|
)
|
|
|
|
if not is_keyframe:
|
|
self._set_skip_append(False)
|
|
|
|
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
|
|
if "depth" in frame_output:
|
|
all_depth.append(_to_out(frame_output["depth"]))
|
|
if "depth_conf" in frame_output:
|
|
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
|
|
if "world_points" in frame_output:
|
|
all_world_points.append(_to_out(frame_output["world_points"]))
|
|
if "world_points_conf" in frame_output:
|
|
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
|
|
del frame_output
|
|
|
|
# Free GPU memory before concatenation
|
|
if output_device is not None:
|
|
# Move images to output device, then free GPU copy
|
|
images_out = _to_out(images)
|
|
del images
|
|
# Clean KV cache (no longer needed after inference)
|
|
self.clean_kv_cache()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
images_out = images
|
|
|
|
# Concatenate all predictions along sequence dimension
|
|
predictions = {
|
|
"pose_enc": torch.cat(all_pose_enc, dim=1),
|
|
}
|
|
del all_pose_enc
|
|
if all_depth:
|
|
predictions["depth"] = torch.cat(all_depth, dim=1)
|
|
del all_depth
|
|
if all_depth_conf:
|
|
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
|
|
del all_depth_conf
|
|
if all_world_points:
|
|
predictions["world_points"] = torch.cat(all_world_points, dim=1)
|
|
del all_world_points
|
|
if all_world_points_conf:
|
|
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
|
|
del all_world_points_conf
|
|
|
|
# Store images for visualization
|
|
predictions["images"] = images_out
|
|
|
|
# Apply prediction normalization if enabled
|
|
if self.pred_normalization:
|
|
predictions = self._normalize_predictions(predictions)
|
|
|
|
return predictions
|