#!/usr/bin/env python3 """Sweep loop closure params on cached hashes.""" import argparse, csv, math, pickle from pathlib import Path import numpy as np from PIL import Image import imagehash def main(): ap = argparse.ArgumentParser() ap.add_argument('--frames-dir', required=True) ap.add_argument('--dvl-csv', required=True) ap.add_argument('--hash-cache', default='/tmp/phash_cache.pkl') ap.add_argument('--hash-size', type=int, default=16) # bigger for finer discrimination ap.add_argument('--out-plot', required=True) ap.add_argument('--min-sep', type=int, default=60) args = ap.parse_args() frames = sorted(Path(args.frames_dir).glob('frame_*.jpg')) # Cache or compute hashes cache_path = Path(args.hash_cache) if cache_path.exists(): with open(cache_path,'rb') as f: d = pickle.load(f) if d.get('frames_count') == len(frames) and d.get('hash_size') == args.hash_size and d.get('frames_dir') == str(args.frames_dir): hashes = d['hashes'] print(f'[cache] loaded {len(hashes)} hashes from {cache_path}', flush=True) else: cache_path = None if not cache_path or not cache_path.exists(): print(f'[hash] computing {len(frames)} pHashes (size={args.hash_size})...', flush=True) hashes = [] for i, f in enumerate(frames): h = imagehash.phash(Image.open(f), hash_size=args.hash_size) hashes.append(h) if i % 200 == 0: print(f' {i}/{len(frames)}', flush=True) with open(args.hash_cache,'wb') as f: pickle.dump({'hashes': hashes, 'frames_count': len(frames), 'hash_size': args.hash_size, 'frames_dir': str(args.frames_dir)}, f) print(f'[cache] saved to {args.hash_cache}', flush=True) # max_dist for hash_size=16 is ~256 bits; scale threshold accordingly # for hash 8: dist 8 ~12%, for hash 16: dist 32 ~12% # try thresholds at 5%, 8%, 12%, 18% n_bits = args.hash_size * args.hash_size thresholds = [int(n_bits*0.05), int(n_bits*0.08), int(n_bits*0.12), int(n_bits*0.18)] print(f'[loop] hash bits={n_bits}, sweep thresholds: {thresholds}', flush=True) dvl_rows = list(csv.DictReader(open(args.dvl_csv))) e_orig = np.array([float(r['east_m']) for r in dvl_rows]) n_orig = np.array([float(r['north_m']) for r in dvl_rows]) def find_loops_and_correct(max_dist): loops = [] for i in range(len(hashes)): for j in range(i + args.min_sep, len(hashes)): d = hashes[i] - hashes[j] if d <= max_dist: loops.append((i, j, d)) e_c = e_orig.copy(); n_c = n_orig.copy() for i, j, d in loops: if j >= len(e_c): continue dx = e_c[i] - e_c[j]; dy = n_c[i] - n_c[j] ns = j - i for k in range(i+1, j+1): ratio = (k-i)/ns e_c[k] += dx*ratio; n_c[k] += dy*ratio for k in range(j+1, len(e_c)): e_c[k] += dx; n_c[k] += dy return loops, e_c, n_c import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt fig, axes = plt.subplots(2, 3, figsize=(20, 12)) # original ax = axes[0,0] ax.plot(e_orig, n_orig, '-b', linewidth=1.0) ax.plot(e_orig[0], n_orig[0], 'go', markersize=10); ax.plot(e_orig[-1], n_orig[-1], 'r^', markersize=10) bbox=(max(e_orig)-min(e_orig), max(n_orig)-min(n_orig)) ax.set_title(f'ORIGINAL (no LC)\nbbox={bbox[0]:.1f}×{bbox[1]:.1f}m') ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3) # corrected for each threshold positions = [(0,1), (0,2), (1,0), (1,1)] for idx, t in enumerate(thresholds): if idx >= len(positions): break loops, e_c, n_c = find_loops_and_correct(t) ax = axes[positions[idx]] ax.plot(e_c, n_c, '-r', linewidth=1.0) ax.plot(e_c[0], n_c[0], 'go', markersize=10); ax.plot(e_c[-1], n_c[-1], 'r^', markersize=10) bbox=(max(e_c)-min(e_c), max(n_c)-min(n_c)) end_dist = math.sqrt(e_c[-1]**2 + n_c[-1]**2) ax.set_title(f'max_dist={t} ({t/n_bits*100:.0f}% bits)\n{len(loops)} loops bbox={bbox[0]:.1f}×{bbox[1]:.1f}m end={end_dist:.1f}m') ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3) print(f'[t={t}] loops={len(loops)} bbox={bbox} end_dist={end_dist:.1f}', flush=True) # summary: end_dist vs threshold ax = axes[1,2] end_dists = [] for t in thresholds: loops, e_c, n_c = find_loops_and_correct(t) end_dists.append((t, len(loops), math.sqrt(e_c[-1]**2+n_c[-1]**2))) ts = [x[0] for x in end_dists] counts = [x[1] for x in end_dists] ed = [x[2] for x in end_dists] ax2 = ax.twinx() ax.plot(ts, counts, 'b-o', label='loop count'); ax.set_ylabel('Loops found', color='b') ax2.plot(ts, ed, 'r-s', label='end_dist'); ax2.set_ylabel('end_dist (m)', color='r') ax.set_xlabel('max_dist threshold'); ax.set_title('Threshold sweep summary') ax.grid(True, alpha=0.3) plt.suptitle(f'Loop closure threshold sweep — GX039839 (pHash size {args.hash_size}, min_sep {args.min_sep})') plt.tight_layout() plt.savefig(args.out_plot, dpi=120, bbox_inches='tight') print(f'[plot] {args.out_plot}', flush=True) if __name__ == '__main__': main()