update
This commit is contained in:
@@ -75,6 +75,8 @@ class GCTStream(GCTBase):
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
# Camera head iterative refinement (lower = faster inference; default 4)
|
||||
camera_num_iterations: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
@@ -123,6 +125,7 @@ class GCTStream(GCTBase):
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
self.camera_num_iterations = camera_num_iterations
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
@@ -186,6 +189,7 @@ class GCTStream(GCTBase):
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
num_iterations=self.camera_num_iterations,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
|
||||
@@ -166,6 +166,8 @@ class GCTStream(GCTBase):
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
# Camera head iterative refinement (lower = faster inference; default 4)
|
||||
camera_num_iterations: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
@@ -214,6 +216,7 @@ class GCTStream(GCTBase):
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
self.camera_num_iterations = camera_num_iterations
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
@@ -277,6 +280,7 @@ class GCTStream(GCTBase):
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
num_iterations=self.camera_num_iterations,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
|
||||
Reference in New Issue
Block a user