Files
cosma-qc/scripts/loop_closure_sweep.py

120 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Sweep loop closure params on cached hashes."""
import argparse, csv, math, pickle
from pathlib import Path
import numpy as np
from PIL import Image
import imagehash
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--frames-dir', required=True)
ap.add_argument('--dvl-csv', required=True)
ap.add_argument('--hash-cache', default='/tmp/phash_cache.pkl')
ap.add_argument('--hash-size', type=int, default=16) # bigger for finer discrimination
ap.add_argument('--out-plot', required=True)
ap.add_argument('--min-sep', type=int, default=60)
args = ap.parse_args()
frames = sorted(Path(args.frames_dir).glob('frame_*.jpg'))
# Cache or compute hashes
cache_path = Path(args.hash_cache)
if cache_path.exists():
with open(cache_path,'rb') as f:
d = pickle.load(f)
if d.get('frames_count') == len(frames) and d.get('hash_size') == args.hash_size and d.get('frames_dir') == str(args.frames_dir):
hashes = d['hashes']
print(f'[cache] loaded {len(hashes)} hashes from {cache_path}', flush=True)
else:
cache_path = None
if not cache_path or not cache_path.exists():
print(f'[hash] computing {len(frames)} pHashes (size={args.hash_size})...', flush=True)
hashes = []
for i, f in enumerate(frames):
h = imagehash.phash(Image.open(f), hash_size=args.hash_size)
hashes.append(h)
if i % 200 == 0: print(f' {i}/{len(frames)}', flush=True)
with open(args.hash_cache,'wb') as f:
pickle.dump({'hashes': hashes, 'frames_count': len(frames), 'hash_size': args.hash_size, 'frames_dir': str(args.frames_dir)}, f)
print(f'[cache] saved to {args.hash_cache}', flush=True)
# max_dist for hash_size=16 is ~256 bits; scale threshold accordingly
# for hash 8: dist 8 ~12%, for hash 16: dist 32 ~12%
# try thresholds at 5%, 8%, 12%, 18%
n_bits = args.hash_size * args.hash_size
thresholds = [int(n_bits*0.05), int(n_bits*0.08), int(n_bits*0.12), int(n_bits*0.18)]
print(f'[loop] hash bits={n_bits}, sweep thresholds: {thresholds}', flush=True)
dvl_rows = list(csv.DictReader(open(args.dvl_csv)))
e_orig = np.array([float(r['east_m']) for r in dvl_rows])
n_orig = np.array([float(r['north_m']) for r in dvl_rows])
def find_loops_and_correct(max_dist):
loops = []
for i in range(len(hashes)):
for j in range(i + args.min_sep, len(hashes)):
d = hashes[i] - hashes[j]
if d <= max_dist:
loops.append((i, j, d))
e_c = e_orig.copy(); n_c = n_orig.copy()
for i, j, d in loops:
if j >= len(e_c): continue
dx = e_c[i] - e_c[j]; dy = n_c[i] - n_c[j]
ns = j - i
for k in range(i+1, j+1):
ratio = (k-i)/ns
e_c[k] += dx*ratio; n_c[k] += dy*ratio
for k in range(j+1, len(e_c)):
e_c[k] += dx; n_c[k] += dy
return loops, e_c, n_c
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
# original
ax = axes[0,0]
ax.plot(e_orig, n_orig, '-b', linewidth=1.0)
ax.plot(e_orig[0], n_orig[0], 'go', markersize=10); ax.plot(e_orig[-1], n_orig[-1], 'r^', markersize=10)
bbox=(max(e_orig)-min(e_orig), max(n_orig)-min(n_orig))
ax.set_title(f'ORIGINAL (no LC)\nbbox={bbox[0]:.1f}×{bbox[1]:.1f}m')
ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
# corrected for each threshold
positions = [(0,1), (0,2), (1,0), (1,1)]
for idx, t in enumerate(thresholds):
if idx >= len(positions): break
loops, e_c, n_c = find_loops_and_correct(t)
ax = axes[positions[idx]]
ax.plot(e_c, n_c, '-r', linewidth=1.0)
ax.plot(e_c[0], n_c[0], 'go', markersize=10); ax.plot(e_c[-1], n_c[-1], 'r^', markersize=10)
bbox=(max(e_c)-min(e_c), max(n_c)-min(n_c))
end_dist = math.sqrt(e_c[-1]**2 + n_c[-1]**2)
ax.set_title(f'max_dist={t} ({t/n_bits*100:.0f}% bits)\n{len(loops)} loops bbox={bbox[0]:.1f}×{bbox[1]:.1f}m end={end_dist:.1f}m')
ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
print(f'[t={t}] loops={len(loops)} bbox={bbox} end_dist={end_dist:.1f}', flush=True)
# summary: end_dist vs threshold
ax = axes[1,2]
end_dists = []
for t in thresholds:
loops, e_c, n_c = find_loops_and_correct(t)
end_dists.append((t, len(loops), math.sqrt(e_c[-1]**2+n_c[-1]**2)))
ts = [x[0] for x in end_dists]
counts = [x[1] for x in end_dists]
ed = [x[2] for x in end_dists]
ax2 = ax.twinx()
ax.plot(ts, counts, 'b-o', label='loop count'); ax.set_ylabel('Loops found', color='b')
ax2.plot(ts, ed, 'r-s', label='end_dist'); ax2.set_ylabel('end_dist (m)', color='r')
ax.set_xlabel('max_dist threshold'); ax.set_title('Threshold sweep summary')
ax.grid(True, alpha=0.3)
plt.suptitle(f'Loop closure threshold sweep — GX039839 (pHash size {args.hash_size}, min_sep {args.min_sep})')
plt.tight_layout()
plt.savefig(args.out_plot, dpi=120, bbox_inches='tight')
print(f'[plot] {args.out_plot}', flush=True)
if __name__ == '__main__': main()