first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

View File

@@ -0,0 +1,774 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Rotation
try:
from lietorch import SE3, Sim3
except ImportError:
SE3 = Sim3 = None
import torch.nn.functional as F
try:
from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
except ImportError:
apply_distortion = iterative_undistortion = single_undistortion = None
def unproject_depth_map_to_point_map(
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
) -> np.ndarray:
"""
Unproject a batch of depth maps to 3D world coordinates.
Args:
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
Returns:
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
"""
if isinstance(depth_map, torch.Tensor):
depth_map = depth_map.cpu().numpy()
if isinstance(extrinsics_cam, torch.Tensor):
extrinsics_cam = extrinsics_cam.cpu().numpy()
if isinstance(intrinsics_cam, torch.Tensor):
intrinsics_cam = intrinsics_cam.cpu().numpy()
world_points_list = []
for frame_idx in range(depth_map.shape[0]):
cur_world_points, _, _ = depth_to_world_coords_points(
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
)
world_points_list.append(cur_world_points)
world_points_array = np.stack(world_points_list, axis=0)
return world_points_array
def depth_to_world_coords_points(
depth_map: np.ndarray,
extrinsic: np.ndarray,
intrinsic: np.ndarray,
eps=1e-8,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Convert a depth map to world coordinates.
Args:
depth_map (np.ndarray): Depth map of shape (H, W).
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
Returns:
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
"""
if depth_map is None:
return None, None, None
# Valid depth mask
point_mask = depth_map > eps
# Convert depth map to camera coordinates
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
# Apply the rotation and translation to the camera coordinates
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
return world_coords_points, cam_coords_points, point_mask
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Convert a depth map to camera coordinates.
Args:
depth_map (np.ndarray): Depth map of shape (H, W).
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
Returns:
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
"""
H, W = depth_map.shape
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
# Intrinsic parameters
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
# Generate grid of pixel coordinates
u, v = np.meshgrid(np.arange(W), np.arange(H))
# Unproject to camera coordinates
x_cam = (u - cu) * depth_map / fu
y_cam = (v - cv) * depth_map / fv
z_cam = depth_map
# Stack to form camera coordinates
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
return cam_coords
def closed_form_inverse_se3(se3, R=None, T=None):
"""
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
If `R` and `T` are provided, they must correspond to the rotation and translation
components of `se3`. Otherwise, they will be extracted from `se3`.
Args:
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
R (optional): Nx3x3 array or tensor of rotation matrices.
T (optional): Nx3x1 array or tensor of translation vectors.
Returns:
Inverted SE3 matrices with the same type and device as `se3`.
Shapes:
se3: (N, 4, 4)
R: (N, 3, 3)
T: (N, 3, 1)
"""
# Check if se3 is a numpy array or a torch tensor
is_numpy = isinstance(se3, np.ndarray)
# Validate shapes
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
# Extract R and T if not provided
if R is None:
R = se3[:, :3, :3] # (N,3,3)
if T is None:
T = se3[:, :3, 3:] # (N,3,1)
# Transpose R
if is_numpy:
# Compute the transpose of the rotation for NumPy
R_transposed = np.transpose(R, (0, 2, 1))
# -R^T t for NumPy
top_right = -np.matmul(R_transposed, T)
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
else:
R_transposed = R.transpose(1, 2) # (N,3,3)
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
inverted_matrix[:, :3, :3] = R_transposed
inverted_matrix[:, :3, 3:] = top_right
return inverted_matrix
def closed_form_inverse_se3_general(se3, R=None, T=None):
"""
支持任意 batch 维度的 SE3 逆运算
se3: (..., 4, 4) 或 (..., 3, 4)
"""
batch_shape = se3.shape[:-2]
if R is None:
R = se3[..., :3, :3]
if T is None:
T = se3[..., :3, 3:]
R_transposed = R.transpose(-2, -1)
top_right = -R_transposed @ T
# 构造单位阵
eye = torch.eye(4, 4, dtype=R.dtype, device=R.device)
inverted_matrix = eye.expand(*batch_shape, 4, 4).clone()
inverted_matrix[..., :3, :3] = R_transposed
inverted_matrix[..., :3, 3:] = top_right
return inverted_matrix
# TODO: this code can be further cleaned up
def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
"""
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
Args:
world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
Returns:
"""
# TODO: merge this into project_world_points_to_cam
# device = world_points.device
# with torch.autocast(device_type=device.type, enabled=False):
ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
# extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
# world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
world_points_h_exp = world_points_h.unsqueeze(-1)
# Now perform the matrix multiplication
# (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
return camera_points
def project_world_points_to_cam(
world_points,
cam_extrinsics,
cam_intrinsics=None,
distortion_params=None,
default=0,
only_points_cam=False,
):
"""
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
Args:
world_points (torch.Tensor): 3D points of shape Px3.
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
Returns:
torch.Tensor: Transformed 2D points of shape BxNx2.
"""
device = world_points.device
# with torch.autocast(device_type=device.type, dtype=torch.double):
with torch.autocast(device_type=device.type, enabled=False):
N = world_points.shape[0] # Number of points
B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
world_points_homogeneous = torch.cat(
[world_points, torch.ones_like(world_points[..., 0:1])], dim=1
) # Nx4
# Reshape for batch processing
world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
B, -1, -1
) # BxNx4
# Step 1: Apply extrinsic parameters
# Transform 3D points to camera coordinate system for all cameras
cam_points = torch.bmm(
cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
)
if only_points_cam:
return None, cam_points
# Step 2: Apply intrinsic parameters and (optional) distortion
image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
return image_points, cam_points
def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
"""
Applies intrinsic parameters and optional distortion to the given 3D points.
Args:
cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
default (float, optional): Default value to replace NaNs in the output.
Returns:
pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
"""
# Normalized device coordinates (NDC)
cam_points = cam_points / cam_points[:, 2:3, :]
ndc_xy = cam_points[:, :2, :]
# Apply distortion if distortion_params are provided
if distortion_params is not None:
x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
else:
distorted_xy = ndc_xy
# Prepare cam_points for batch matrix multiplication
cam_coords_homo = torch.cat(
(distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
) # Bx3xN
# Apply intrinsic parameters using batch matrix multiplication
pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
# Extract x and y coordinates
pixel_coords = pixel_coords[:, :2, :] # Bx2xN
# Replace NaNs with default value
pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
return pixel_coords.transpose(1, 2) # BxNx2
def cam_from_img(pred_tracks, intrinsics, extra_params=None):
"""
Normalize predicted tracks based on camera intrinsics.
Args:
intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
Returns:
torch.Tensor: Normalized tracks tensor.
"""
# We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
# otherwise we can use something like
# tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
tracks_normalized = (pred_tracks - principal_point) / focal_length
if extra_params is not None:
# Apply iterative undistortion
try:
tracks_normalized = iterative_undistortion(
extra_params, tracks_normalized
)
except:
tracks_normalized = single_undistortion(
extra_params, tracks_normalized
)
return tracks_normalized
## Droid SLAM Part
MIN_DEPTH = 0.2
def extract_intrinsics(intrinsics):
return intrinsics[...,None,None,:].unbind(dim=-1)
def projective_transform(
poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False
):
"""map points from ii->jj"""
# inverse project (pinhole)
X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian)
# transform
Gij = poses[:, jj] * poses[:, ii].inv()
# Gij.data[:, ii == jj] = torch.as_tensor(
# [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda"
# )
X1, Ja = actp(Gij, X0, jacobian=jacobian)
# project (pinhole)
x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth)
# exclude points too close to camera
valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float()
valid = valid.unsqueeze(-1)
if jacobian:
# Ji transforms according to dual adjoint
Jj = torch.matmul(Jp, Ja)
Ji = -Gij[:, :, None, None, None].adjT(Jj)
Jz = Gij[:, :, None, None] * Jz
Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
return x1, valid, (Ji, Jj, Jz)
return x1, valid
def induced_flow(poses, disps, intrinsics, ii, jj):
"""optical flow induced by camera motion"""
ht, wd = disps.shape[2:]
y, x = torch.meshgrid(
torch.arange(ht, device=disps.device, dtype=torch.float),
torch.arange(wd, device=disps.device, dtype=torch.float),
indexing="ij",
)
coords0 = torch.stack([x, y], dim=-1)
coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
return coords1[..., :2] - coords0, valid
def all_pairs_distance_matrix(poses, beta=2.5):
""" compute distance matrix between all pairs of poses """
poses = np.array(poses, dtype=np.float32)
poses[:,:3] *= beta # scale to balence rot + trans
poses = SE3(torch.from_numpy(poses))
r = (poses[:,None].inv() * poses[None,:]).log()
return r.norm(dim=-1).cpu().numpy()
def pose_matrix_to_quaternion(pose):
""" convert 4x4 pose matrix to (t, q) """
q = Rotation.from_matrix(pose[..., :3, :3]).as_quat()
return np.concatenate([pose[..., :3, 3], q], axis=-1)
def compute_distance_matrix_flow(poses, disps, intrinsics):
""" compute flow magnitude between all pairs of frames """
if not isinstance(poses, SE3):
poses = torch.from_numpy(poses).float().cuda()[None]
poses = SE3(poses).inv()
disps = torch.from_numpy(disps).float().cuda()[None]
intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
N = poses.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1).cuda()
jj = jj.reshape(-1).cuda()
MAX_FLOW = 100.0
matrix = np.zeros((N, N), dtype=np.float32)
s = 2048
for i in range(0, ii.shape[0], s):
flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
flow = torch.stack([flow1, flow2], dim=2)
val = torch.stack([val1, val2], dim=2)
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
mag = mag.view(mag.shape[1], -1)
val = val.view(val.shape[1], -1)
mag = (mag * val).mean(-1) / val.mean(-1)
mag[val.mean(-1) < 0.7] = np.inf
i1 = ii[i:i+s].cpu().numpy()
j1 = jj[i:i+s].cpu().numpy()
matrix[i1, j1] = mag.cpu().numpy()
return matrix
def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
""" compute flow magnitude between all pairs of frames """
# if not isinstance(poses, SE3):
# poses = torch.from_numpy(poses).float().cuda()[None]
# poses = SE3(poses).inv()
# disps = torch.from_numpy(disps).float().cuda()[None]
# intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
N = poses.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1)
jj = jj.reshape(-1)
MAX_FLOW = 128.0
matrix = np.zeros((N, N), dtype=np.float32)
s = 2048
for i in range(0, ii.shape[0], s):
flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow1 = flow1a + beta * flow1b
val1 = val1a * val2b
flow2 = flow2a + beta * flow2b
val2 = val2a * val2b
flow = torch.stack([flow1, flow2], dim=2)
val = torch.stack([val1, val2], dim=2)
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
mag = mag.view(mag.shape[1], -1)
val = val.view(val.shape[1], -1)
mag = (mag * val).mean(-1) / val.mean(-1)
mag[val.mean(-1) < 0.8] = np.inf
i1 = ii[i:i+s].cpu().numpy()
j1 = jj[i:i+s].cpu().numpy()
matrix[i1, j1] = mag.cpu().numpy()
return matrix
def coords_grid(ht, wd, **kwargs):
y, x = torch.meshgrid(
torch.arange(ht, dtype=torch.float, **kwargs),
torch.arange(wd, dtype=torch.float, **kwargs),
indexing="ij",
)
return torch.stack([x, y], dim=-1)
def iproj(disps, intrinsics, jacobian=False):
"""pinhole camera inverse projection"""
ht, wd = disps.shape[2:]
fx, fy, cx, cy = extract_intrinsics(intrinsics)
y, x = torch.meshgrid(
torch.arange(ht, device=disps.device, dtype=torch.float),
torch.arange(wd, device=disps.device, dtype=torch.float),
indexing="ij",
)
i = torch.ones_like(disps)
X = (x - cx) / fx
Y = (y - cy) / fy
pts = torch.stack([X, Y, i, disps], dim=-1)
if jacobian:
J = torch.zeros_like(pts)
J[..., -1] = 1.0
return pts, J
return pts, None
def proj(Xs, intrinsics, jacobian=False, return_depth=False):
"""pinhole camera projection"""
fx, fy, cx, cy = extract_intrinsics(intrinsics)
X, Y, Z, D = Xs.unbind(dim=-1)
Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z)
d = 1.0 / Z
x = fx * (X * d) + cx
y = fy * (Y * d) + cy
if return_depth:
coords = torch.stack([x, y, D * d], dim=-1)
else:
coords = torch.stack([x, y], dim=-1)
if jacobian:
B, N, H, W = d.shape
o = torch.zeros_like(d)
proj_jac = torch.stack(
[
fx * d,
o,
-fx * X * d * d,
o,
o,
fy * d,
-fy * Y * d * d,
o,
# o, o, -D*d*d, d,
],
dim=-1,
).view(B, N, H, W, 2, 4)
return coords, proj_jac
return coords, None
def actp(Gij, X0, jacobian=False):
"""action on point cloud"""
X1 = Gij[:, :, None, None] * X0
if jacobian:
X, Y, Z, d = X1.unbind(dim=-1)
o = torch.zeros_like(d)
B, N, H, W = d.shape
if isinstance(Gij, SE3):
Ja = torch.stack(
[
d,
o,
o,
o,
Z,
-Y,
o,
d,
o,
-Z,
o,
X,
o,
o,
d,
Y,
-X,
o,
o,
o,
o,
o,
o,
o,
],
dim=-1,
).view(B, N, H, W, 4, 6)
elif isinstance(Gij, Sim3):
Ja = torch.stack(
[
d,
o,
o,
o,
Z,
-Y,
X,
o,
d,
o,
-Z,
o,
X,
Y,
o,
o,
d,
Y,
-X,
o,
Z,
o,
o,
o,
o,
o,
o,
o,
],
dim=-1,
).view(B, N, H, W, 4, 7)
return X1, Ja
return X1, None
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
return standardize_quaternion(out)
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def umeyama(X, Y):
"""
Estimates the Sim(3) transformation between `X` and `Y` point sets.
Estimates c, R and t such as c * R @ X + t ~ Y.
Parameters
----------
X : numpy.array
(m, n) shaped numpy array. m is the dimension of the points,
n is the number of points in the point set.
Y : numpy.array
(m, n) shaped numpy array. Indexes should be consistent with `X`.
That is, Y[:, i] must be the point corresponding to X[:, i].
Returns
-------
c : float
Scale factor.
R : numpy.array
(3, 3) shaped rotation matrix.
t : numpy.array
(3, 1) shaped translation vector.
"""
mu_x = X.mean(axis=1).reshape(-1, 1)
mu_y = Y.mean(axis=1).reshape(-1, 1)
var_x = np.square(X - mu_x).sum(axis=0).mean()
cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1]
U, D, VH = np.linalg.svd(cov_xy)
S = np.eye(X.shape[0])
if np.linalg.det(U) * np.linalg.det(VH) < 0:
S[-1, -1] = -1
c = np.trace(np.diag(D) @ S) / var_x
R = U @ S @ VH
t = mu_y - c * R @ mu_x
return c, R, t

