first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

View File

@@ -0,0 +1,454 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lingbot_map.layers import Mlp
from lingbot_map.layers.block import Block
from lingbot_map.layers.block import CameraBlock
from lingbot_map.heads.head_act import activate_pose
from lingbot_map.layers.rope import WanRotaryPosEmbed
from functools import partial
from torch.utils.checkpoint import checkpoint
class CameraHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
enable_ulysses_cp=False,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.enable_ulysses_cp = enable_ulysses_cp
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
num_iterations (int): Number of refinement iterations.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []
for _ in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
# Apply trunk blocks with enable_ulysses_cp
for block in self.trunk:
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraCausalHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
num_iterations = 4,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
enable_ulysses_cp: bool = False,
attn_class: str = "flexflashattn_varlen",
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# 3D RoPE parameters
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
rope_theta: float = 10000.0,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.sliding_window_size = sliding_window_size
self.enable_ulysses_cp = enable_ulysses_cp
self.num_heads = num_heads
# 3D RoPE for temporal position encoding
self.enable_3d_rope = enable_3d_rope
if enable_3d_rope:
head_dim = dim_in // num_heads
# For camera head: each frame has 1 token (frame_seqlen=1)
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
self.rope3d = WanRotaryPosEmbed(
attention_head_dim=head_dim,
patch_size=(max_frame_num, 1, 1),
theta=rope_theta,
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
)
else:
self.rope3d = None
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
self.num_iterations = num_iterations
self.kv_cache = None
self.pos_cache = None
self.frame_idx = 0
self.cp_size = 1
## Get cp size if enable ulysses cp
if self.enable_ulysses_cp:
from torchtitan.distributed.sequence_parallel import (
init_sequence_parallel,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
)
self.cp_size = get_ulysses_sequence_parallel_world_size()
def clean_kv_cache(self):
del self.kv_cache
self.kv_cache = None
self.frame_idx = 0
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
If None, use the default self.sliding_window_size.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use passed sliding_window_size if provided, otherwise use default
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
if causal_inference:
if self.kv_cache is None:
self.kv_cache = []
for i in range(self.num_iterations):
self.kv_cache.append({"_skip_append": False})
for j in range(self.trunk_depth):
self.kv_cache[i][f"k_{j}"] = None
self.kv_cache[i][f"v_{j}"] = None
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
num_iterations (int): Number of refinement iterations.
sliding_window_size (int, optional): Sliding window size to use.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape
pred_pose_enc = None
pred_pose_enc_list = []
# Check if this is the first call (processing scale frames)
# Scale frames should use batch mode attention for numerical consistency
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
# Generate 3D RoPE positions if enabled
pos3d = None
if self.rope3d is not None:
# For camera tokens: shape [B, S, C] where each frame has 1 token
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
# In streaming mode with KV cache, use frame_idx to track global position
# Otherwise, generate positions from 0
if self.kv_cache is not None:
f_start = self.frame_idx
f_end = self.frame_idx + S
else:
f_start = 0
f_end = None # Will use ppf as frame count
pos3d = self.rope3d(
ppf=S * self.cp_size, # Total frames (with CP)
pph=1, # height = 1 (camera token)
ppw=1, # width = 1 (camera token)
patch_start_idx=0, # No special tokens before
device=pose_tokens.device,
f_start=f_start,
f_end=f_end,
) # Returns [1, 1, S*cp_size, head_dim//2] complex
for i in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
for idx in range(self.trunk_depth):
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
# Update frame_idx for streaming mode (KV cache)
if self.kv_cache is not None:
self.frame_idx += S
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraDecoder(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dec_embed_dim=512,
depth=5,
dec_num_heads=8,
mlp_ratio=4,
rope=None,
need_project=True,
use_checkpoint=False,
):
super().__init__()
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=None,
qk_norm=False,
# attn_class=MemEffAttentionRope,
rope=rope
) for _ in range(depth)])
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
def forward(self, hidden, xpos=None):
hidden = self.projects(hidden)
B, V, P, C = hidden.shape
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
for i, blk in enumerate(self.blocks):
if self.use_checkpoint and self.training:
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
else:
hidden = blk(hidden, pos=xpos)
out = self.linear_out(hidden).view(B, V, P, -1)
return out

View File

