first commit
This commit is contained in:
125
lingbot_map/heads/head_act.py
Normal file
125
lingbot_map/heads/head_act.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
||||
"""
|
||||
Activate pose parameters with specified activation functions.
|
||||
|
||||
Args:
|
||||
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
||||
trans_act: Activation type for translation component
|
||||
quat_act: Activation type for quaternion component
|
||||
fl_act: Activation type for focal length component
|
||||
|
||||
Returns:
|
||||
Activated pose parameters tensor
|
||||
"""
|
||||
T = pred_pose_enc[..., :3]
|
||||
quat = pred_pose_enc[..., 3:7]
|
||||
fl = pred_pose_enc[..., 7:] # or fov
|
||||
|
||||
T = base_pose_act(T, trans_act)
|
||||
quat = base_pose_act(quat, quat_act)
|
||||
fl = base_pose_act(fl, fl_act) # or fov
|
||||
|
||||
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
||||
|
||||
return pred_pose_enc
|
||||
|
||||
|
||||
def base_pose_act(pose_enc, act_type="linear"):
|
||||
"""
|
||||
Apply basic activation function to pose parameters.
|
||||
|
||||
Args:
|
||||
pose_enc: Tensor containing encoded pose parameters
|
||||
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
||||
|
||||
Returns:
|
||||
Activated pose parameters
|
||||
"""
|
||||
if act_type == "linear":
|
||||
return pose_enc
|
||||
elif act_type == "inv_log":
|
||||
return inverse_log_transform(pose_enc)
|
||||
elif act_type == "exp":
|
||||
return torch.exp(pose_enc)
|
||||
elif act_type == "relu":
|
||||
return F.relu(pose_enc)
|
||||
else:
|
||||
raise ValueError(f"Unknown act_type: {act_type}")
|
||||
|
||||
|
||||
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
||||
"""
|
||||
Process network output to extract 3D points and confidence values.
|
||||
|
||||
Args:
|
||||
out: Network output tensor (B, C, H, W)
|
||||
activation: Activation type for 3D points
|
||||
conf_activation: Activation type for confidence values
|
||||
|
||||
Returns:
|
||||
Tuple of (3D points tensor, confidence tensor)
|
||||
"""
|
||||
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
||||
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
||||
|
||||
# Split into xyz (first C-1 channels) and confidence (last channel)
|
||||
xyz = fmap[:, :, :, :-1]
|
||||
conf = fmap[:, :, :, -1]
|
||||
|
||||
if activation == "norm_exp":
|
||||
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
xyz_normed = xyz / d
|
||||
pts3d = xyz_normed * torch.expm1(d)
|
||||
elif activation == "norm":
|
||||
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
||||
elif activation == "exp":
|
||||
pts3d = torch.exp(xyz)
|
||||
elif activation == "relu":
|
||||
pts3d = F.relu(xyz)
|
||||
elif activation == "inv_log":
|
||||
pts3d = inverse_log_transform(xyz)
|
||||
elif activation == "xy_inv_log":
|
||||
xy, z = xyz.split([2, 1], dim=-1)
|
||||
z = inverse_log_transform(z)
|
||||
pts3d = torch.cat([xy * z, z], dim=-1)
|
||||
elif activation == "sigmoid":
|
||||
pts3d = torch.sigmoid(xyz)
|
||||
elif activation == "linear":
|
||||
pts3d = xyz
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
if conf_activation == "expp1":
|
||||
conf_out = 1 + conf.exp()
|
||||
elif conf_activation == "expp0":
|
||||
conf_out = conf.exp()
|
||||
elif conf_activation == "sigmoid":
|
||||
conf_out = torch.sigmoid(conf)
|
||||
else:
|
||||
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
||||
|
||||
return pts3d, conf_out
|
||||
|
||||
|
||||
def inverse_log_transform(y):
|
||||
"""
|
||||
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
||||
|
||||
Args:
|
||||
y: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed tensor
|
||||
"""
|
||||
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
||||
Reference in New Issue
Block a user