first commit
This commit is contained in:
0
lingbot_map/heads/__init__.py
Normal file
0
lingbot_map/heads/__init__.py
Normal file
454
lingbot_map/heads/camera_head.py
Normal file
454
lingbot_map/heads/camera_head.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lingbot_map.layers import Mlp
|
||||
from lingbot_map.layers.block import Block
|
||||
from lingbot_map.layers.block import CameraBlock
|
||||
from lingbot_map.heads.head_act import activate_pose
|
||||
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
||||
from functools import partial
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class CameraHead(nn.Module):
|
||||
"""
|
||||
CameraHead predicts camera parameters from token representations using iterative refinement.
|
||||
|
||||
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int = 2048,
|
||||
trunk_depth: int = 4,
|
||||
pose_encoding_type: str = "absT_quaR_FoV",
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
trans_act: str = "linear",
|
||||
quat_act: str = "linear",
|
||||
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
||||
enable_ulysses_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
self.target_dim = 9
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
||||
|
||||
self.trans_act = trans_act
|
||||
self.quat_act = quat_act
|
||||
self.fl_act = fl_act
|
||||
self.trunk_depth = trunk_depth
|
||||
|
||||
self.enable_ulysses_cp = enable_ulysses_cp
|
||||
|
||||
# Build the trunk using a sequence of transformer blocks.
|
||||
self.trunk = nn.Sequential(
|
||||
*[
|
||||
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
||||
for _ in range(trunk_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Normalizations for camera token and trunk output.
|
||||
self.token_norm = nn.LayerNorm(dim_in)
|
||||
self.trunk_norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Learnable empty camera pose token.
|
||||
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
||||
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
||||
|
||||
# Module for producing modulation parameters: shift, scale, and a gate.
|
||||
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
||||
|
||||
# Adaptive layer normalization without affine parameters.
|
||||
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
||||
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
# Extract the camera tokens
|
||||
pose_tokens = tokens[:, :, 0]
|
||||
pose_tokens = self.token_norm(pose_tokens)
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
||||
return pred_pose_enc_list
|
||||
|
||||
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
||||
"""
|
||||
Iteratively refine camera pose predictions.
|
||||
|
||||
Args:
|
||||
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
||||
num_iterations (int): Number of refinement iterations.
|
||||
|
||||
Returns:
|
||||
list: List of activated camera encodings from each iteration.
|
||||
"""
|
||||
B, S, C = pose_tokens.shape # S is expected to be 1.
|
||||
pred_pose_enc = None
|
||||
pred_pose_enc_list = []
|
||||
|
||||
for _ in range(num_iterations):
|
||||
# Use a learned empty pose for the first iteration.
|
||||
if pred_pose_enc is None:
|
||||
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
||||
else:
|
||||
# Detach the previous prediction to avoid backprop through time.
|
||||
pred_pose_enc = pred_pose_enc.detach()
|
||||
module_input = self.embed_pose(pred_pose_enc)
|
||||
|
||||
# Generate modulation parameters and split them into shift, scale, and gate components.
|
||||
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
||||
|
||||
# Adaptive layer normalization and modulation.
|
||||
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
||||
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
||||
|
||||
# Apply trunk blocks with enable_ulysses_cp
|
||||
for block in self.trunk:
|
||||
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
|
||||
# Compute the delta update for the pose encoding.
|
||||
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
||||
|
||||
if pred_pose_enc is None:
|
||||
pred_pose_enc = pred_pose_enc_delta
|
||||
else:
|
||||
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
||||
|
||||
# Apply final activation functions for translation, quaternion, and field-of-view.
|
||||
activated_pose = activate_pose(
|
||||
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
||||
)
|
||||
pred_pose_enc_list.append(activated_pose)
|
||||
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modulate the input tensor using scaling and shifting parameters.
|
||||
"""
|
||||
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class CameraCausalHead(nn.Module):
|
||||
"""
|
||||
CameraHead predicts camera parameters from token representations using iterative refinement.
|
||||
|
||||
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int = 2048,
|
||||
trunk_depth: int = 4,
|
||||
pose_encoding_type: str = "absT_quaR_FoV",
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
trans_act: str = "linear",
|
||||
quat_act: str = "linear",
|
||||
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
||||
num_iterations = 4,
|
||||
elementwise_attn_output_gate: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
attend_to_scale_frames: bool = False,
|
||||
num_random_frames: int = 0,
|
||||
enable_ulysses_cp: bool = False,
|
||||
attn_class: str = "flexflashattn_varlen",
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# 3D RoPE parameters
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
self.target_dim = 9
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
||||
|
||||
self.trans_act = trans_act
|
||||
self.quat_act = quat_act
|
||||
self.fl_act = fl_act
|
||||
self.trunk_depth = trunk_depth
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.enable_ulysses_cp = enable_ulysses_cp
|
||||
self.num_heads = num_heads
|
||||
|
||||
# 3D RoPE for temporal position encoding
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
if enable_3d_rope:
|
||||
head_dim = dim_in // num_heads
|
||||
# For camera head: each frame has 1 token (frame_seqlen=1)
|
||||
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
|
||||
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
|
||||
self.rope3d = WanRotaryPosEmbed(
|
||||
attention_head_dim=head_dim,
|
||||
patch_size=(max_frame_num, 1, 1),
|
||||
theta=rope_theta,
|
||||
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
|
||||
)
|
||||
else:
|
||||
self.rope3d = None
|
||||
|
||||
# Build the trunk using a sequence of transformer blocks.
|
||||
self.trunk = nn.Sequential(
|
||||
*[
|
||||
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
|
||||
for _ in range(trunk_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Normalizations for camera token and trunk output.
|
||||
self.token_norm = nn.LayerNorm(dim_in)
|
||||
self.trunk_norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Learnable empty camera pose token.
|
||||
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
||||
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
||||
|
||||
# Module for producing modulation parameters: shift, scale, and a gate.
|
||||
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
||||
|
||||
# Adaptive layer normalization without affine parameters.
|
||||
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
||||
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
||||
|
||||
self.num_iterations = num_iterations
|
||||
|
||||
self.kv_cache = None
|
||||
self.pos_cache = None
|
||||
self.frame_idx = 0
|
||||
self.cp_size = 1
|
||||
|
||||
## Get cp size if enable ulysses cp
|
||||
if self.enable_ulysses_cp:
|
||||
from torchtitan.distributed.sequence_parallel import (
|
||||
init_sequence_parallel,
|
||||
get_ulysses_sequence_parallel_rank,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
)
|
||||
|
||||
self.cp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
|
||||
|
||||
def clean_kv_cache(self):
|
||||
del self.kv_cache
|
||||
self.kv_cache = None
|
||||
self.frame_idx = 0
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
||||
If None, use the default self.sliding_window_size.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
# Use passed sliding_window_size if provided, otherwise use default
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
# Extract the camera tokens
|
||||
pose_tokens = tokens[:, :, 0]
|
||||
pose_tokens = self.token_norm(pose_tokens)
|
||||
|
||||
if causal_inference:
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = []
|
||||
for i in range(self.num_iterations):
|
||||
self.kv_cache.append({"_skip_append": False})
|
||||
for j in range(self.trunk_depth):
|
||||
self.kv_cache[i][f"k_{j}"] = None
|
||||
self.kv_cache[i][f"v_{j}"] = None
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
||||
return pred_pose_enc_list
|
||||
|
||||
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
|
||||
"""
|
||||
Iteratively refine camera pose predictions.
|
||||
|
||||
Args:
|
||||
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
||||
num_iterations (int): Number of refinement iterations.
|
||||
sliding_window_size (int, optional): Sliding window size to use.
|
||||
|
||||
Returns:
|
||||
list: List of activated camera encodings from each iteration.
|
||||
"""
|
||||
B, S, C = pose_tokens.shape
|
||||
pred_pose_enc = None
|
||||
pred_pose_enc_list = []
|
||||
|
||||
# Check if this is the first call (processing scale frames)
|
||||
# Scale frames should use batch mode attention for numerical consistency
|
||||
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
|
||||
|
||||
# Generate 3D RoPE positions if enabled
|
||||
pos3d = None
|
||||
if self.rope3d is not None:
|
||||
# For camera tokens: shape [B, S, C] where each frame has 1 token
|
||||
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
|
||||
|
||||
# In streaming mode with KV cache, use frame_idx to track global position
|
||||
# Otherwise, generate positions from 0
|
||||
if self.kv_cache is not None:
|
||||
f_start = self.frame_idx
|
||||
f_end = self.frame_idx + S
|
||||
else:
|
||||
f_start = 0
|
||||
f_end = None # Will use ppf as frame count
|
||||
|
||||
pos3d = self.rope3d(
|
||||
ppf=S * self.cp_size, # Total frames (with CP)
|
||||
pph=1, # height = 1 (camera token)
|
||||
ppw=1, # width = 1 (camera token)
|
||||
patch_start_idx=0, # No special tokens before
|
||||
device=pose_tokens.device,
|
||||
f_start=f_start,
|
||||
f_end=f_end,
|
||||
) # Returns [1, 1, S*cp_size, head_dim//2] complex
|
||||
|
||||
for i in range(num_iterations):
|
||||
# Use a learned empty pose for the first iteration.
|
||||
if pred_pose_enc is None:
|
||||
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
||||
else:
|
||||
# Detach the previous prediction to avoid backprop through time.
|
||||
pred_pose_enc = pred_pose_enc.detach()
|
||||
module_input = self.embed_pose(pred_pose_enc)
|
||||
|
||||
# Generate modulation parameters and split them into shift, scale, and gate components.
|
||||
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
||||
|
||||
# Adaptive layer normalization and modulation.
|
||||
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
||||
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
||||
|
||||
for idx in range(self.trunk_depth):
|
||||
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
|
||||
# Compute the delta update for the pose encoding.
|
||||
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
||||
|
||||
if pred_pose_enc is None:
|
||||
pred_pose_enc = pred_pose_enc_delta
|
||||
else:
|
||||
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
||||
|
||||
# Apply final activation functions for translation, quaternion, and field-of-view.
|
||||
activated_pose = activate_pose(
|
||||
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
||||
)
|
||||
pred_pose_enc_list.append(activated_pose)
|
||||
|
||||
# Update frame_idx for streaming mode (KV cache)
|
||||
if self.kv_cache is not None:
|
||||
self.frame_idx += S
|
||||
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modulate the input tensor using scaling and shifting parameters.
|
||||
"""
|
||||
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
|
||||
|
||||
class CameraDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dec_embed_dim=512,
|
||||
depth=5,
|
||||
dec_num_heads=8,
|
||||
mlp_ratio=4,
|
||||
rope=None,
|
||||
need_project=True,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=dec_embed_dim,
|
||||
num_heads=dec_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
proj_bias=True,
|
||||
ffn_bias=True,
|
||||
drop_path=0.0,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
ffn_layer=Mlp,
|
||||
init_values=None,
|
||||
qk_norm=False,
|
||||
# attn_class=MemEffAttentionRope,
|
||||
rope=rope
|
||||
) for _ in range(depth)])
|
||||
|
||||
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
||||
|
||||
def forward(self, hidden, xpos=None):
|
||||
hidden = self.projects(hidden)
|
||||
B, V, P, C = hidden.shape
|
||||
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.use_checkpoint and self.training:
|
||||
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
|
||||
else:
|
||||
hidden = blk(hidden, pos=xpos)
|
||||
out = self.linear_out(hidden).view(B, V, P, -1)
|
||||
|
||||
return out
|
||||
679
lingbot_map/heads/dpt_head.py
Normal file
679
lingbot_map/heads/dpt_head.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
||||
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .head_act import activate_head
|
||||
from .utils import create_uv_grid, position_grid_to_embed
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
"""
|
||||
DPT Head for dense prediction tasks.
|
||||
|
||||
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
||||
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
||||
backbone and produces dense predictions by fusing multi-scale features.
|
||||
|
||||
Args:
|
||||
dim_in (int): Input dimension (channels).
|
||||
patch_size (int, optional): Patch size. Default is 14.
|
||||
output_dim (int, optional): Number of output channels. Default is 4.
|
||||
activation (str, optional): Activation type. Default is "inv_log".
|
||||
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
||||
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
||||
out_channels (List[int], optional): Output channels for each intermediate layer.
|
||||
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
||||
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
||||
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
||||
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 4,
|
||||
activation: str = "inv_log",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: List[int] = [256, 512, 1024, 1024],
|
||||
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
|
||||
pos_embed: bool = True,
|
||||
feature_only: bool = False,
|
||||
down_ratio: int = 1,
|
||||
) -> None:
|
||||
super(DPTHead, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.feature_only = feature_only
|
||||
self.down_ratio = down_ratio
|
||||
self.intermediate_layer_idx = intermediate_layer_idx
|
||||
|
||||
self.norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Projection layers for each output channel from tokens.
|
||||
self.projects = nn.ModuleList(
|
||||
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
||||
)
|
||||
|
||||
# Resize layers for upsampling feature maps.
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, expand=False)
|
||||
|
||||
# Attach additional modules to scratch.
|
||||
self.scratch.stem_transpose = None
|
||||
self.scratch.refinenet1 = _make_fusion_block(features)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if feature_only:
|
||||
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv2_in_channels = head_features_1 // 2
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
aggregated_tokens_list: List[torch.Tensor],
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
frames_chunk_size: int = 8,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the DPT head, supports processing by chunking frames.
|
||||
Args:
|
||||
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
||||
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
||||
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
||||
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
||||
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
||||
If None or larger than S, all frames are processed at once. Default: 8.
|
||||
|
||||
Returns:
|
||||
Tensor or Tuple[Tensor, Tensor]:
|
||||
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
||||
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
||||
"""
|
||||
B, _, _, H, W = images.shape
|
||||
|
||||
S = aggregated_tokens_list[0].shape[1]
|
||||
|
||||
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
||||
if frames_chunk_size is None or frames_chunk_size >= S:
|
||||
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
||||
|
||||
# Otherwise, process frames in chunks to manage memory usage
|
||||
assert frames_chunk_size > 0
|
||||
|
||||
# Process frames in batches
|
||||
all_preds = []
|
||||
all_conf = []
|
||||
|
||||
for frames_start_idx in range(0, S, frames_chunk_size):
|
||||
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
||||
|
||||
# Process batch of frames
|
||||
if self.feature_only:
|
||||
chunk_output = self._forward_impl(
|
||||
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
||||
)
|
||||
all_preds.append(chunk_output)
|
||||
else:
|
||||
chunk_preds, chunk_conf = self._forward_impl(
|
||||
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
||||
)
|
||||
all_preds.append(chunk_preds)
|
||||
all_conf.append(chunk_conf)
|
||||
|
||||
# Concatenate results along the sequence dimension
|
||||
if self.feature_only:
|
||||
return torch.cat(all_preds, dim=1)
|
||||
else:
|
||||
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
aggregated_tokens_list: List[torch.Tensor],
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
frames_start_idx: int = None,
|
||||
frames_end_idx: int = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Implementation of the forward pass through the DPT head.
|
||||
|
||||
This method processes a specific chunk of frames from the sequence.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
||||
images (Tensor): Input images with shape [B, S, 3, H, W].
|
||||
patch_start_idx (int): Starting index for patch tokens.
|
||||
frames_start_idx (int, optional): Starting index for frames to process.
|
||||
frames_end_idx (int, optional): Ending index for frames to process.
|
||||
|
||||
Returns:
|
||||
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
||||
"""
|
||||
|
||||
B, _, _, H, W = images.shape
|
||||
|
||||
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
||||
|
||||
out = []
|
||||
dpt_idx = 0
|
||||
|
||||
for layer_idx in self.intermediate_layer_idx:
|
||||
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
||||
|
||||
|
||||
|
||||
if frames_start_idx is not None and frames_end_idx is not None:
|
||||
x = x[:, frames_start_idx:frames_end_idx]
|
||||
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
|
||||
x = x.reshape(B * S, -1, x.shape[-1])
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[dpt_idx](x)
|
||||
if self.pos_embed:
|
||||
x = self._apply_pos_embed(x, W, H)
|
||||
x = self.resize_layers[dpt_idx](x)
|
||||
|
||||
out.append(x)
|
||||
dpt_idx += 1
|
||||
|
||||
# Fuse features from multiple layers.
|
||||
out = self.scratch_forward(out)
|
||||
# Interpolate fused output to match target image resolution.
|
||||
out = custom_interpolate(
|
||||
out,
|
||||
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
if self.pos_embed:
|
||||
out = self._apply_pos_embed(out, W, H)
|
||||
|
||||
if self.feature_only:
|
||||
return out.view(B, S, *out.shape[1:])
|
||||
|
||||
out = self.scratch.output_conv2(out)
|
||||
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
||||
|
||||
preds = preds.view(B, S, *preds.shape[1:])
|
||||
conf = conf.view(B, S, *conf.shape[1:])
|
||||
return preds, conf
|
||||
|
||||
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""
|
||||
Apply positional embedding to tensor x.
|
||||
"""
|
||||
patch_w = x.shape[-1]
|
||||
patch_h = x.shape[-2]
|
||||
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
||||
pos_embed = pos_embed * ratio
|
||||
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
||||
return x + pos_embed
|
||||
|
||||
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the fusion blocks.
|
||||
|
||||
Args:
|
||||
features (List[Tensor]): List of feature maps from different layers.
|
||||
|
||||
Returns:
|
||||
Tensor: Fused feature map.
|
||||
"""
|
||||
layer_1, layer_2, layer_3, layer_4 = features
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
del layer_4_rn, layer_4
|
||||
|
||||
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
del layer_3_rn, layer_3
|
||||
|
||||
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
del layer_2_rn, layer_2
|
||||
|
||||
out = self.scratch.refinenet1(out, layer_1_rn)
|
||||
del layer_1_rn, layer_1
|
||||
|
||||
out = self.scratch.output_conv1(out)
|
||||
return out
|
||||
|
||||
|
||||
################################################################################
|
||||
# Modules
|
||||
################################################################################
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(inplace=True),
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
has_residual=has_residual,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
||||
scratch = nn.Module()
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn, groups=1):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
self.groups = groups
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.norm1 = None
|
||||
self.norm2 = None
|
||||
|
||||
self.activation = activation
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.norm1 is not None:
|
||||
out = self.norm1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.norm2 is not None:
|
||||
out = self.norm2(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=None,
|
||||
has_residual=True,
|
||||
groups=1,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
self.groups = groups
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand == True:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
||||
|
||||
self.has_residual = has_residual
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if self.has_residual:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Tuple[int, int] = None,
|
||||
scale_factor: float = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
||||
"""
|
||||
if size is None:
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
|
||||
INT_MAX = 1610612736
|
||||
|
||||
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
|
||||
if input_elements > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
||||
interpolated_chunks = [
|
||||
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
||||
]
|
||||
x = torch.cat(interpolated_chunks, dim=0)
|
||||
return x.contiguous()
|
||||
else:
|
||||
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
class DPTHead_Update(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
features=256,
|
||||
use_bn=False,
|
||||
out_channels=[256, 512, 1024, 1024],
|
||||
use_clstoken=False
|
||||
):
|
||||
super(DPTHead_Update, self).__init__()
|
||||
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(2 * in_channels, in_channels),
|
||||
nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
if return_intermediate:
|
||||
return out, path_1, path_2, path_3, path_4
|
||||
else:
|
||||
out = self.scratch.output_conv2(out)
|
||||
return out
|
||||
|
||||
def _make_fusion_block_slam(features, use_bn, size=None):
|
||||
return FeatureFusionBlock_slam(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class FeatureFusionBlock_slam(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=None
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_slam, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand == True:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size=size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
125
lingbot_map/heads/head_act.py
Normal file
125
lingbot_map/heads/head_act.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
||||
"""
|
||||
Activate pose parameters with specified activation functions.
|
||||
|
||||
Args:
|
||||
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
||||
trans_act: Activation type for translation component
|
||||
quat_act: Activation type for quaternion component
|
||||
fl_act: Activation type for focal length component
|
||||
|
||||
Returns:
|
||||
Activated pose parameters tensor
|
||||
"""
|
||||
T = pred_pose_enc[..., :3]
|
||||
quat = pred_pose_enc[..., 3:7]
|
||||
fl = pred_pose_enc[..., 7:] # or fov
|
||||
|
||||
T = base_pose_act(T, trans_act)
|
||||
quat = base_pose_act(quat, quat_act)
|
||||
fl = base_pose_act(fl, fl_act) # or fov
|
||||
|
||||
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
||||
|
||||
return pred_pose_enc
|
||||
|
||||
|
||||
def base_pose_act(pose_enc, act_type="linear"):
|
||||
"""
|
||||
Apply basic activation function to pose parameters.
|
||||
|
||||
Args:
|
||||
pose_enc: Tensor containing encoded pose parameters
|
||||
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
||||
|
||||
Returns:
|
||||
Activated pose parameters
|
||||
"""
|
||||
if act_type == "linear":
|
||||
return pose_enc
|
||||
elif act_type == "inv_log":
|
||||
return inverse_log_transform(pose_enc)
|
||||
elif act_type == "exp":
|
||||
return torch.exp(pose_enc)
|
||||
elif act_type == "relu":
|
||||
return F.relu(pose_enc)
|
||||
else:
|
||||
raise ValueError(f"Unknown act_type: {act_type}")
|
||||
|
||||
|
||||
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
||||
"""
|
||||
Process network output to extract 3D points and confidence values.
|
||||
|
||||
Args:
|
||||
out: Network output tensor (B, C, H, W)
|
||||
activation: Activation type for 3D points
|
||||
conf_activation: Activation type for confidence values
|
||||
|
||||
Returns:
|
||||
Tuple of (3D points tensor, confidence tensor)
|
||||
"""
|
||||
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
||||
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
||||
|
||||
# Split into xyz (first C-1 channels) and confidence (last channel)
|
||||
xyz = fmap[:, :, :, :-1]
|
||||
conf = fmap[:, :, :, -1]
|
||||
|
||||
if activation == "norm_exp":
|
||||
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
xyz_normed = xyz / d
|
||||
pts3d = xyz_normed * torch.expm1(d)
|
||||
elif activation == "norm":
|
||||
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
||||
elif activation == "exp":
|
||||
pts3d = torch.exp(xyz)
|
||||
elif activation == "relu":
|
||||
pts3d = F.relu(xyz)
|
||||
elif activation == "inv_log":
|
||||
pts3d = inverse_log_transform(xyz)
|
||||
elif activation == "xy_inv_log":
|
||||
xy, z = xyz.split([2, 1], dim=-1)
|
||||
z = inverse_log_transform(z)
|
||||
pts3d = torch.cat([xy * z, z], dim=-1)
|
||||
elif activation == "sigmoid":
|
||||
pts3d = torch.sigmoid(xyz)
|
||||
elif activation == "linear":
|
||||
pts3d = xyz
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
if conf_activation == "expp1":
|
||||
conf_out = 1 + conf.exp()
|
||||
elif conf_activation == "expp0":
|
||||
conf_out = conf.exp()
|
||||
elif conf_activation == "sigmoid":
|
||||
conf_out = torch.sigmoid(conf)
|
||||
else:
|
||||
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
||||
|
||||
return pts3d, conf_out
|
||||
|
||||
|
||||
def inverse_log_transform(y):
|
||||
"""
|
||||
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
||||
|
||||
Args:
|
||||
y: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed tensor
|
||||
"""
|
||||
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
||||
109
lingbot_map/heads/utils.py
Normal file
109
lingbot_map/heads/utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
||||
"""
|
||||
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
||||
|
||||
Args:
|
||||
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
||||
embed_dim: Output channel dimension for embeddings
|
||||
|
||||
Returns:
|
||||
Tensor of shape (H, W, embed_dim) with positional embeddings
|
||||
"""
|
||||
H, W, grid_dim = pos_grid.shape
|
||||
assert grid_dim == 2
|
||||
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
||||
|
||||
# Process x and y coordinates separately
|
||||
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
||||
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
||||
|
||||
# Combine and reshape
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
||||
|
||||
return emb.view(H, W, embed_dim) # [H, W, D]
|
||||
|
||||
|
||||
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
device = pos.device
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / omega_0**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb.float()
|
||||
|
||||
|
||||
# Inspired by https://github.com/microsoft/moge
|
||||
|
||||
|
||||
def create_uv_grid(
|
||||
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a normalized UV grid of shape (width, height, 2).
|
||||
|
||||
The grid spans horizontally and vertically according to an aspect ratio,
|
||||
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
||||
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
||||
|
||||
Args:
|
||||
width (int): Number of points horizontally.
|
||||
height (int): Number of points vertically.
|
||||
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
||||
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
||||
device (torch.device, optional): Device on which the tensor is created.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
||||
"""
|
||||
# Derive aspect ratio if not explicitly provided
|
||||
if aspect_ratio is None:
|
||||
aspect_ratio = float(width) / float(height)
|
||||
|
||||
# Compute normalized spans for X and Y
|
||||
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
|
||||
# Establish the linspace boundaries
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
|
||||
# Generate 1D coordinates
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
|
||||
# Create 2D meshgrid (width x height) and stack into UV
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
uv_grid = torch.stack((uu, vv), dim=-1)
|
||||
|
||||
return uv_grid
|
||||
Reference in New Issue
Block a user