@@ -0,0 +1,679 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
import os
from typing import List, Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .head_act import activate_head
from .utils import create_uv_grid, position_grid_to_embed
class DPTHead(nn.Module):
"""
DPT Head for dense prediction tasks.
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
backbone and produces dense predictions by fusing multi-scale features.
Args:
dim_in (int): Input dimension (channels).
patch_size (int, optional): Patch size. Default is 14.
output_dim (int, optional): Number of output channels. Default is 4.
activation (str, optional): Activation type. Default is "inv_log".
conf_activation (str, optional): Confidence activation type. Default is "expp1".
features (int, optional): Feature channels for intermediate representations. Default is 256.
out_channels (List[int], optional): Output channels for each intermediate layer.
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
"""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 4,
activation: str = "inv_log",
conf_activation: str = "expp1",
features: int = 256,
out_channels: List[int] = [256, 512, 1024, 1024],
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
pos_embed: bool = True,
feature_only: bool = False,
down_ratio: int = 1,
) -> None:
super(DPTHead, self).__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.feature_only = feature_only
self.down_ratio = down_ratio
self.intermediate_layer_idx = intermediate_layer_idx
self.norm = nn.LayerNorm(dim_in)
# Projection layers for each output channel from tokens.
self.projects = nn.ModuleList(
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
)
# Resize layers for upsampling feature maps.
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
self.scratch = _make_scratch(out_channels, features, expand=False)
# Attach additional modules to scratch.
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features)
self.scratch.refinenet2 = _make_fusion_block(features)
self.scratch.refinenet3 = _make_fusion_block(features)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
head_features_1 = features
head_features_2 = 32
if feature_only:
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
conv2_in_channels = head_features_1 // 2
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
def forward(
self,
aggregated_tokens_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frames_chunk_size: int = 8,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass through the DPT head, supports processing by chunking frames.
Args:
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
patch_start_idx (int): Starting index for patch tokens in the token sequence.
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
frames_chunk_size (int, optional): Number of frames to process in each chunk.
If None or larger than S, all frames are processed at once. Default: 8.
Returns:
Tensor or Tuple[Tensor, Tensor]:
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
"""
B, _, _, H, W = images.shape
S = aggregated_tokens_list[0].shape[1]
# If frames_chunk_size is not specified or greater than S, process all frames at once
if frames_chunk_size is None or frames_chunk_size >= S:
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
# Otherwise, process frames in chunks to manage memory usage
assert frames_chunk_size > 0
# Process frames in batches
all_preds = []
all_conf = []
for frames_start_idx in range(0, S, frames_chunk_size):
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
# Process batch of frames
if self.feature_only:
chunk_output = self._forward_impl(
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
)
all_preds.append(chunk_output)
else:
chunk_preds, chunk_conf = self._forward_impl(
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
)
all_preds.append(chunk_preds)
all_conf.append(chunk_conf)
# Concatenate results along the sequence dimension
if self.feature_only:
return torch.cat(all_preds, dim=1)
else:
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
def _forward_impl(
self,
aggregated_tokens_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frames_start_idx: int = None,
frames_end_idx: int = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Implementation of the forward pass through the DPT head.
This method processes a specific chunk of frames from the sequence.
Args:
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
images (Tensor): Input images with shape [B, S, 3, H, W].
patch_start_idx (int): Starting index for patch tokens.
frames_start_idx (int, optional): Starting index for frames to process.
frames_end_idx (int, optional): Ending index for frames to process.
Returns:
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
"""
B, _, _, H, W = images.shape
patch_h, patch_w = H // self.patch_size, W // self.patch_size
out = []
dpt_idx = 0
for layer_idx in self.intermediate_layer_idx:
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
if frames_start_idx is not None and frames_end_idx is not None:
x = x[:, frames_start_idx:frames_end_idx]
B, S = x.shape[0], x.shape[1]
x = x.reshape(B * S, -1, x.shape[-1])
x = self.norm(x)
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[dpt_idx](x)
if self.pos_embed:
x = self._apply_pos_embed(x, W, H)
x = self.resize_layers[dpt_idx](x)
out.append(x)
dpt_idx += 1
# Fuse features from multiple layers.
out = self.scratch_forward(out)
# Interpolate fused output to match target image resolution.
out = custom_interpolate(
out,
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
mode="bilinear",
align_corners=True,
)
if self.pos_embed:
out = self._apply_pos_embed(out, W, H)
if self.feature_only:
return out.view(B, S, *out.shape[1:])
out = self.scratch.output_conv2(out)
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
preds = preds.view(B, S, *preds.shape[1:])
conf = conf.view(B, S, *conf.shape[1:])
return preds, conf
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""
Apply positional embedding to tensor x.
"""
patch_w = x.shape[-1]
patch_h = x.shape[-2]
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
pos_embed = pos_embed * ratio
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
return x + pos_embed
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
"""
Forward pass through the fusion blocks.
Args:
features (List[Tensor]): List of feature maps from different layers.
Returns:
Tensor: Fused feature map.
"""
layer_1, layer_2, layer_3, layer_4 = features
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
del layer_4_rn, layer_4
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
del layer_3_rn, layer_3
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
del layer_2_rn, layer_2
out = self.scratch.refinenet1(out, layer_1_rn)
del layer_1_rn, layer_1
out = self.scratch.output_conv1(out)
return out
################################################################################
# Modules
################################################################################
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
return FeatureFusionBlock(
features,
nn.ReLU(inplace=True),
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=size,
has_residual=has_residual,
groups=groups,
)
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn, groups=1):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = groups
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.norm1 = None
self.norm2 = None
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.norm1 is not None:
out = self.norm1(out)
out = self.activation(out)
out = self.conv2(out)
if self.norm2 is not None:
out = self.norm2(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None,
has_residual=True,
groups=1,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = groups
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
)
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.has_residual = has_residual
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if self.has_residual:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output
def custom_interpolate(
x: torch.Tensor,
size: Tuple[int, int] = None,
scale_factor: float = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
"""
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
"""
if size is None:
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
if input_elements > INT_MAX:
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
interpolated_chunks = [
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
]
x = torch.cat(interpolated_chunks, dim=0)
return x.contiguous()
else:
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
class DPTHead_Update(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead_Update, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
if return_intermediate:
return out, path_1, path_2, path_3, path_4
else:
out = self.scratch.output_conv2(out)
return out
def _make_fusion_block_slam(features, use_bn, size=None):
return FeatureFusionBlock_slam(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class FeatureFusionBlock_slam(nn.Module):
"""Feature fusion block.
"""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_slam, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size=size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
"""
Activate pose parameters with specified activation functions.
Args:
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
trans_act: Activation type for translation component
quat_act: Activation type for quaternion component
fl_act: Activation type for focal length component
Returns:
Activated pose parameters tensor
"""
T = pred_pose_enc[..., :3]
quat = pred_pose_enc[..., 3:7]
fl = pred_pose_enc[..., 7:] # or fov
T = base_pose_act(T, trans_act)
quat = base_pose_act(quat, quat_act)
fl = base_pose_act(fl, fl_act) # or fov
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
return pred_pose_enc
def base_pose_act(pose_enc, act_type="linear"):
"""
Apply basic activation function to pose parameters.
Args:
pose_enc: Tensor containing encoded pose parameters
act_type: Activation type ("linear", "inv_log", "exp", "relu")
Returns:
Activated pose parameters
"""
if act_type == "linear":
return pose_enc
elif act_type == "inv_log":
return inverse_log_transform(pose_enc)
elif act_type == "exp":
return torch.exp(pose_enc)
elif act_type == "relu":
return F.relu(pose_enc)
else:
raise ValueError(f"Unknown act_type: {act_type}")
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
"""
Process network output to extract 3D points and confidence values.
Args:
out: Network output tensor (B, C, H, W)
activation: Activation type for 3D points
conf_activation: Activation type for confidence values
Returns:
Tuple of (3D points tensor, confidence tensor)
"""
# Move channels from last dim to the 4th dimension => (B, H, W, C)
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
# Split into xyz (first C-1 channels) and confidence (last channel)
xyz = fmap[:, :, :, :-1]
conf = fmap[:, :, :, -1]
if activation == "norm_exp":
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
xyz_normed = xyz / d
pts3d = xyz_normed * torch.expm1(d)
elif activation == "norm":
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
elif activation == "exp":
pts3d = torch.exp(xyz)
elif activation == "relu":
pts3d = F.relu(xyz)
elif activation == "inv_log":
pts3d = inverse_log_transform(xyz)
elif activation == "xy_inv_log":
xy, z = xyz.split([2, 1], dim=-1)
z = inverse_log_transform(z)
pts3d = torch.cat([xy * z, z], dim=-1)
elif activation == "sigmoid":
pts3d = torch.sigmoid(xyz)
elif activation == "linear":
pts3d = xyz
else:
raise ValueError(f"Unknown activation: {activation}")
if conf_activation == "expp1":
conf_out = 1 + conf.exp()
elif conf_activation == "expp0":
conf_out = conf.exp()
elif conf_activation == "sigmoid":
conf_out = torch.sigmoid(conf)
else:
raise ValueError(f"Unknown conf_activation: {conf_activation}")
return pts3d, conf_out
def inverse_log_transform(y):
"""
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
Args:
y: Input tensor
Returns:
Transformed tensor
"""
return torch.sign(y) * (torch.expm1(torch.abs(y)))

109
lingbot_map/heads/utils.py Normal file
View File

@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
"""
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
Args:
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
embed_dim: Output channel dimension for embeddings
Returns:
Tensor of shape (H, W, embed_dim) with positional embeddings
"""
H, W, grid_dim = pos_grid.shape
assert grid_dim == 2
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
# Process x and y coordinates separately
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
# Combine and reshape
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
return emb.view(H, W, embed_dim) # [H, W, D]
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
device = pos.device
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
omega /= embed_dim / 2.0
omega = 1.0 / omega_0**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb.float()
# Inspired by https://github.com/microsoft/moge
def create_uv_grid(
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
) -> torch.Tensor:
"""
Create a normalized UV grid of shape (width, height, 2).
The grid spans horizontally and vertically according to an aspect ratio,
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
corner is at (x_span, y_span), normalized by the diagonal of the plane.
Args:
width (int): Number of points horizontally.
height (int): Number of points vertically.
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
dtype (torch.dtype, optional): Data type of the resulting tensor.
device (torch.device, optional): Device on which the tensor is created.
Returns:
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
"""
# Derive aspect ratio if not explicitly provided
if aspect_ratio is None:
aspect_ratio = float(width) / float(height)
# Compute normalized spans for X and Y
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
# Establish the linspace boundaries
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
# Generate 1D coordinates
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
# Create 2D meshgrid (width x height) and stack into UV
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
uv_grid = torch.stack((uu, vv), dim=-1)
return uv_grid