455 lines
19 KiB
Python
455 lines
19 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import math
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from lingbot_map.layers import Mlp
|
|
from lingbot_map.layers.block import Block
|
|
from lingbot_map.layers.block import CameraBlock
|
|
from lingbot_map.heads.head_act import activate_pose
|
|
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
|
from functools import partial
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
class CameraHead(nn.Module):
|
|
"""
|
|
CameraHead predicts camera parameters from token representations using iterative refinement.
|
|
|
|
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_in: int = 2048,
|
|
trunk_depth: int = 4,
|
|
pose_encoding_type: str = "absT_quaR_FoV",
|
|
num_heads: int = 16,
|
|
mlp_ratio: int = 4,
|
|
init_values: float = 0.01,
|
|
trans_act: str = "linear",
|
|
quat_act: str = "linear",
|
|
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
|
enable_ulysses_cp=False,
|
|
):
|
|
super().__init__()
|
|
|
|
if pose_encoding_type == "absT_quaR_FoV":
|
|
self.target_dim = 9
|
|
else:
|
|
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
|
|
|
self.trans_act = trans_act
|
|
self.quat_act = quat_act
|
|
self.fl_act = fl_act
|
|
self.trunk_depth = trunk_depth
|
|
|
|
self.enable_ulysses_cp = enable_ulysses_cp
|
|
|
|
# Build the trunk using a sequence of transformer blocks.
|
|
self.trunk = nn.Sequential(
|
|
*[
|
|
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
|
for _ in range(trunk_depth)
|
|
]
|
|
)
|
|
|
|
# Normalizations for camera token and trunk output.
|
|
self.token_norm = nn.LayerNorm(dim_in)
|
|
self.trunk_norm = nn.LayerNorm(dim_in)
|
|
|
|
# Learnable empty camera pose token.
|
|
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
|
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
|
|
|
# Module for producing modulation parameters: shift, scale, and a gate.
|
|
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
|
|
|
# Adaptive layer normalization without affine parameters.
|
|
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
|
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
|
|
|
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
|
|
"""
|
|
Forward pass to predict camera parameters.
|
|
|
|
Args:
|
|
aggregated_tokens_list (list): List of token tensors from the network;
|
|
the last tensor is used for prediction.
|
|
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
|
|
|
Returns:
|
|
list: A list of predicted camera encodings (post-activation) from each iteration.
|
|
"""
|
|
# Use tokens from the last block for camera prediction.
|
|
tokens = aggregated_tokens_list[-1]
|
|
|
|
# Extract the camera tokens
|
|
pose_tokens = tokens[:, :, 0]
|
|
pose_tokens = self.token_norm(pose_tokens)
|
|
|
|
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
|
return pred_pose_enc_list
|
|
|
|
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
|
"""
|
|
Iteratively refine camera pose predictions.
|
|
|
|
Args:
|
|
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
|
num_iterations (int): Number of refinement iterations.
|
|
|
|
Returns:
|
|
list: List of activated camera encodings from each iteration.
|
|
"""
|
|
B, S, C = pose_tokens.shape # S is expected to be 1.
|
|
pred_pose_enc = None
|
|
pred_pose_enc_list = []
|
|
|
|
for _ in range(num_iterations):
|
|
# Use a learned empty pose for the first iteration.
|
|
if pred_pose_enc is None:
|
|
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
|
else:
|
|
# Detach the previous prediction to avoid backprop through time.
|
|
pred_pose_enc = pred_pose_enc.detach()
|
|
module_input = self.embed_pose(pred_pose_enc)
|
|
|
|
# Generate modulation parameters and split them into shift, scale, and gate components.
|
|
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
|
|
|
# Adaptive layer normalization and modulation.
|
|
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
|
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
|
|
|
# Apply trunk blocks with enable_ulysses_cp
|
|
for block in self.trunk:
|
|
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
|
|
# Compute the delta update for the pose encoding.
|
|
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
|
|
|
if pred_pose_enc is None:
|
|
pred_pose_enc = pred_pose_enc_delta
|
|
else:
|
|
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
|
|
|
# Apply final activation functions for translation, quaternion, and field-of-view.
|
|
activated_pose = activate_pose(
|
|
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
|
)
|
|
pred_pose_enc_list.append(activated_pose)
|
|
|
|
return pred_pose_enc_list
|
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Modulate the input tensor using scaling and shifting parameters.
|
|
"""
|
|
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
|
return x * (1 + scale) + shift
|
|
|
|
|
|
class CameraCausalHead(nn.Module):
|
|
"""
|
|
CameraHead predicts camera parameters from token representations using iterative refinement.
|
|
|
|
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_in: int = 2048,
|
|
trunk_depth: int = 4,
|
|
pose_encoding_type: str = "absT_quaR_FoV",
|
|
num_heads: int = 16,
|
|
mlp_ratio: int = 4,
|
|
init_values: float = 0.01,
|
|
trans_act: str = "linear",
|
|
quat_act: str = "linear",
|
|
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
|
num_iterations = 4,
|
|
elementwise_attn_output_gate: bool = False,
|
|
sliding_window_size: int = -1,
|
|
attend_to_scale_frames: bool = False,
|
|
num_random_frames: int = 0,
|
|
enable_ulysses_cp: bool = False,
|
|
attn_class: str = "flexflashattn_varlen",
|
|
# KV cache parameters
|
|
kv_cache_sliding_window: int = 64,
|
|
kv_cache_scale_frames: int = 8,
|
|
kv_cache_cross_frame_special: bool = True,
|
|
kv_cache_include_scale_frames: bool = True,
|
|
kv_cache_camera_only: bool = False,
|
|
# 3D RoPE parameters
|
|
enable_3d_rope: bool = False,
|
|
max_frame_num: int = 1024,
|
|
rope_theta: float = 10000.0,
|
|
):
|
|
super().__init__()
|
|
|
|
if pose_encoding_type == "absT_quaR_FoV":
|
|
self.target_dim = 9
|
|
else:
|
|
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
|
|
|
self.trans_act = trans_act
|
|
self.quat_act = quat_act
|
|
self.fl_act = fl_act
|
|
self.trunk_depth = trunk_depth
|
|
self.sliding_window_size = sliding_window_size
|
|
self.enable_ulysses_cp = enable_ulysses_cp
|
|
self.num_heads = num_heads
|
|
|
|
# 3D RoPE for temporal position encoding
|
|
self.enable_3d_rope = enable_3d_rope
|
|
if enable_3d_rope:
|
|
head_dim = dim_in // num_heads
|
|
# For camera head: each frame has 1 token (frame_seqlen=1)
|
|
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
|
|
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
|
|
self.rope3d = WanRotaryPosEmbed(
|
|
attention_head_dim=head_dim,
|
|
patch_size=(max_frame_num, 1, 1),
|
|
theta=rope_theta,
|
|
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
|
|
)
|
|
else:
|
|
self.rope3d = None
|
|
|
|
# Build the trunk using a sequence of transformer blocks.
|
|
self.trunk = nn.Sequential(
|
|
*[
|
|
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
|
|
for _ in range(trunk_depth)
|
|
]
|
|
)
|
|
|
|
# Normalizations for camera token and trunk output.
|
|
self.token_norm = nn.LayerNorm(dim_in)
|
|
self.trunk_norm = nn.LayerNorm(dim_in)
|
|
|
|
# Learnable empty camera pose token.
|
|
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
|
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
|
|
|
# Module for producing modulation parameters: shift, scale, and a gate.
|
|
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
|
|
|
# Adaptive layer normalization without affine parameters.
|
|
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
|
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
|
|
|
self.num_iterations = num_iterations
|
|
|
|
self.kv_cache = None
|
|
self.pos_cache = None
|
|
self.frame_idx = 0
|
|
self.cp_size = 1
|
|
|
|
## Get cp size if enable ulysses cp
|
|
if self.enable_ulysses_cp:
|
|
from torchtitan.distributed.sequence_parallel import (
|
|
init_sequence_parallel,
|
|
get_ulysses_sequence_parallel_rank,
|
|
get_ulysses_sequence_parallel_world_size,
|
|
)
|
|
|
|
self.cp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
|
|
def clean_kv_cache(self):
|
|
del self.kv_cache
|
|
self.kv_cache = None
|
|
self.frame_idx = 0
|
|
|
|
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
|
"""
|
|
Forward pass to predict camera parameters.
|
|
|
|
Args:
|
|
aggregated_tokens_list (list): List of token tensors from the network;
|
|
the last tensor is used for prediction.
|
|
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
|
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
|
If None, use the default self.sliding_window_size.
|
|
|
|
Returns:
|
|
list: A list of predicted camera encodings (post-activation) from each iteration.
|
|
"""
|
|
# Use passed sliding_window_size if provided, otherwise use default
|
|
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
|
|
|
# Use tokens from the last block for camera prediction.
|
|
tokens = aggregated_tokens_list[-1]
|
|
|
|
# Extract the camera tokens
|
|
pose_tokens = tokens[:, :, 0]
|
|
pose_tokens = self.token_norm(pose_tokens)
|
|
|
|
if causal_inference:
|
|
if self.kv_cache is None:
|
|
self.kv_cache = []
|
|
for i in range(self.num_iterations):
|
|
self.kv_cache.append({"_skip_append": False})
|
|
for j in range(self.trunk_depth):
|
|
self.kv_cache[i][f"k_{j}"] = None
|
|
self.kv_cache[i][f"v_{j}"] = None
|
|
|
|
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
|
return pred_pose_enc_list
|
|
|
|
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
|
|
"""
|
|
Iteratively refine camera pose predictions.
|
|
|
|
Args:
|
|
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
|
num_iterations (int): Number of refinement iterations.
|
|
sliding_window_size (int, optional): Sliding window size to use.
|
|
|
|
Returns:
|
|
list: List of activated camera encodings from each iteration.
|
|
"""
|
|
B, S, C = pose_tokens.shape
|
|
pred_pose_enc = None
|
|
pred_pose_enc_list = []
|
|
|
|
# Check if this is the first call (processing scale frames)
|
|
# Scale frames should use batch mode attention for numerical consistency
|
|
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
|
|
|
|
# Generate 3D RoPE positions if enabled
|
|
pos3d = None
|
|
if self.rope3d is not None:
|
|
# For camera tokens: shape [B, S, C] where each frame has 1 token
|
|
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
|
|
|
|
# In streaming mode with KV cache, use frame_idx to track global position
|
|
# Otherwise, generate positions from 0
|
|
if self.kv_cache is not None:
|
|
f_start = self.frame_idx
|
|
f_end = self.frame_idx + S
|
|
else:
|
|
f_start = 0
|
|
f_end = None # Will use ppf as frame count
|
|
|
|
pos3d = self.rope3d(
|
|
ppf=S * self.cp_size, # Total frames (with CP)
|
|
pph=1, # height = 1 (camera token)
|
|
ppw=1, # width = 1 (camera token)
|
|
patch_start_idx=0, # No special tokens before
|
|
device=pose_tokens.device,
|
|
f_start=f_start,
|
|
f_end=f_end,
|
|
) # Returns [1, 1, S*cp_size, head_dim//2] complex
|
|
|
|
for i in range(num_iterations):
|
|
# Use a learned empty pose for the first iteration.
|
|
if pred_pose_enc is None:
|
|
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
|
else:
|
|
# Detach the previous prediction to avoid backprop through time.
|
|
pred_pose_enc = pred_pose_enc.detach()
|
|
module_input = self.embed_pose(pred_pose_enc)
|
|
|
|
# Generate modulation parameters and split them into shift, scale, and gate components.
|
|
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
|
|
|
# Adaptive layer normalization and modulation.
|
|
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
|
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
|
|
|
for idx in range(self.trunk_depth):
|
|
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
|
|
# Compute the delta update for the pose encoding.
|
|
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
|
|
|
if pred_pose_enc is None:
|
|
pred_pose_enc = pred_pose_enc_delta
|
|
else:
|
|
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
|
|
|
# Apply final activation functions for translation, quaternion, and field-of-view.
|
|
activated_pose = activate_pose(
|
|
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
|
)
|
|
pred_pose_enc_list.append(activated_pose)
|
|
|
|
# Update frame_idx for streaming mode (KV cache)
|
|
if self.kv_cache is not None:
|
|
self.frame_idx += S
|
|
|
|
return pred_pose_enc_list
|
|
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Modulate the input tensor using scaling and shifting parameters.
|
|
"""
|
|
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
|
return x * (1 + scale) + shift
|
|
|
|
|
|
|
|
|
|
class CameraDecoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_dim,
|
|
out_dim,
|
|
dec_embed_dim=512,
|
|
depth=5,
|
|
dec_num_heads=8,
|
|
mlp_ratio=4,
|
|
rope=None,
|
|
need_project=True,
|
|
use_checkpoint=False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
|
self.use_checkpoint = use_checkpoint
|
|
|
|
self.blocks = nn.ModuleList([
|
|
Block(
|
|
dim=dec_embed_dim,
|
|
num_heads=dec_num_heads,
|
|
mlp_ratio=mlp_ratio,
|
|
qkv_bias=True,
|
|
proj_bias=True,
|
|
ffn_bias=True,
|
|
drop_path=0.0,
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
act_layer=nn.GELU,
|
|
ffn_layer=Mlp,
|
|
init_values=None,
|
|
qk_norm=False,
|
|
# attn_class=MemEffAttentionRope,
|
|
rope=rope
|
|
) for _ in range(depth)])
|
|
|
|
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
|
|
|
def forward(self, hidden, xpos=None):
|
|
hidden = self.projects(hidden)
|
|
B, V, P, C = hidden.shape
|
|
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
|
|
for i, blk in enumerate(self.blocks):
|
|
if self.use_checkpoint and self.training:
|
|
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
|
|
else:
|
|
hidden = blk(hidden, pos=xpos)
|
|
out = self.linear_out(hidden).view(B, V, P, -1)
|
|
|
|
return out
|