From 76bba217dcee049ebd9a7de02c23287def451ba9 Mon Sep 17 00:00:00 2001 From: Floppyrj45 Date: Fri, 24 Apr 2026 10:27:55 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20stitch.py=20--poses=20trajectory=5Fworl?= =?UTF-8?q?d.h5=20=E2=80=94=20T=5Finit=20depuis=20poses=20monde,=20remplac?= =?UTF-8?q?e=20RANSAC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/stitch.py | 44 +++++++++++++++++++++++++++++++++++++- tests/__init__.py | 0 tests/test_stitch_poses.py | 43 +++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_stitch_poses.py diff --git a/scripts/stitch.py b/scripts/stitch.py index 6355dab..29c179d 100644 --- a/scripts/stitch.py +++ b/scripts/stitch.py @@ -77,6 +77,35 @@ def icp_refine(src, dst, init_transform, voxel_size: float): return result.transformation +def _load_world_poses(h5_path: str, n_plys: int) -> list[np.ndarray]: + """Load world-frame transforms from trajectory_world.h5, one per PLY. + + Divides the pose sequence into n_plys equal chunks. + Returns T_i_to_ref (4x4) for each PLY, where T_0_to_ref = I. + """ + import h5py + with h5py.File(h5_path, "r") as f: + if "poses_world" not in f or "T_4x4" not in f["poses_world"]: + raise ValueError(f"{h5_path}: missing poses_world/T_4x4") + T_all = f["poses_world/T_4x4"][:] # (M, 4, 4) + + M = len(T_all) + chunk = max(1, M // n_plys) + + avg_T = [] + for i in range(n_plys): + start = i * chunk + end = min(start + chunk, M) + chunk_T = T_all[start:end] + avg_t = chunk_T[:, :3, 3].mean(0) + T_rep = chunk_T[0].copy() + T_rep[:3, 3] = avg_t + avg_T.append(T_rep) + + T0_inv = np.linalg.inv(avg_T[0]) + return [T0_inv @ avg_T[i] for i in range(n_plys)] + + def main(): ap = argparse.ArgumentParser() ap.add_argument("output", type=Path, help="Output merged PLY") @@ -89,6 +118,9 @@ def main(): help="Use identity as init transform and refine with ICP only") ap.add_argument("--merge-voxel", type=float, default=0.02, help="Final voxel downsampling on merged cloud (0 = no downsample)") + ap.add_argument("--poses", type=str, default=None, + help="Path to trajectory_world.h5 — use world poses as T_init " + "for ICP (replaces RANSAC). Requires h5py.") args = ap.parse_args() try: @@ -108,11 +140,21 @@ def main(): merged = clouds[0] ref_down, ref_fpfh = preprocess(clouds[0], args.voxel) + # Load pose-guided transforms if available + world_transforms = None + if args.poses: + print(f"Loading world poses from {args.poses}...") + world_transforms = _load_world_poses(args.poses, len(clouds)) + print(f"Pose-guided init: {len(world_transforms)} transforms loaded") + for i, src_pcd in enumerate(clouds[1:], start=1): print(f"\nAligning {args.inputs[i].name} → {args.inputs[0].name}...") src_down, src_fpfh = preprocess(src_pcd, args.voxel) - if args.icp_only or args.no_ransac: + if world_transforms is not None: + init_tf = world_transforms[i] + print(" Using world pose T_init (no RANSAC)") + elif args.icp_only or args.no_ransac: init_tf = np.eye(4) else: print(" RANSAC global registration...") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_stitch_poses.py b/tests/test_stitch_poses.py new file mode 100644 index 0000000..5577b3b --- /dev/null +++ b/tests/test_stitch_poses.py @@ -0,0 +1,43 @@ +import sys, os, tempfile +import numpy as np +import h5py +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +def _make_test_h5(path, n_poses=100): + T = np.tile(np.eye(4), (n_poses, 1, 1)).astype(np.float64) + for i in range(n_poses): + T[i, 0, 3] = float(i) * 0.1 + with h5py.File(path, "w") as f: + pw = f.create_group("poses_world") + pw.create_dataset("T_4x4", data=T) + pw.create_dataset("t_ns", data=np.arange(n_poses, dtype=np.int64) * int(1e8)) + f.attrs["status"] = "aligned" + +def test_load_world_poses_returns_n_transforms(): + from scripts.stitch import _load_world_poses + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: + path = tmp.name + try: + _make_test_h5(path, n_poses=100) + transforms = _load_world_poses(path, 4) + assert len(transforms) == 4 + assert np.allclose(transforms[0], np.eye(4), atol=1e-6), "T_0 must be identity" + assert not np.allclose(transforms[1], np.eye(4)), "T_1 must differ from identity" + for T in transforms: + assert T.shape == (4, 4) + finally: + os.unlink(path) + +def test_load_world_poses_single_ply(): + from scripts.stitch import _load_world_poses + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp: + path = tmp.name + try: + _make_test_h5(path, n_poses=10) + transforms = _load_world_poses(path, 1) + assert len(transforms) == 1 + assert np.allclose(transforms[0], np.eye(4), atol=1e-6) + finally: + os.unlink(path)