Files
lingbot-map/lingbot_map/heads/camera_head.py
LinZhuoChen f9b3ae457a first commit
2026-04-16 09:51:30 +08:00

455 lines
19 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lingbot_map.layers import Mlp
from lingbot_map.layers.block import Block
from lingbot_map.layers.block import CameraBlock
from lingbot_map.heads.head_act import activate_pose
from lingbot_map.layers.rope import WanRotaryPosEmbed
from functools import partial
from torch.utils.checkpoint import checkpoint
class CameraHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
enable_ulysses_cp=False,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.enable_ulysses_cp = enable_ulysses_cp
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
num_iterations (int): Number of refinement iterations.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []
for _ in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
# Apply trunk blocks with enable_ulysses_cp
for block in self.trunk:
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraCausalHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
num_iterations = 4,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
enable_ulysses_cp: bool = False,
attn_class: str = "flexflashattn_varlen",
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# 3D RoPE parameters
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
rope_theta: float = 10000.0,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.sliding_window_size = sliding_window_size
self.enable_ulysses_cp = enable_ulysses_cp
self.num_heads = num_heads
# 3D RoPE for temporal position encoding
self.enable_3d_rope = enable_3d_rope
if enable_3d_rope:
head_dim = dim_in // num_heads
# For camera head: each frame has 1 token (frame_seqlen=1)
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
self.rope3d = WanRotaryPosEmbed(
attention_head_dim=head_dim,
patch_size=(max_frame_num, 1, 1),
theta=rope_theta,
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
)
else:
self.rope3d = None
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
self.num_iterations = num_iterations
self.kv_cache = None
self.pos_cache = None
self.frame_idx = 0
self.cp_size = 1
## Get cp size if enable ulysses cp
if self.enable_ulysses_cp:
from torchtitan.distributed.sequence_parallel import (
init_sequence_parallel,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
)
self.cp_size = get_ulysses_sequence_parallel_world_size()
def clean_kv_cache(self):
del self.kv_cache
self.kv_cache = None
self.frame_idx = 0
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
If None, use the default self.sliding_window_size.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use passed sliding_window_size if provided, otherwise use default
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
if causal_inference:
if self.kv_cache is None:
self.kv_cache = []
for i in range(self.num_iterations):
self.kv_cache.append({"_skip_append": False})
for j in range(self.trunk_depth):
self.kv_cache[i][f"k_{j}"] = None
self.kv_cache[i][f"v_{j}"] = None
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
num_iterations (int): Number of refinement iterations.
sliding_window_size (int, optional): Sliding window size to use.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape
pred_pose_enc = None
pred_pose_enc_list = []
# Check if this is the first call (processing scale frames)
# Scale frames should use batch mode attention for numerical consistency
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
# Generate 3D RoPE positions if enabled
pos3d = None
if self.rope3d is not None:
# For camera tokens: shape [B, S, C] where each frame has 1 token
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
# In streaming mode with KV cache, use frame_idx to track global position
# Otherwise, generate positions from 0
if self.kv_cache is not None:
f_start = self.frame_idx
f_end = self.frame_idx + S
else:
f_start = 0
f_end = None # Will use ppf as frame count
pos3d = self.rope3d(
ppf=S * self.cp_size, # Total frames (with CP)
pph=1, # height = 1 (camera token)
ppw=1, # width = 1 (camera token)
patch_start_idx=0, # No special tokens before
device=pose_tokens.device,
f_start=f_start,
f_end=f_end,
) # Returns [1, 1, S*cp_size, head_dim//2] complex
for i in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
for idx in range(self.trunk_depth):
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
# Update frame_idx for streaming mode (KV cache)
if self.kv_cache is not None:
self.frame_idx += S
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraDecoder(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dec_embed_dim=512,
depth=5,
dec_num_heads=8,
mlp_ratio=4,
rope=None,
need_project=True,
use_checkpoint=False,
):
super().__init__()
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=None,
qk_norm=False,
# attn_class=MemEffAttentionRope,
rope=rope
) for _ in range(depth)])
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
def forward(self, hidden, xpos=None):
hidden = self.projects(hidden)
B, V, P, C = hidden.shape
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
for i, blk in enumerate(self.blocks):
if self.use_checkpoint and self.training:
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
else:
hidden = blk(hidden, pos=xpos)
out = self.linear_out(hidden).view(B, V, P, -1)
return out