This commit is contained in:
LinZhuoChen
2026-04-18 02:26:23 +08:00
parent 3d1fe1d8e2
commit 01c99afc41
329 changed files with 34 additions and 7 deletions

View File

@@ -272,23 +272,27 @@ class CameraCausalHead(nn.Module):
self.kv_cache = None
self.frame_idx = 0
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = None, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
num_iterations (int, optional): Number of iterative refinement steps.
If None, falls back to self.num_iterations (set at construction).
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
If None, use the default self.sliding_window_size.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
if num_iterations is None:
num_iterations = self.num_iterations
# Use passed sliding_window_size if provided, otherwise use default
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
@@ -299,12 +303,12 @@ class CameraCausalHead(nn.Module):
if causal_inference:
if self.kv_cache is None:
self.kv_cache = []
for i in range(self.num_iterations):
for i in range(num_iterations):
self.kv_cache.append({"_skip_append": False})
for j in range(self.trunk_depth):
self.kv_cache[i][f"k_{j}"] = None
self.kv_cache[i][f"v_{j}"] = None
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
return pred_pose_enc_list