View File

@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from PIL import Image
from torchvision import transforms as TF
import numpy as np
def load_and_preprocess_images_square(image_path_list, target_size=1024):
"""
Load and preprocess images by center padding to square and resizing to target size.
Also returns the position information of original pixels after transformation.
Args:
image_path_list (list): List of paths to image files
target_size (int, optional): Target size for both width and height. Defaults to 518.
Returns:
tuple: (
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
)
Raises:
ValueError: If the input list is empty
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
images = []
original_coords = [] # Renamed from position_info to be more descriptive
to_tensor = TF.ToTensor()
for image_path in image_path_list:
# Open image
img = Image.open(image_path)
# If there's an alpha channel, blend onto white background
if img.mode == "RGBA":
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
img = Image.alpha_composite(background, img)
# Convert to RGB
img = img.convert("RGB")
# Get original dimensions
width, height = img.size
# Make the image square by padding the shorter dimension
max_dim = max(width, height)
# Calculate padding
left = (max_dim - width) // 2
top = (max_dim - height) // 2
# Calculate scale factor for resizing
scale = target_size / max_dim
# Calculate final coordinates of original image in target space
x1 = left * scale
y1 = top * scale
x2 = (left + width) * scale
y2 = (top + height) * scale
# Store original image coordinates and scale
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
# Create a new black square image and paste original
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
square_img.paste(img, (left, top))
# Resize to target size
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
# Convert to tensor
img_tensor = to_tensor(square_img)
images.append(img_tensor)
# Stack all images
images = torch.stack(images)
original_coords = torch.from_numpy(np.array(original_coords)).float()
# Add additional dimension if single image to ensure correct shape
if len(image_path_list) == 1:
if images.dim() == 3:
images = images.unsqueeze(0)
original_coords = original_coords.unsqueeze(0)
return images, original_coords
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
"""
A quick start function to load and preprocess images for model input.
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
Args:
image_path_list (list): List of paths to image files
mode (str, optional): Preprocessing mode, either "crop" or "pad".
- "crop" (default): Sets width to 518px and center crops height if needed.
- "pad": Preserves all pixels by making the largest dimension 518px
and padding the smaller dimension to reach a square shape.
Returns:
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
Raises:
ValueError: If the input list is empty or if mode is invalid
Notes:
- Images with different dimensions will be padded with white (value=1.0)
- A warning is printed when images have different shapes
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
and height is center-cropped if larger than 518px
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
and the smaller dimension is padded to reach a square shape (518x518)
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
# Validate mode
if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'")
images = []
shapes = set()
to_tensor = TF.ToTensor()
target_size = image_size
# First process all images and collect their shapes
for i, image_path in enumerate(image_path_list):
# Open image
img = Image.open(image_path)
# If there's an alpha channel, blend onto white background:
if img.mode == "RGBA":
# Create white background
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
# Alpha composite onto the white background
img = Image.alpha_composite(background, img)
# Now convert to "RGB" (this step assigns white for transparent areas)
img = img.convert("RGB")
width, height = img.size
if fx is not None:
fx[i] = fx[i] * width
fy[i] = fy[i] * height
cx[i] = cx[i] * width
cy[i] = cy[i] * height
if mode == "pad":
# Make the largest dimension 518px while maintaining aspect ratio
if width >= height:
new_width = target_size
new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14
else:
new_height = target_size
new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14
else: # mode == "crop"
# Original behavior: set width to 518px
new_width = target_size
# Calculate height maintaining aspect ratio, divisible by 14
new_height = round(height * (new_width / width) / patch_size) * patch_size
# Resize with new dimensions (width, height)
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
img = to_tensor(img) # Convert to tensor (0, 1)
# Center crop height if it's larger than 518 (only in crop mode)
if mode == "crop" and new_height > target_size:
start_y = (new_height - target_size) // 2
img = img[:, start_y : start_y + target_size, :]
if fx is not None:
fx[i] = fx[i] * new_width / width
fy[i] = fy[i] * new_height / height
cx[i] = img.shape[2] / 2
cy[i] = img.shape[1] / 2
# For pad mode, pad to make a square of target_size x target_size
if mode == "pad":
h_padding = target_size - img.shape[1]
w_padding = target_size - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
# Pad with white (value=1.0)
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
shapes.add((img.shape[1], img.shape[2]))
images.append(img)
# Check if we have different shapes
# In theory our model can also work well with different shapes
if len(shapes) > 1:
print(f"Warning: Found images with different shapes: {shapes}")
# Find maximum dimensions
max_height = max(shape[0] for shape in shapes)
max_width = max(shape[1] for shape in shapes)
# Pad images if necessary
padded_images = []
for img in images:
h_padding = max_height - img.shape[1]
w_padding = max_width - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
padded_images.append(img)
images = padded_images
images = torch.stack(images) # concatenate images
# Ensure correct shape when single image
if len(image_path_list) == 1:
# Verify shape is (1, C, H, W)
if images.dim() == 3:
images = images.unsqueeze(0)
if fx is not None:
return images, fx, fy, cx, cy
return images

