update
This commit is contained in:
@@ -272,23 +272,27 @@ class CameraCausalHead(nn.Module):
|
||||
self.kv_cache = None
|
||||
self.frame_idx = 0
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = None, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
num_iterations (int, optional): Number of iterative refinement steps.
|
||||
If None, falls back to self.num_iterations (set at construction).
|
||||
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
||||
If None, use the default self.sliding_window_size.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
if num_iterations is None:
|
||||
num_iterations = self.num_iterations
|
||||
|
||||
# Use passed sliding_window_size if provided, otherwise use default
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
@@ -299,12 +303,12 @@ class CameraCausalHead(nn.Module):
|
||||
if causal_inference:
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = []
|
||||
for i in range(self.num_iterations):
|
||||
for i in range(num_iterations):
|
||||
self.kv_cache.append({"_skip_append": False})
|
||||
for j in range(self.trunk_depth):
|
||||
self.kv_cache[i][f"k_{j}"] = None
|
||||
self.kv_cache[i][f"v_{j}"] = None
|
||||
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ class GCTStream(GCTBase):
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
# Camera head iterative refinement (lower = faster inference; default 4)
|
||||
camera_num_iterations: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
@@ -123,6 +125,7 @@ class GCTStream(GCTBase):
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
self.camera_num_iterations = camera_num_iterations
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
@@ -186,6 +189,7 @@ class GCTStream(GCTBase):
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
num_iterations=self.camera_num_iterations,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
|
||||
@@ -166,6 +166,8 @@ class GCTStream(GCTBase):
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
# Camera head iterative refinement (lower = faster inference; default 4)
|
||||
camera_num_iterations: int = 4,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
@@ -214,6 +216,7 @@ class GCTStream(GCTBase):
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
self.camera_num_iterations = camera_num_iterations
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
@@ -277,6 +280,7 @@ class GCTStream(GCTBase):
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
num_iterations=self.camera_num_iterations,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
|
||||
Reference in New Issue
Block a user