update
This commit is contained in:
@@ -272,23 +272,27 @@ class CameraCausalHead(nn.Module):
|
||||
self.kv_cache = None
|
||||
self.frame_idx = 0
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = None, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
num_iterations (int, optional): Number of iterative refinement steps.
|
||||
If None, falls back to self.num_iterations (set at construction).
|
||||
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
||||
If None, use the default self.sliding_window_size.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
if num_iterations is None:
|
||||
num_iterations = self.num_iterations
|
||||
|
||||
# Use passed sliding_window_size if provided, otherwise use default
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
@@ -299,12 +303,12 @@ class CameraCausalHead(nn.Module):
|
||||
if causal_inference:
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = []
|
||||
for i in range(self.num_iterations):
|
||||
for i in range(num_iterations):
|
||||
self.kv_cache.append({"_skip_append": False})
|
||||
for j in range(self.trunk_depth):
|
||||
self.kv_cache[i][f"k_{j}"] = None
|
||||
self.kv_cache[i][f"v_{j}"] = None
|
||||
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user