View File

@@ -0,0 +1,331 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .rotation import quat_to_mat, mat_to_quat
import os
import torch
import numpy as np
import gzip
import json
import random
import logging
import warnings
from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general
def extri_intri_to_pose_encoding(
extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
):
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
This function transforms camera parameters into a unified pose encoding format,
which can be used for various downstream tasks like pose prediction or representation.
Args:
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
where B is batch size and S is sequence length.
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
Defined in pixels, with format:
[[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]]
where fx, fy are focal lengths and (cx, cy) is the principal point
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
Required for computing field of view values. For example: (256, 512).
pose_encoding_type (str): Type of pose encoding to use. Currently only
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
Returns:
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
For "absT_quaR_FoV" type, the 9 dimensions are:
- [:3] = absolute translation vector T (3D)
- [3:7] = rotation as quaternion quat (4D)
- [7:] = field of view (2D)
"""
# extrinsics: BxSx3x4
# intrinsics: BxSx3x3
if pose_encoding_type == "absT_quaR_FoV":
R = extrinsics[:, :, :3, :3] # BxSx3x3
T = extrinsics[:, :, :3, 3] # BxSx3
quat = mat_to_quat(R)
# Note the order of h and w here
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
else:
raise NotImplementedError
return pose_encoding
def pose_encoding_to_extri_intri(
pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
):
"""Convert a pose encoding back to camera extrinsics and intrinsics.
This function performs the inverse operation of extri_intri_to_pose_encoding,
reconstructing the full camera parameters from the compact encoding.
Args:
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
where B is batch size and S is sequence length.
For "absT_quaR_FoV" type, the 9 dimensions are:
- [:3] = absolute translation vector T (3D)
- [3:7] = rotation as quaternion quat (4D)
- [7:] = field of view (2D)
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
Required for reconstructing intrinsics from field of view values.
For example: (256, 512).
pose_encoding_type (str): Type of pose encoding used. Currently only
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
If False, only extrinsics are returned and intrinsics will be None.
Returns:
tuple: (extrinsics, intrinsics)
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
a 3x1 translation vector.
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
or None if build_intrinsics is False. Defined in pixels, with format:
[[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]]
where fx, fy are focal lengths and (cx, cy) is the principal point,
assumed to be at the center of the image (W/2, H/2).
"""
intrinsics = None
if pose_encoding_type == "absT_quaR_FoV":
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
if build_intrinsics:
H, W = image_size_hw
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
elif pose_encoding_type == "absT_quaR":
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
intrinsics = None
return extrinsics, intrinsics
def convert_pt3d_RT_to_opencv(Rot, Trans):
"""
Convert Point3D extrinsic matrices to OpenCV convention.
Args:
Rot: 3D rotation matrix in Point3D format
Trans: 3D translation vector in Point3D format
Returns:
extri_opencv: 3x4 extrinsic matrix in OpenCV format
"""
rot_pt3d = np.array(Rot)
trans_pt3d = np.array(Trans)
trans_pt3d[:2] *= -1
rot_pt3d[:, :2] *= -1
rot_pt3d = rot_pt3d.transpose(1, 0)
extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
return extri_opencv
def build_pair_index(N, B=1):
"""
Build indices for all possible pairs of frames.
Args:
N: Number of frames
B: Batch size
Returns:
i1, i2: Indices for all possible pairs
"""
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
return i1, i2
def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
"""
Calculate rotation angle error between ground truth and predicted rotations.
Args:
rot_gt: Ground truth rotation matrices
rot_pred: Predicted rotation matrices
batch_size: Batch size for reshaping the result
eps: Small value to avoid numerical issues
Returns:
Rotation angle error in degrees
"""
q_pred = mat_to_quat(rot_pred)
q_gt = mat_to_quat(rot_gt)
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
err_q = torch.arccos(1 - 2 * loss_q)
rel_rangle_deg = err_q * 180 / np.pi
if batch_size is not None:
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
return rel_rangle_deg
def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
"""
Calculate translation angle error between ground truth and predicted translations.
Args:
tvec_gt: Ground truth translation vectors
tvec_pred: Predicted translation vectors
batch_size: Batch size for reshaping the result
ambiguity: Whether to handle direction ambiguity
Returns:
Translation angle error in degrees
"""
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
if ambiguity:
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
if batch_size is not None:
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
return rel_tangle_deg
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
"""
Normalize the translation vectors and compute the angle between them.
Args:
t_gt: Ground truth translation vectors
t: Predicted translation vectors
eps: Small value to avoid division by zero
default_err: Default error value for invalid cases
Returns:
Angular error between translation vectors in radians
"""
t_norm = torch.norm(t, dim=1, keepdim=True)
t = t / (t_norm + eps)
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
t_gt = t_gt / (t_gt_norm + eps)
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
err_t = torch.acos(torch.sqrt(1 - loss_t))
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
return err_t
def calculate_auc_np(r_error, t_error, max_threshold=30):
"""
Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
Args:
r_error: numpy array representing R error values (Degree)
t_error: numpy array representing T error values (Degree)
max_threshold: Maximum threshold value for binning the histogram
Returns:
AUC value and the normalized histogram
"""
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
max_errors = np.max(error_matrix, axis=1)
bins = np.arange(max_threshold + 1)
histogram, _ = np.histogram(max_errors, bins=bins)
num_pairs = float(len(max_errors))
normalized_histogram = histogram.astype(float) / num_pairs
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
"""
Compute rotation and translation errors between predicted and ground truth poses.
This function assumes the input poses are world-to-camera (w2c) transformations.
Args:
pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4)
gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4)
num_frames: Number of frames (N)
Returns:
Rotation and translation angle errors in degrees
"""
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
relative_pose_gt = gt_se3[pair_idx_i1].bmm(
closed_form_inverse_se3(gt_se3[pair_idx_i2])
)
relative_pose_pred = pred_se3[pair_idx_i1].bmm(
closed_form_inverse_se3(pred_se3[pair_idx_i2])
)
rel_rangle_deg = rotation_angle(
relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
)
rel_tangle_deg = translation_angle(
relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
)
return rel_rangle_deg, rel_tangle_deg
def colmap_to_opencv_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[..., 0, 2] -= 0.5
K[..., 1, 2] -= 0.5
return K
def read_camera_parameters(filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
return intrinsics, extrinsics

View File

@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
import torch
import numpy as np
import torch.nn.functional as F
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Quaternion Order: XYZW or say ijkr, scalar-last
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
i, j, k, r = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
q_abs = _sqrt_positive_part(
torch.stack(
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
# Convert from rijk to ijkr
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)