Files
cosma-qc/scripts/loop_closure_phash.py

121 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Loop closure detection via perceptual hashing.
For each frame, compute pHash (DCT-based perceptual hash).
Find pairs (i, j) with |i-j| > MIN_SEPARATION and hash distance < THRESHOLD.
These are loop closures — AUV revisited same physical location.
Then correct DVL trajectory by snapping back at loop closures.
Usage:
python3 loop_closure_phash.py --frames-dir <dir> --dvl-csv <csv> \
--out-corrected /tmp/dvl_loopclosed.csv --plot /tmp/loop_closure.png \
--min-sep 60 --max-dist 8
"""
import argparse, csv, math
from pathlib import Path
import numpy as np
from PIL import Image
import imagehash
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--frames-dir', required=True)
ap.add_argument('--dvl-csv', required=True)
ap.add_argument('--out-corrected', required=True)
ap.add_argument('--plot', default=None)
ap.add_argument('--min-sep', type=int, default=60, help='min frame separation to count as loop')
ap.add_argument('--max-dist', type=int, default=10, help='max pHash Hamming distance for match')
ap.add_argument('--hash-size', type=int, default=8)
args = ap.parse_args()
frames = sorted(Path(args.frames_dir).glob('frame_*.jpg'))
print(f'[loop] hashing {len(frames)} frames (pHash size {args.hash_size})...', flush=True)
hashes = []
for i, f in enumerate(frames):
img = Image.open(f)
h = imagehash.phash(img, hash_size=args.hash_size)
hashes.append(h)
if i % 200 == 0: print(f' hashed {i}/{len(frames)}', flush=True)
print(f'[loop] searching loop closures (min_sep={args.min_sep}, max_dist={args.max_dist})...', flush=True)
loops = [] # list of (i, j, distance)
for i in range(len(hashes)):
for j in range(i + args.min_sep, len(hashes)):
d = hashes[i] - hashes[j]
if d <= args.max_dist:
loops.append((i, j, d))
if i % 200 == 0: print(f' search at {i}, loops found so far: {len(loops)}', flush=True)
print(f'[loop] found {len(loops)} loop closures', flush=True)
# Load DVL trajectory
dvl_rows = list(csv.DictReader(open(args.dvl_csv)))
e = np.array([float(r['east_m']) for r in dvl_rows])
n = np.array([float(r['north_m']) for r in dvl_rows])
# Simple correction: for each loop closure (i, j), interpolate a rigid correction
# over [i, j] to bring j back to i's position
# We'll apply gradual correction: for k in [i, j], offset by linear ramp
e_corr = e.copy(); n_corr = n.copy()
n_corrections = 0
for i, j, d in loops:
if j >= len(e_corr): continue
dx = e_corr[i] - e_corr[j]
dy = n_corr[i] - n_corr[j]
# spread correction linearly over [i+1, j]
nsteps = j - i
for k in range(i+1, j+1):
ratio = (k - i) / nsteps
e_corr[k] += dx * ratio
n_corr[k] += dy * ratio
# carry forward the offset to all frames after j
for k in range(j+1, len(e_corr)):
e_corr[k] += dx
n_corr[k] += dy
n_corrections += 1
print(f'[loop] applied {n_corrections} corrections to trajectory', flush=True)
with open(args.out_corrected, 'w', newline='') as ff:
w = csv.writer(ff)
w.writerow(['frame_idx','ts_s','east_m_orig','north_m_orig','east_m_corr','north_m_corr'])
for k, r in enumerate(dvl_rows):
w.writerow([r['frame_idx'], r['ts_s'], e[k], n[k], e_corr[k], n_corr[k]])
print(f'[out] {args.out_corrected}', flush=True)
if args.plot:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
ax_orig, ax_corr, ax_pairs, ax_dist = axes[0,0], axes[0,1], axes[1,0], axes[1,1]
ax_orig.plot(e, n, '-b', linewidth=1.0); ax_orig.plot(e[0], n[0], 'go', markersize=10); ax_orig.plot(e[-1], n[-1], 'r^', markersize=10)
ax_orig.set_title(f'DVL trajectory ORIGINAL (drift visible)\nbbox={max(e)-min(e):.1f}×{max(n)-min(n):.1f}m')
ax_orig.set_xlabel('East (m)'); ax_orig.set_ylabel('North (m)'); ax_orig.set_aspect('equal'); ax_orig.grid(True, alpha=0.3)
ax_corr.plot(e_corr, n_corr, '-r', linewidth=1.0); ax_corr.plot(e_corr[0], n_corr[0], 'go', markersize=10); ax_corr.plot(e_corr[-1], n_corr[-1], 'r^', markersize=10)
ax_corr.set_title(f'DVL trajectory + LOOP CLOSURE\nbbox={max(e_corr)-min(e_corr):.1f}×{max(n_corr)-min(n_corr):.1f}m\nLoops applied: {n_corrections}')
ax_corr.set_xlabel('East (m)'); ax_corr.set_ylabel('North (m)'); ax_corr.set_aspect('equal'); ax_corr.grid(True, alpha=0.3)
# plot loop pairs as lines on original
ax_pairs.plot(e, n, '-', color='gray', linewidth=0.5, alpha=0.4)
for i, j, d in loops[:200]: # show first 200 pairs
ax_pairs.plot([e[i], e[j]], [n[i], n[j]], '-', color='orange', linewidth=0.4, alpha=0.3)
ax_pairs.set_title(f'Loop closure pairs (first 200, of {len(loops)})')
ax_pairs.set_xlabel('East'); ax_pairs.set_ylabel('North'); ax_pairs.set_aspect('equal'); ax_pairs.grid(True, alpha=0.3)
# histogram of loop distances
dists = [d for _,_,d in loops]
if dists:
ax_dist.hist(dists, bins=range(0, max(dists)+2))
ax_dist.set_xlabel('Hash Hamming distance'); ax_dist.set_ylabel('Count'); ax_dist.set_title('Loop closure hash distance distribution'); ax_dist.grid(True, alpha=0.3)
plt.suptitle(f'Loop closure detection (pHash {args.hash_size}, min_sep={args.min_sep}, max_dist={args.max_dist}) — GX039839')
plt.tight_layout()
plt.savefig(args.plot, dpi=130, bbox_inches='tight')
print(f'[plot] {args.plot}', flush=True)
if __name__ == '__main__': main()