feat: fuse_trajectory — Umeyama weighted alignment lingbot→world + graceful fallbacks
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
183
fuse/fuse_trajectory.py
Normal file
183
fuse/fuse_trajectory.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# fuse/fuse_trajectory.py
|
||||
"""
|
||||
Umeyama-based trajectory fusion.
|
||||
Aligns lingbot local poses to world coordinates using USBL/GPS absolute fixes.
|
||||
Gracefully handles missing lingbot poses or no absolute AUV fixes.
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
|
||||
def umeyama(src: np.ndarray, dst: np.ndarray,
|
||||
weights: np.ndarray | None = None) -> tuple[float, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Weighted Umeyama: find scale, R, t such that dst ≈ scale * R @ src + t
|
||||
src, dst : (N, 3) float64
|
||||
weights : (N,) non-negative
|
||||
Raises ValueError if N < 3.
|
||||
"""
|
||||
N = len(src)
|
||||
if N < 3:
|
||||
raise ValueError(f"Umeyama requires at least 3 point pairs, got {N}")
|
||||
if weights is None:
|
||||
weights = np.ones(N, dtype=np.float64)
|
||||
w = weights / weights.sum()
|
||||
mu_s = (w[:, None] * src).sum(0)
|
||||
mu_d = (w[:, None] * dst).sum(0)
|
||||
src_c, dst_c = src - mu_s, dst - mu_d
|
||||
cov = (dst_c * w[:, None]).T @ src_c
|
||||
U, D, Vt = np.linalg.svd(cov)
|
||||
S = np.eye(3)
|
||||
if np.linalg.det(U) * np.linalg.det(Vt) < 0:
|
||||
S[2, 2] = -1
|
||||
R = U @ S @ Vt
|
||||
var_s = (w * np.sum(src_c ** 2, axis=1)).sum()
|
||||
scale = float(np.sum(D * np.diag(S)) / var_s) if var_s > 0 else 1.0
|
||||
t = mu_d - scale * R @ mu_s
|
||||
return scale, R, t
|
||||
|
||||
|
||||
def fuse(fixes_h5: str, poses_npz: str, out_h5: str,
|
||||
outlier_sigma: float = 2.0) -> None:
|
||||
|
||||
# 1. Load fixes
|
||||
with h5py.File(fixes_h5, "r") as f:
|
||||
usv_e = f["usv_gps/easting"][:]
|
||||
usv_n = f["usv_gps/northing"][:]
|
||||
usv_t = f["usv_gps/t_ns"][:]
|
||||
utm_zone = f["usv_gps"].attrs.get("utm_zone", "31T")
|
||||
|
||||
auv_lat = f["auv_mcap/lat"][:]
|
||||
auv_lon = f["auv_mcap/lon"][:]
|
||||
auv_dep = f["auv_mcap/depth_m"][:]
|
||||
auv_t = f["auv_mcap/t_ns"][:]
|
||||
|
||||
usbl_n = f["usbl_fixes/north_m"][:] if "usbl_fixes" in f else np.array([])
|
||||
usbl_e = f["usbl_fixes/east_m"][:] if "usbl_fixes" in f else np.array([])
|
||||
usbl_d = f["usbl_fixes/depth_m"][:] if "usbl_fixes" in f else np.array([])
|
||||
usbl_t = f["usbl_fixes/t_ns"][:] if "usbl_fixes" in f else np.array([], dtype=np.int64)
|
||||
|
||||
# 2. Check for lingbot poses
|
||||
poses_path = Path(poses_npz)
|
||||
if not poses_path.exists():
|
||||
print(f"WARNING: lingbot poses not found at {poses_npz}")
|
||||
print("Saving sources only — re-run after Plan 2 generates lingbot_poses.npz")
|
||||
with h5py.File(out_h5, "w") as f:
|
||||
f.attrs["status"] = "no_lingbot_poses"
|
||||
f.attrs["utm_zone"] = utm_zone
|
||||
return
|
||||
|
||||
# 3. Load lingbot poses
|
||||
data = np.load(poses_npz)
|
||||
poses_34 = data["poses"] # (N, 3, 4)
|
||||
pose_t = data["timestamps_ns"] # (N,) int64
|
||||
ling_xyz = poses_34[:, :3, 3] # camera positions in local frame
|
||||
|
||||
# 4. Build absolute AUV reference points
|
||||
has_usbl_fixes = (len(usbl_n) > 0 and
|
||||
not np.all(usbl_n == 0) and
|
||||
not np.all(usbl_e == 0))
|
||||
|
||||
has_auv_gps = not np.all(auv_lat == 0.0)
|
||||
|
||||
if not has_usbl_fixes and not has_auv_gps:
|
||||
print("WARNING: No absolute AUV position fixes available.")
|
||||
print("Cannot perform Umeyama alignment — saving local lingbot trajectory only.")
|
||||
_save_local_only(out_h5, pose_t, poses_34, ling_xyz, utm_zone)
|
||||
return
|
||||
|
||||
# 5. Build world reference: interpolate AUV absolute pos at lingbot timestamps
|
||||
if has_auv_gps:
|
||||
from pyproj import Transformer
|
||||
zone_num = int("".join(c for c in utm_zone if c.isdigit()))
|
||||
hemi = "north" if utm_zone[-1].upper() >= "N" else "south"
|
||||
proj = f"+proj=utm +zone={zone_num} +{hemi} +ellps=WGS84"
|
||||
tr = Transformer.from_crs("EPSG:4326", proj, always_xy=True)
|
||||
auv_x, auv_y = tr.transform(auv_lon, auv_lat)
|
||||
auv_xyz_world = np.column_stack([auv_x, auv_y, -auv_dep])
|
||||
ref_t = auv_t
|
||||
else:
|
||||
# USBL relative offsets + USV GPS → AUV absolute
|
||||
usv_e_i = np.interp(usbl_t.astype(float), usv_t.astype(float), usv_e)
|
||||
usv_n_i = np.interp(usbl_t.astype(float), usv_t.astype(float), usv_n)
|
||||
auv_xyz_world = np.column_stack([usv_e_i + usbl_e, usv_n_i + usbl_n, -usbl_d])
|
||||
ref_t = usbl_t
|
||||
|
||||
# 6. Match lingbot timestamps → world reference points
|
||||
src_pts, dst_pts, weights = [], [], []
|
||||
for i, pt in enumerate(pose_t):
|
||||
idx = np.argmin(np.abs(ref_t - pt))
|
||||
if np.abs(ref_t[idx] - pt) > 5e9: # > 5s gap → skip
|
||||
continue
|
||||
src_pts.append(ling_xyz[i])
|
||||
dst_pts.append(auv_xyz_world[idx])
|
||||
weights.append(1.0)
|
||||
|
||||
if len(src_pts) < 3:
|
||||
print(f"WARNING: Only {len(src_pts)} correspondences (need ≥3). Saving local only.")
|
||||
_save_local_only(out_h5, pose_t, poses_34, ling_xyz, utm_zone)
|
||||
return
|
||||
|
||||
src = np.array(src_pts)
|
||||
dst = np.array(dst_pts)
|
||||
w = np.array(weights)
|
||||
|
||||
# 7. Umeyama with outlier rejection
|
||||
scale, R, t = umeyama(src, dst, w)
|
||||
residuals = np.linalg.norm((scale * (R @ src.T).T + t) - dst, axis=1)
|
||||
sigma = residuals.std()
|
||||
mask = residuals < outlier_sigma * sigma if sigma > 0 else np.ones(len(src), dtype=bool)
|
||||
print(f"Correspondences: {len(src)} total, {mask.sum()} after outlier rejection (σ={sigma:.3f}m)")
|
||||
scale, R, t = umeyama(src[mask], dst[mask], w[mask])
|
||||
|
||||
# 8. Transform all lingbot poses to world frame
|
||||
N = len(pose_t)
|
||||
poses_world = np.zeros((N, 4, 4))
|
||||
poses_world[:, 3, 3] = 1.0
|
||||
for i in range(N):
|
||||
R_local = poses_34[i, :3, :3]
|
||||
t_local = poses_34[i, :3, 3]
|
||||
poses_world[i, :3, :3] = R @ R_local
|
||||
poses_world[i, :3, 3] = scale * R @ t_local + t
|
||||
|
||||
xyz_world = poses_world[:, :3, 3]
|
||||
|
||||
# 9. Write trajectory_world.h5
|
||||
with h5py.File(out_h5, "w") as f:
|
||||
f.attrs["status"] = "aligned"
|
||||
f.attrs["utm_zone"] = utm_zone
|
||||
|
||||
al = f.create_group("alignment")
|
||||
al.attrs["scale"] = scale
|
||||
al.attrs["rmse_m"] = float(residuals[mask].mean())
|
||||
al.attrs["n_correspondences"] = int(mask.sum())
|
||||
al.create_dataset("R", data=R)
|
||||
al.create_dataset("t", data=t)
|
||||
|
||||
pw = f.create_group("poses_world")
|
||||
pw.create_dataset("t_ns", data=pose_t, compression="gzip")
|
||||
pw.create_dataset("x_m", data=xyz_world[:, 0], compression="gzip")
|
||||
pw.create_dataset("y_m", data=xyz_world[:, 1], compression="gzip")
|
||||
pw.create_dataset("z_m", data=xyz_world[:, 2], compression="gzip")
|
||||
pw.create_dataset("T_4x4", data=poses_world, compression="gzip")
|
||||
|
||||
print(f"Fusion OK: scale={scale:.4f} RMSE={residuals[mask].mean():.3f}m → {out_h5}")
|
||||
|
||||
|
||||
def _save_local_only(out_h5: str, pose_t, poses_34, ling_xyz, utm_zone):
|
||||
with h5py.File(out_h5, "w") as f:
|
||||
f.attrs["status"] = "local_only"
|
||||
f.attrs["utm_zone"] = utm_zone
|
||||
pw = f.create_group("poses_world")
|
||||
pw.create_dataset("t_ns", data=pose_t, compression="gzip")
|
||||
pw.create_dataset("x_m", data=ling_xyz[:, 0], compression="gzip")
|
||||
pw.create_dataset("y_m", data=ling_xyz[:, 1], compression="gzip")
|
||||
pw.create_dataset("z_m", data=ling_xyz[:, 2], compression="gzip")
|
||||
pw.create_dataset("T_4x4", data=poses_34, compression="gzip")
|
||||
print(f"Saved local lingbot trajectory (no world alignment) → {out_h5}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fuse(sys.argv[1], sys.argv[2], sys.argv[3])
|
||||
28
tests/test_fuse_trajectory.py
Normal file
28
tests/test_fuse_trajectory.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import tempfile, os
|
||||
import numpy as np
|
||||
import h5py
|
||||
|
||||
def test_fuse_creates_output_without_lingbot():
|
||||
"""When lingbot_poses.npz doesn't exist, fuse creates HDF5 with sources only."""
|
||||
from fuse.fuse_trajectory import fuse
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
out_h5 = os.path.join(tmpdir, "traj.h5")
|
||||
# Create minimal sparse_fixes.h5
|
||||
fixes_h5 = os.path.join(tmpdir, "fixes.h5")
|
||||
with h5py.File(fixes_h5, "w") as f:
|
||||
grp = f.create_group("usv_gps")
|
||||
grp.create_dataset("t_ns", data=np.array([1000, 2000], dtype=np.int64))
|
||||
grp.create_dataset("easting", data=np.array([100.0, 101.0]))
|
||||
grp.create_dataset("northing",data=np.array([200.0, 201.0]))
|
||||
grp.create_dataset("rtk_status", data=np.array([0, 0], dtype=np.int8))
|
||||
grp.attrs["utm_zone"] = "31T"
|
||||
grp2 = f.create_group("auv_mcap")
|
||||
grp2.create_dataset("t_ns", data=np.array([1000, 2000], dtype=np.int64))
|
||||
grp2.create_dataset("lat", data=np.array([0.0, 0.0]))
|
||||
grp2.create_dataset("lon", data=np.array([0.0, 0.0]))
|
||||
grp2.create_dataset("depth_m",data=np.array([5.0, 6.0]))
|
||||
|
||||
fuse(fixes_h5, "/nonexistent/lingbot.npz", out_h5)
|
||||
assert os.path.exists(out_h5)
|
||||
with h5py.File(out_h5, "r") as f:
|
||||
assert "status" in f.attrs
|
||||
41
tests/test_umeyama.py
Normal file
41
tests/test_umeyama.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
def test_umeyama_identity():
|
||||
from fuse.fuse_trajectory import umeyama
|
||||
src = np.random.default_rng(0).standard_normal((10, 3))
|
||||
scale, R, t = umeyama(src, src)
|
||||
assert abs(scale - 1.0) < 1e-5
|
||||
assert np.allclose(R, np.eye(3), atol=1e-5)
|
||||
assert np.allclose(t, np.zeros(3), atol=1e-5)
|
||||
|
||||
def test_umeyama_known_transform():
|
||||
from fuse.fuse_trajectory import umeyama
|
||||
rng = np.random.default_rng(42)
|
||||
src = rng.standard_normal((20, 3))
|
||||
true_scale = 2.5
|
||||
true_R = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)
|
||||
true_t = np.array([1.0, 2.0, 3.0])
|
||||
dst = true_scale * (src @ true_R.T) + true_t
|
||||
scale, R, t = umeyama(src, dst)
|
||||
assert abs(scale - true_scale) < 1e-4
|
||||
assert np.allclose(R, true_R, atol=1e-4)
|
||||
assert np.allclose(t, true_t, atol=1e-4)
|
||||
|
||||
def test_umeyama_weighted():
|
||||
from fuse.fuse_trajectory import umeyama
|
||||
rng = np.random.default_rng(0)
|
||||
src = rng.standard_normal((15, 3))
|
||||
true_scale, true_t = 1.5, np.array([0.5, -0.5, 1.0])
|
||||
dst = true_scale * src + true_t
|
||||
weights = np.ones(15)
|
||||
weights[0] = 0.0 # outlier with zero weight
|
||||
scale, R, t = umeyama(src, dst, weights=weights)
|
||||
assert abs(scale - true_scale) < 1e-3
|
||||
assert np.allclose(t, true_t, atol=1e-3)
|
||||
|
||||
def test_umeyama_raises_on_few_points():
|
||||
from fuse.fuse_trajectory import umeyama
|
||||
src = np.random.default_rng(0).standard_normal((2, 3))
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
umeyama(src, src)
|
||||
Reference in New Issue
Block a user