# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from lingbot_map.layers import Mlp from lingbot_map.layers.block import Block from lingbot_map.layers.block import CameraBlock from lingbot_map.heads.head_act import activate_pose from lingbot_map.layers.rope import WanRotaryPosEmbed from functools import partial from torch.utils.checkpoint import checkpoint class CameraHead(nn.Module): """ CameraHead predicts camera parameters from token representations using iterative refinement. It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. """ def __init__( self, dim_in: int = 2048, trunk_depth: int = 4, pose_encoding_type: str = "absT_quaR_FoV", num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, trans_act: str = "linear", quat_act: str = "linear", fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. enable_ulysses_cp=False, ): super().__init__() if pose_encoding_type == "absT_quaR_FoV": self.target_dim = 9 else: raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") self.trans_act = trans_act self.quat_act = quat_act self.fl_act = fl_act self.trunk_depth = trunk_depth self.enable_ulysses_cp = enable_ulysses_cp # Build the trunk using a sequence of transformer blocks. self.trunk = nn.Sequential( *[ Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) for _ in range(trunk_depth) ] ) # Normalizations for camera token and trunk output. self.token_norm = nn.LayerNorm(dim_in) self.trunk_norm = nn.LayerNorm(dim_in) # Learnable empty camera pose token. self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) self.embed_pose = nn.Linear(self.target_dim, dim_in) # Module for producing modulation parameters: shift, scale, and a gate. self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) # Adaptive layer normalization without affine parameters. self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list: """ Forward pass to predict camera parameters. Args: aggregated_tokens_list (list): List of token tensors from the network; the last tensor is used for prediction. num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. Returns: list: A list of predicted camera encodings (post-activation) from each iteration. """ # Use tokens from the last block for camera prediction. tokens = aggregated_tokens_list[-1] # Extract the camera tokens pose_tokens = tokens[:, :, 0] pose_tokens = self.token_norm(pose_tokens) pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) return pred_pose_enc_list def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: """ Iteratively refine camera pose predictions. Args: pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. num_iterations (int): Number of refinement iterations. Returns: list: List of activated camera encodings from each iteration. """ B, S, C = pose_tokens.shape # S is expected to be 1. pred_pose_enc = None pred_pose_enc_list = [] for _ in range(num_iterations): # Use a learned empty pose for the first iteration. if pred_pose_enc is None: module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) else: # Detach the previous prediction to avoid backprop through time. pred_pose_enc = pred_pose_enc.detach() module_input = self.embed_pose(pred_pose_enc) # Generate modulation parameters and split them into shift, scale, and gate components. shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) # Adaptive layer normalization and modulation. pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) pose_tokens_modulated = pose_tokens_modulated + pose_tokens # Apply trunk blocks with enable_ulysses_cp for block in self.trunk: pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp) # Compute the delta update for the pose encoding. pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) if pred_pose_enc is None: pred_pose_enc = pred_pose_enc_delta else: pred_pose_enc = pred_pose_enc + pred_pose_enc_delta # Apply final activation functions for translation, quaternion, and field-of-view. activated_pose = activate_pose( pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act ) pred_pose_enc_list.append(activated_pose) return pred_pose_enc_list def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ Modulate the input tensor using scaling and shifting parameters. """ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 return x * (1 + scale) + shift class CameraCausalHead(nn.Module): """ CameraHead predicts camera parameters from token representations using iterative refinement. It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. """ def __init__( self, dim_in: int = 2048, trunk_depth: int = 4, pose_encoding_type: str = "absT_quaR_FoV", num_heads: int = 16, mlp_ratio: int = 4, init_values: float = 0.01, trans_act: str = "linear", quat_act: str = "linear", fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. num_iterations = 4, elementwise_attn_output_gate: bool = False, sliding_window_size: int = -1, attend_to_scale_frames: bool = False, num_random_frames: int = 0, enable_ulysses_cp: bool = False, attn_class: str = "flexflashattn_varlen", # KV cache parameters kv_cache_sliding_window: int = 64, kv_cache_scale_frames: int = 8, kv_cache_cross_frame_special: bool = True, kv_cache_include_scale_frames: bool = True, kv_cache_camera_only: bool = False, # 3D RoPE parameters enable_3d_rope: bool = False, max_frame_num: int = 1024, rope_theta: float = 10000.0, ): super().__init__() if pose_encoding_type == "absT_quaR_FoV": self.target_dim = 9 else: raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") self.trans_act = trans_act self.quat_act = quat_act self.fl_act = fl_act self.trunk_depth = trunk_depth self.sliding_window_size = sliding_window_size self.enable_ulysses_cp = enable_ulysses_cp self.num_heads = num_heads # 3D RoPE for temporal position encoding self.enable_3d_rope = enable_3d_rope if enable_3d_rope: head_dim = dim_in // num_heads # For camera head: each frame has 1 token (frame_seqlen=1) # patch_size is (max_frames, h=1, w=1) for 3D RoPE # fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder self.rope3d = WanRotaryPosEmbed( attention_head_dim=head_dim, patch_size=(max_frame_num, 1, 1), theta=rope_theta, fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation ) else: self.rope3d = None # Build the trunk using a sequence of transformer blocks. self.trunk = nn.Sequential( *[ CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only) for _ in range(trunk_depth) ] ) # Normalizations for camera token and trunk output. self.token_norm = nn.LayerNorm(dim_in) self.trunk_norm = nn.LayerNorm(dim_in) # Learnable empty camera pose token. self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) self.embed_pose = nn.Linear(self.target_dim, dim_in) # Module for producing modulation parameters: shift, scale, and a gate. self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) # Adaptive layer normalization without affine parameters. self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) self.num_iterations = num_iterations self.kv_cache = None self.pos_cache = None self.frame_idx = 0 self.cp_size = 1 ## Get cp size if enable ulysses cp if self.enable_ulysses_cp: from torchtitan.distributed.sequence_parallel import ( init_sequence_parallel, get_ulysses_sequence_parallel_rank, get_ulysses_sequence_parallel_world_size, ) self.cp_size = get_ulysses_sequence_parallel_world_size() def clean_kv_cache(self): del self.kv_cache self.kv_cache = None self.frame_idx = 0 def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list: """ Forward pass to predict camera parameters. Args: aggregated_tokens_list (list): List of token tensors from the network; the last tensor is used for prediction. num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. sliding_window_size (int, optional): Override the sliding window size for this forward pass. If None, use the default self.sliding_window_size. Returns: list: A list of predicted camera encodings (post-activation) from each iteration. """ # Use passed sliding_window_size if provided, otherwise use default effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size # Use tokens from the last block for camera prediction. tokens = aggregated_tokens_list[-1] # Extract the camera tokens pose_tokens = tokens[:, :, 0] pose_tokens = self.token_norm(pose_tokens) if causal_inference: if self.kv_cache is None: self.kv_cache = [] for i in range(self.num_iterations): self.kv_cache.append({"_skip_append": False}) for j in range(self.trunk_depth): self.kv_cache[i][f"k_{j}"] = None self.kv_cache[i][f"v_{j}"] = None pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size) return pred_pose_enc_list def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list: """ Iteratively refine camera pose predictions. Args: pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. num_iterations (int): Number of refinement iterations. sliding_window_size (int, optional): Sliding window size to use. Returns: list: List of activated camera encodings from each iteration. """ B, S, C = pose_tokens.shape pred_pose_enc = None pred_pose_enc_list = [] # Check if this is the first call (processing scale frames) # Scale frames should use batch mode attention for numerical consistency is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0) # Generate 3D RoPE positions if enabled pos3d = None if self.rope3d is not None: # For camera tokens: shape [B, S, C] where each frame has 1 token # Position for frame f is (f, 0, 0) - temporal varies, spatial fixed # In streaming mode with KV cache, use frame_idx to track global position # Otherwise, generate positions from 0 if self.kv_cache is not None: f_start = self.frame_idx f_end = self.frame_idx + S else: f_start = 0 f_end = None # Will use ppf as frame count pos3d = self.rope3d( ppf=S * self.cp_size, # Total frames (with CP) pph=1, # height = 1 (camera token) ppw=1, # width = 1 (camera token) patch_start_idx=0, # No special tokens before device=pose_tokens.device, f_start=f_start, f_end=f_end, ) # Returns [1, 1, S*cp_size, head_dim//2] complex for i in range(num_iterations): # Use a learned empty pose for the first iteration. if pred_pose_enc is None: module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) else: # Detach the previous prediction to avoid backprop through time. pred_pose_enc = pred_pose_enc.detach() module_input = self.embed_pose(pred_pose_enc) # Generate modulation parameters and split them into shift, scale, and gate components. shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) # Adaptive layer normalization and modulation. pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) pose_tokens_modulated = pose_tokens_modulated + pose_tokens for idx in range(self.trunk_depth): pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames) # Compute the delta update for the pose encoding. pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) if pred_pose_enc is None: pred_pose_enc = pred_pose_enc_delta else: pred_pose_enc = pred_pose_enc + pred_pose_enc_delta # Apply final activation functions for translation, quaternion, and field-of-view. activated_pose = activate_pose( pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act ) pred_pose_enc_list.append(activated_pose) # Update frame_idx for streaming mode (KV cache) if self.kv_cache is not None: self.frame_idx += S return pred_pose_enc_list def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """ Modulate the input tensor using scaling and shifting parameters. """ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 return x * (1 + scale) + shift class CameraDecoder(nn.Module): def __init__( self, in_dim, out_dim, dec_embed_dim=512, depth=5, dec_num_heads=8, mlp_ratio=4, rope=None, need_project=True, use_checkpoint=False, ): super().__init__() self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() self.use_checkpoint = use_checkpoint self.blocks = nn.ModuleList([ Block( dim=dec_embed_dim, num_heads=dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, proj_bias=True, ffn_bias=True, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, ffn_layer=Mlp, init_values=None, qk_norm=False, # attn_class=MemEffAttentionRope, rope=rope ) for _ in range(depth)]) self.linear_out = nn.Linear(dec_embed_dim, out_dim) def forward(self, hidden, xpos=None): hidden = self.projects(hidden) B, V, P, C = hidden.shape hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3]) for i, blk in enumerate(self.blocks): if self.use_checkpoint and self.training: hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False) else: hidden = blk(hidden, pos=xpos) out = self.linear_out(hidden).view(B, V, P, -1) return out