This commit is contained in:
LinZhuoChen
2026-04-18 02:26:23 +08:00
parent 3d1fe1d8e2
commit 01c99afc41
329 changed files with 34 additions and 7 deletions

View File

@@ -166,6 +166,8 @@ class GCTStream(GCTBase):
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
# Gradient checkpointing
use_gradient_checkpoint: bool = True,
# Camera head iterative refinement (lower = faster inference; default 4)
camera_num_iterations: int = 4,
):
"""
Initialize GCTStream.
@@ -214,6 +216,7 @@ class GCTStream(GCTBase):
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
self.use_sdpa = use_sdpa
self.camera_num_iterations = camera_num_iterations
# Call base class __init__ (will call _build_aggregator)
super().__init__(
@@ -277,6 +280,7 @@ class GCTStream(GCTBase):
dim_in=2 * self.embed_dim,
sliding_window_size=self.sliding_window_size,
attend_to_scale_frames=self.attend_to_scale_frames,
num_iterations=self.camera_num_iterations,
# KV cache parameters
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,