first commit
This commit is contained in:
0
lingbot_map/models/__init__.py
Normal file
0
lingbot_map/models/__init__.py
Normal file
359
lingbot_map/models/gct_base.py
Normal file
359
lingbot_map/models/gct_base.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
GCTBase - Base class for GCT model implementations.
|
||||
|
||||
Provides shared functionality:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Model hub mixin (PyTorchModelHubMixin)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from lingbot_map.heads.dpt_head import DPTHead
|
||||
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
|
||||
"""
|
||||
Base class for GCT model implementations.
|
||||
|
||||
Handles shared components:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Input normalization
|
||||
|
||||
Subclasses must implement:
|
||||
- _build_aggregator(): Create mode-specific aggregator
|
||||
- _build_camera_head(): Create mode-specific camera head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
patch_embed: str = 'dinov2_vitl14_reg',
|
||||
disable_global_rope: bool = False,
|
||||
# Head configuration
|
||||
enable_camera: bool = True,
|
||||
enable_point: bool = True,
|
||||
enable_local_point: bool = False,
|
||||
enable_depth: bool = True,
|
||||
enable_track: bool = False,
|
||||
# Camera head sliding window
|
||||
enable_camera_sliding_window: bool = False,
|
||||
# 3D RoPE
|
||||
enable_3d_rope: bool = False,
|
||||
# Context Parallelism (kept for checkpoint compatibility but not used)
|
||||
enable_ulysses_cp: bool = False,
|
||||
# Normalization
|
||||
enable_normalize: bool = False,
|
||||
# Prediction normalization
|
||||
pred_normalization: bool = False,
|
||||
pred_normalization_detach_scale: bool = False,
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store configuration
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_embed = patch_embed
|
||||
self.disable_global_rope = disable_global_rope
|
||||
|
||||
self.enable_ulysses_cp = False # CP disabled in standalone package
|
||||
self.enable_normalize = enable_normalize
|
||||
self.pred_normalization = pred_normalization
|
||||
self.pred_normalization_detach_scale = pred_normalization_detach_scale
|
||||
self.use_gradient_checkpoint = use_gradient_checkpoint
|
||||
|
||||
# Head flags
|
||||
self.enable_camera = enable_camera
|
||||
self.enable_point = enable_point
|
||||
self.enable_local_point = enable_local_point
|
||||
self.enable_depth = enable_depth
|
||||
self.enable_track = enable_track
|
||||
self.enable_camera_sliding_window = enable_camera_sliding_window
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
|
||||
# Build aggregator (subclass-specific)
|
||||
self.aggregator = self._build_aggregator()
|
||||
|
||||
# Build prediction heads (subclass-specific)
|
||||
self.camera_head = self._build_camera_head() if enable_camera else None
|
||||
self.point_head = self._build_point_head() if enable_point else None
|
||||
self.local_point_head = self._build_local_point_head() if enable_local_point else None
|
||||
self.depth_head = self._build_depth_head() if enable_depth else None
|
||||
|
||||
@abstractmethod
|
||||
def _build_aggregator(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_camera_head(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
def _build_depth_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=2,
|
||||
activation="exp",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_local_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _normalize_input(self, images: torch.Tensor, query_points=None):
|
||||
if len(images.shape) == 4:
|
||||
images = images.unsqueeze(0)
|
||||
if query_points is not None and len(query_points.shape) == 2:
|
||||
query_points = query_points.unsqueeze(0)
|
||||
return images, query_points
|
||||
|
||||
@abstractmethod
|
||||
def _aggregate_features(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
view_graphs: Optional[torch.Tensor] = None,
|
||||
causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
is_cp_sliced: bool = False,
|
||||
) -> tuple:
|
||||
pass
|
||||
|
||||
def _predict_camera(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.camera_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
|
||||
camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pose_enc_list = self.camera_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
mask=mask,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
|
||||
sliding_window_size=camera_sliding_window,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
return {
|
||||
"pose_enc": pose_enc_list[-1],
|
||||
"pose_enc_list": pose_enc_list,
|
||||
}
|
||||
|
||||
def _predict_depth(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.depth_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
depth, depth_conf = self.depth_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"depth": depth, "depth_conf": depth_conf}
|
||||
|
||||
def _predict_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"world_points": pts3d, "world_points_conf": pts3d_conf}
|
||||
|
||||
def _predict_local_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.local_point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.local_point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
|
||||
|
||||
def _unproject_depth_to_world(
|
||||
self,
|
||||
depth: torch.Tensor,
|
||||
pose_enc: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, S, H, W, _ = depth.shape
|
||||
device = depth.device
|
||||
dtype = depth.dtype
|
||||
|
||||
image_size_hw = (H, W)
|
||||
extrinsics, intrinsics = pose_encoding_to_extri_intri(
|
||||
pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
|
||||
)
|
||||
|
||||
extrinsics_flat = extrinsics.view(B * S, 3, 4)
|
||||
extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
|
||||
extrinsics_4x4[:, :3, :] = extrinsics_flat
|
||||
extrinsics_4x4[:, 3, 3] = 1.0
|
||||
c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(
|
||||
torch.arange(H, device=device, dtype=dtype),
|
||||
torch.arange(W, device=device, dtype=dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
|
||||
|
||||
intrinsics_inv = torch.inverse(intrinsics)
|
||||
camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
|
||||
camera_points = camera_coords * depth
|
||||
|
||||
ones = torch.ones_like(camera_points[..., :1])
|
||||
camera_points_h = torch.cat([camera_points, ones], dim=-1)
|
||||
world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
|
||||
|
||||
return world_points_h[..., :3]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
query_points: Optional[torch.Tensor] = None,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
gather_outputs: bool = True,
|
||||
point_masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the GCT model.
|
||||
|
||||
Args:
|
||||
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
||||
query_points: Optional query points [N, 2] or [B, N, 2]
|
||||
|
||||
Returns:
|
||||
Dictionary containing predictions:
|
||||
- pose_enc: Camera pose encoding [B, S, 9]
|
||||
- depth: Depth maps [B, S, H, W, 1]
|
||||
- depth_conf: Depth confidence [B, S, H, W]
|
||||
- world_points: 3D world coordinates [B, S, H, W, 3]
|
||||
- world_points_conf: Point confidence [B, S, H, W]
|
||||
"""
|
||||
images, query_points = self._normalize_input(images, query_points)
|
||||
|
||||
aggregated_tokens_list, patch_start_idx = self._aggregate_features(
|
||||
images,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
predictions = {}
|
||||
|
||||
predictions.update(self._predict_camera(
|
||||
aggregated_tokens_list,
|
||||
mask=ordered_video,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_depth(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_local_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
if not self.training:
|
||||
predictions["images"] = images
|
||||
|
||||
return predictions
|
||||
444
lingbot_map/models/gct_stream.py
Normal file
444
lingbot_map/models/gct_stream.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
GCTStream - Streaming GCT with KV cache for online inference.
|
||||
|
||||
Provides streaming inference functionality:
|
||||
- Temporal causal attention with KV cache
|
||||
- Sliding window support
|
||||
- Efficient frame-by-frame processing
|
||||
- 3D RoPE support for temporal consistency
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Dict, Any, List
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lingbot_map.heads.camera_head import CameraCausalHead
|
||||
from lingbot_map.models.gct_base import GCTBase
|
||||
from lingbot_map.aggregator.stream import AggregatorStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCTStream(GCTBase):
|
||||
"""
|
||||
Streaming GCT model with KV cache for efficient online inference.
|
||||
|
||||
Features:
|
||||
- AggregatorStream with KV cache support (FlashInfer backend)
|
||||
- CameraCausalHead for pose refinement
|
||||
- Sliding window attention for memory efficiency
|
||||
- Frame-by-frame streaming inference
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
patch_embed: str = 'dinov2_vitl14_reg',
|
||||
pretrained_path: str = '',
|
||||
disable_global_rope: bool = False,
|
||||
# Head configuration
|
||||
enable_camera: bool = True,
|
||||
enable_point: bool = True,
|
||||
enable_local_point: bool = False,
|
||||
enable_depth: bool = True,
|
||||
enable_track: bool = False,
|
||||
# Normalization
|
||||
enable_normalize: bool = False,
|
||||
# Prediction normalization
|
||||
pred_normalization: bool = False,
|
||||
# Stream-specific parameters
|
||||
sliding_window_size: int = -1,
|
||||
num_frame_for_scale: int = 1,
|
||||
num_random_frames: int = 0,
|
||||
attend_to_special_tokens: bool = False,
|
||||
attend_to_scale_frames: bool = False,
|
||||
enable_stream_inference: bool = True, # Default to True for streaming
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
|
||||
enable_camera_3d_rope: bool = False,
|
||||
camera_rope_theta: float = 10000.0,
|
||||
# Scale token configuration (kept for checkpoint compat, ignored)
|
||||
use_scale_token: bool = True,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# Backend selection
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
|
||||
Args:
|
||||
img_size: Input image size
|
||||
patch_size: Patch size for embedding
|
||||
embed_dim: Embedding dimension
|
||||
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
|
||||
pretrained_path: Path to pretrained DINOv2 weights
|
||||
disable_global_rope: Disable RoPE in global attention
|
||||
enable_camera/point/depth/track: Enable prediction heads
|
||||
enable_normalize: Enable normalization
|
||||
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
||||
num_frame_for_scale: Number of scale estimation frames
|
||||
num_random_frames: Number of random frames for long-range dependencies
|
||||
attend_to_special_tokens: Enable cross-frame special token attention
|
||||
attend_to_scale_frames: Whether to attend to scale frames
|
||||
enable_stream_inference: Enable streaming inference with KV cache
|
||||
enable_3d_rope: Enable 3D RoPE for temporal consistency
|
||||
max_frame_num: Maximum number of frames for 3D RoPE
|
||||
use_scale_token: Kept for checkpoint compatibility, ignored
|
||||
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
||||
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
||||
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
||||
kv_cache_include_scale_frames: Include scale frames in KV cache
|
||||
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
||||
"""
|
||||
# Store stream-specific parameters before calling super().__init__()
|
||||
self.pretrained_path = pretrained_path
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.num_frame_for_scale = num_frame_for_scale
|
||||
self.num_random_frames = num_random_frames
|
||||
self.attend_to_special_tokens = attend_to_special_tokens
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.enable_stream_inference = enable_stream_inference
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
self.max_frame_num = max_frame_num
|
||||
# Camera head 3D RoPE settings
|
||||
self.enable_camera_3d_rope = enable_camera_3d_rope
|
||||
self.camera_rope_theta = camera_rope_theta
|
||||
# KV cache parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
embed_dim=embed_dim,
|
||||
patch_embed=patch_embed,
|
||||
disable_global_rope=disable_global_rope,
|
||||
enable_camera=enable_camera,
|
||||
enable_point=enable_point,
|
||||
enable_local_point=enable_local_point,
|
||||
enable_depth=enable_depth,
|
||||
enable_track=enable_track,
|
||||
enable_normalize=enable_normalize,
|
||||
pred_normalization=pred_normalization,
|
||||
enable_3d_rope=enable_3d_rope,
|
||||
use_gradient_checkpoint=use_gradient_checkpoint,
|
||||
)
|
||||
|
||||
def _build_aggregator(self) -> nn.Module:
|
||||
"""
|
||||
Build streaming aggregator with KV cache support (FlashInfer backend).
|
||||
|
||||
Returns:
|
||||
AggregatorStream module
|
||||
"""
|
||||
return AggregatorStream(
|
||||
img_size=self.img_size,
|
||||
patch_size=self.patch_size,
|
||||
embed_dim=self.embed_dim,
|
||||
patch_embed=self.patch_embed,
|
||||
pretrained_path=self.pretrained_path,
|
||||
disable_global_rope=self.disable_global_rope,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
num_frame_for_scale=self.num_frame_for_scale,
|
||||
num_random_frames=self.num_random_frames,
|
||||
attend_to_special_tokens=self.attend_to_special_tokens,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
enable_stream_inference=self.enable_stream_inference,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
max_frame_num=self.max_frame_num,
|
||||
# Backend: FlashInfer (default) or SDPA (fallback)
|
||||
use_flashinfer=not self.use_sdpa,
|
||||
use_sdpa=self.use_sdpa,
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
use_gradient_checkpoint=self.use_gradient_checkpoint,
|
||||
)
|
||||
|
||||
def _build_camera_head(self) -> nn.Module:
|
||||
"""
|
||||
Build causal camera head for streaming inference.
|
||||
|
||||
Returns:
|
||||
CameraCausalHead module or None
|
||||
"""
|
||||
return CameraCausalHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
# Camera head 3D RoPE parameters
|
||||
enable_3d_rope=self.enable_camera_3d_rope,
|
||||
max_frame_num=self.max_frame_num,
|
||||
rope_theta=self.camera_rope_theta,
|
||||
)
|
||||
|
||||
def _aggregate_features(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""
|
||||
Run aggregator to get multi-scale features.
|
||||
|
||||
Args:
|
||||
images: Input images [B, S, 3, H, W]
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
sliding_window_size: Override sliding window size
|
||||
num_frame_per_block: Number of frames per block
|
||||
|
||||
Returns:
|
||||
(aggregated_tokens_list, patch_start_idx)
|
||||
"""
|
||||
aggregated_tokens_list, patch_start_idx = self.aggregator(
|
||||
images,
|
||||
selected_idx=[4, 11, 17, 23],
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
return aggregated_tokens_list, patch_start_idx
|
||||
|
||||
def clean_kv_cache(self):
|
||||
"""
|
||||
Clean KV cache in aggregator.
|
||||
|
||||
Call this method when starting a new video sequence to clear
|
||||
cached key-value pairs from previous sequences.
|
||||
"""
|
||||
if hasattr(self.aggregator, 'clean_kv_cache'):
|
||||
self.aggregator.clean_kv_cache()
|
||||
else:
|
||||
logger.warning("Aggregator does not support KV cache cleaning")
|
||||
if hasattr(self.camera_head, 'kv_cache'):
|
||||
self.camera_head.clean_kv_cache()
|
||||
else:
|
||||
logger.warning("Camera head does not support KV cache cleaning")
|
||||
|
||||
def _set_skip_append(self, skip: bool):
|
||||
"""Set _skip_append flag on all KV caches (aggregator + camera head).
|
||||
|
||||
When skip=True, attention layers will attend to [cached_kv + current_kv]
|
||||
but will NOT store the current frame's KV in cache. This is used for
|
||||
non-keyframe processing in keyframe-based streaming inference.
|
||||
|
||||
Args:
|
||||
skip: If True, subsequent forward passes will not append KV to cache.
|
||||
"""
|
||||
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
|
||||
self.aggregator.kv_cache["_skip_append"] = skip
|
||||
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
||||
for cache_dict in self.camera_head.kv_cache:
|
||||
cache_dict["_skip_append"] = skip
|
||||
|
||||
def get_kv_cache_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about current KV cache state.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics:
|
||||
- num_cached_blocks: Number of blocks with cached KV
|
||||
- cache_memory_mb: Approximate memory usage in MB
|
||||
"""
|
||||
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
|
||||
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
|
||||
|
||||
kv_cache = self.aggregator.kv_cache
|
||||
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
|
||||
|
||||
# Estimate memory usage
|
||||
total_elements = 0
|
||||
for _, v in kv_cache.items():
|
||||
if v is not None and torch.is_tensor(v):
|
||||
total_elements += v.numel()
|
||||
|
||||
# Assume bfloat16 (2 bytes per element)
|
||||
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
|
||||
|
||||
return {
|
||||
"num_cached_blocks": num_cached,
|
||||
"cache_memory_mb": round(cache_memory_mb, 2)
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_streaming(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_scale_frames: Optional[int] = None,
|
||||
keyframe_interval: int = 1,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Streaming inference: process scale frames first, then frame-by-frame.
|
||||
|
||||
This method enables efficient online inference by:
|
||||
1. Processing initial scale frames together (bidirectional attention via scale token)
|
||||
2. Processing remaining frames one-by-one with KV cache (causal streaming)
|
||||
|
||||
Keyframe mode (keyframe_interval > 1):
|
||||
- Every keyframe_interval-th frame (after scale frames) is a keyframe
|
||||
- Keyframes: KV is stored in cache (normal behavior)
|
||||
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
|
||||
- All frames produce full predictions regardless of keyframe status
|
||||
- Reduces KV cache memory growth by ~1/keyframe_interval
|
||||
|
||||
Args:
|
||||
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
||||
num_scale_frames: Number of initial frames for scale estimation.
|
||||
If None, uses self.num_frame_for_scale.
|
||||
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
|
||||
whose KV persists in cache. 1 = every frame is a
|
||||
keyframe (default, same as original behavior).
|
||||
output_device: Device to store output predictions on. If None, keeps on
|
||||
the same device as the model. Set to torch.device('cpu')
|
||||
to offload predictions per-frame and avoid GPU OOM on
|
||||
long sequences.
|
||||
|
||||
Returns:
|
||||
Dictionary containing predictions for all frames:
|
||||
- pose_enc: [B, S, 9]
|
||||
- depth: [B, S, H, W, 1]
|
||||
- depth_conf: [B, S, H, W]
|
||||
- world_points: [B, S, H, W, 3]
|
||||
- world_points_conf: [B, S, H, W]
|
||||
"""
|
||||
# Normalize input shape
|
||||
if len(images.shape) == 4:
|
||||
images = images.unsqueeze(0)
|
||||
B, S, C, H, W = images.shape
|
||||
|
||||
# Determine number of scale frames
|
||||
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
|
||||
scale_frames = min(scale_frames, S) # Cap to available frames
|
||||
|
||||
# Helper to move tensor to output device
|
||||
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
||||
if output_device is not None:
|
||||
return t.to(output_device)
|
||||
return t
|
||||
|
||||
# Clean KV caches before starting new sequence
|
||||
self.clean_kv_cache()
|
||||
|
||||
# Phase 1: Process scale frames together
|
||||
# These frames get bidirectional attention among themselves via scale token
|
||||
logger.info(f'Processing {scale_frames} scale frames...')
|
||||
scale_images = images[:, :scale_frames]
|
||||
scale_output = self.forward(
|
||||
scale_images,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=scale_frames, # Process all scale frames as one block
|
||||
causal_inference=True,
|
||||
)
|
||||
|
||||
# Initialize output lists with scale frame predictions (offload if needed)
|
||||
all_pose_enc = [_to_out(scale_output["pose_enc"])]
|
||||
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
|
||||
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
|
||||
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
|
||||
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
|
||||
del scale_output
|
||||
|
||||
# Phase 2: Process remaining frames one-by-one
|
||||
pbar = tqdm(
|
||||
range(scale_frames, S),
|
||||
desc='Streaming inference',
|
||||
initial=scale_frames,
|
||||
total=S,
|
||||
)
|
||||
for i in pbar:
|
||||
frame_image = images[:, i:i+1]
|
||||
|
||||
# Determine if this frame is a keyframe
|
||||
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
|
||||
|
||||
if not is_keyframe:
|
||||
self._set_skip_append(True)
|
||||
|
||||
frame_output = self.forward(
|
||||
frame_image,
|
||||
num_frame_for_scale=scale_frames, # Keep same for scale token logic
|
||||
num_frame_per_block=1, # Single frame per block
|
||||
causal_inference=True,
|
||||
)
|
||||
|
||||
if not is_keyframe:
|
||||
self._set_skip_append(False)
|
||||
|
||||
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
|
||||
if "depth" in frame_output:
|
||||
all_depth.append(_to_out(frame_output["depth"]))
|
||||
if "depth_conf" in frame_output:
|
||||
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
|
||||
if "world_points" in frame_output:
|
||||
all_world_points.append(_to_out(frame_output["world_points"]))
|
||||
if "world_points_conf" in frame_output:
|
||||
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
|
||||
del frame_output
|
||||
|
||||
# Free GPU memory before concatenation
|
||||
if output_device is not None:
|
||||
# Move images to output device, then free GPU copy
|
||||
images_out = _to_out(images)
|
||||
del images
|
||||
# Clean KV cache (no longer needed after inference)
|
||||
self.clean_kv_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
images_out = images
|
||||
|
||||
# Concatenate all predictions along sequence dimension
|
||||
predictions = {
|
||||
"pose_enc": torch.cat(all_pose_enc, dim=1),
|
||||
}
|
||||
del all_pose_enc
|
||||
if all_depth:
|
||||
predictions["depth"] = torch.cat(all_depth, dim=1)
|
||||
del all_depth
|
||||
if all_depth_conf:
|
||||
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
|
||||
del all_depth_conf
|
||||
if all_world_points:
|
||||
predictions["world_points"] = torch.cat(all_world_points, dim=1)
|
||||
del all_world_points
|
||||
if all_world_points_conf:
|
||||
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
|
||||
del all_world_points_conf
|
||||
|
||||
# Store images for visualization
|
||||
predictions["images"] = images_out
|
||||
|
||||
# Apply prediction normalization if enabled
|
||||
if self.pred_normalization:
|
||||
predictions = self._normalize_predictions(predictions)
|
||||
|
||||
return predictions
|
||||
Reference in New Issue
Block a user