120 lines
5.3 KiB
Python
120 lines
5.3 KiB
Python
#!/usr/bin/env python3
|
||
"""Sweep loop closure params on cached hashes."""
|
||
import argparse, csv, math, pickle
|
||
from pathlib import Path
|
||
import numpy as np
|
||
from PIL import Image
|
||
import imagehash
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument('--frames-dir', required=True)
|
||
ap.add_argument('--dvl-csv', required=True)
|
||
ap.add_argument('--hash-cache', default='/tmp/phash_cache.pkl')
|
||
ap.add_argument('--hash-size', type=int, default=16) # bigger for finer discrimination
|
||
ap.add_argument('--out-plot', required=True)
|
||
ap.add_argument('--min-sep', type=int, default=60)
|
||
args = ap.parse_args()
|
||
|
||
frames = sorted(Path(args.frames_dir).glob('frame_*.jpg'))
|
||
|
||
# Cache or compute hashes
|
||
cache_path = Path(args.hash_cache)
|
||
if cache_path.exists():
|
||
with open(cache_path,'rb') as f:
|
||
d = pickle.load(f)
|
||
if d.get('frames_count') == len(frames) and d.get('hash_size') == args.hash_size and d.get('frames_dir') == str(args.frames_dir):
|
||
hashes = d['hashes']
|
||
print(f'[cache] loaded {len(hashes)} hashes from {cache_path}', flush=True)
|
||
else:
|
||
cache_path = None
|
||
if not cache_path or not cache_path.exists():
|
||
print(f'[hash] computing {len(frames)} pHashes (size={args.hash_size})...', flush=True)
|
||
hashes = []
|
||
for i, f in enumerate(frames):
|
||
h = imagehash.phash(Image.open(f), hash_size=args.hash_size)
|
||
hashes.append(h)
|
||
if i % 200 == 0: print(f' {i}/{len(frames)}', flush=True)
|
||
with open(args.hash_cache,'wb') as f:
|
||
pickle.dump({'hashes': hashes, 'frames_count': len(frames), 'hash_size': args.hash_size, 'frames_dir': str(args.frames_dir)}, f)
|
||
print(f'[cache] saved to {args.hash_cache}', flush=True)
|
||
|
||
# max_dist for hash_size=16 is ~256 bits; scale threshold accordingly
|
||
# for hash 8: dist 8 ~12%, for hash 16: dist 32 ~12%
|
||
# try thresholds at 5%, 8%, 12%, 18%
|
||
n_bits = args.hash_size * args.hash_size
|
||
thresholds = [int(n_bits*0.05), int(n_bits*0.08), int(n_bits*0.12), int(n_bits*0.18)]
|
||
print(f'[loop] hash bits={n_bits}, sweep thresholds: {thresholds}', flush=True)
|
||
|
||
dvl_rows = list(csv.DictReader(open(args.dvl_csv)))
|
||
e_orig = np.array([float(r['east_m']) for r in dvl_rows])
|
||
n_orig = np.array([float(r['north_m']) for r in dvl_rows])
|
||
|
||
def find_loops_and_correct(max_dist):
|
||
loops = []
|
||
for i in range(len(hashes)):
|
||
for j in range(i + args.min_sep, len(hashes)):
|
||
d = hashes[i] - hashes[j]
|
||
if d <= max_dist:
|
||
loops.append((i, j, d))
|
||
e_c = e_orig.copy(); n_c = n_orig.copy()
|
||
for i, j, d in loops:
|
||
if j >= len(e_c): continue
|
||
dx = e_c[i] - e_c[j]; dy = n_c[i] - n_c[j]
|
||
ns = j - i
|
||
for k in range(i+1, j+1):
|
||
ratio = (k-i)/ns
|
||
e_c[k] += dx*ratio; n_c[k] += dy*ratio
|
||
for k in range(j+1, len(e_c)):
|
||
e_c[k] += dx; n_c[k] += dy
|
||
return loops, e_c, n_c
|
||
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
|
||
|
||
# original
|
||
ax = axes[0,0]
|
||
ax.plot(e_orig, n_orig, '-b', linewidth=1.0)
|
||
ax.plot(e_orig[0], n_orig[0], 'go', markersize=10); ax.plot(e_orig[-1], n_orig[-1], 'r^', markersize=10)
|
||
bbox=(max(e_orig)-min(e_orig), max(n_orig)-min(n_orig))
|
||
ax.set_title(f'ORIGINAL (no LC)\nbbox={bbox[0]:.1f}×{bbox[1]:.1f}m')
|
||
ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
|
||
|
||
# corrected for each threshold
|
||
positions = [(0,1), (0,2), (1,0), (1,1)]
|
||
for idx, t in enumerate(thresholds):
|
||
if idx >= len(positions): break
|
||
loops, e_c, n_c = find_loops_and_correct(t)
|
||
ax = axes[positions[idx]]
|
||
ax.plot(e_c, n_c, '-r', linewidth=1.0)
|
||
ax.plot(e_c[0], n_c[0], 'go', markersize=10); ax.plot(e_c[-1], n_c[-1], 'r^', markersize=10)
|
||
bbox=(max(e_c)-min(e_c), max(n_c)-min(n_c))
|
||
end_dist = math.sqrt(e_c[-1]**2 + n_c[-1]**2)
|
||
ax.set_title(f'max_dist={t} ({t/n_bits*100:.0f}% bits)\n{len(loops)} loops bbox={bbox[0]:.1f}×{bbox[1]:.1f}m end={end_dist:.1f}m')
|
||
ax.set_xlabel('East'); ax.set_ylabel('North'); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
|
||
print(f'[t={t}] loops={len(loops)} bbox={bbox} end_dist={end_dist:.1f}', flush=True)
|
||
|
||
# summary: end_dist vs threshold
|
||
ax = axes[1,2]
|
||
end_dists = []
|
||
for t in thresholds:
|
||
loops, e_c, n_c = find_loops_and_correct(t)
|
||
end_dists.append((t, len(loops), math.sqrt(e_c[-1]**2+n_c[-1]**2)))
|
||
ts = [x[0] for x in end_dists]
|
||
counts = [x[1] for x in end_dists]
|
||
ed = [x[2] for x in end_dists]
|
||
ax2 = ax.twinx()
|
||
ax.plot(ts, counts, 'b-o', label='loop count'); ax.set_ylabel('Loops found', color='b')
|
||
ax2.plot(ts, ed, 'r-s', label='end_dist'); ax2.set_ylabel('end_dist (m)', color='r')
|
||
ax.set_xlabel('max_dist threshold'); ax.set_title('Threshold sweep summary')
|
||
ax.grid(True, alpha=0.3)
|
||
|
||
plt.suptitle(f'Loop closure threshold sweep — GX039839 (pHash size {args.hash_size}, min_sep {args.min_sep})')
|
||
plt.tight_layout()
|
||
plt.savefig(args.out_plot, dpi=120, bbox_inches='tight')
|
||
print(f'[plot] {args.out_plot}', flush=True)
|
||
|
||
if __name__ == '__main__': main()
|