42 lines
1.5 KiB
Python
42 lines
1.5 KiB
Python
import numpy as np
|
|
import pytest
|
|
|
|
def test_umeyama_identity():
|
|
from fuse.fuse_trajectory import umeyama
|
|
src = np.random.default_rng(0).standard_normal((10, 3))
|
|
scale, R, t = umeyama(src, src)
|
|
assert abs(scale - 1.0) < 1e-5
|
|
assert np.allclose(R, np.eye(3), atol=1e-5)
|
|
assert np.allclose(t, np.zeros(3), atol=1e-5)
|
|
|
|
def test_umeyama_known_transform():
|
|
from fuse.fuse_trajectory import umeyama
|
|
rng = np.random.default_rng(42)
|
|
src = rng.standard_normal((20, 3))
|
|
true_scale = 2.5
|
|
true_R = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)
|
|
true_t = np.array([1.0, 2.0, 3.0])
|
|
dst = true_scale * (src @ true_R.T) + true_t
|
|
scale, R, t = umeyama(src, dst)
|
|
assert abs(scale - true_scale) < 1e-4
|
|
assert np.allclose(R, true_R, atol=1e-4)
|
|
assert np.allclose(t, true_t, atol=1e-4)
|
|
|
|
def test_umeyama_weighted():
|
|
from fuse.fuse_trajectory import umeyama
|
|
rng = np.random.default_rng(0)
|
|
src = rng.standard_normal((15, 3))
|
|
true_scale, true_t = 1.5, np.array([0.5, -0.5, 1.0])
|
|
dst = true_scale * src + true_t
|
|
weights = np.ones(15)
|
|
weights[0] = 0.0 # outlier with zero weight
|
|
scale, R, t = umeyama(src, dst, weights=weights)
|
|
assert abs(scale - true_scale) < 1e-3
|
|
assert np.allclose(t, true_t, atol=1e-3)
|
|
|
|
def test_umeyama_raises_on_few_points():
|
|
from fuse.fuse_trajectory import umeyama
|
|
src = np.random.default_rng(0).standard_normal((2, 3))
|
|
with pytest.raises(ValueError, match="at least 3"):
|
|
umeyama(src, src)
|