121 lines
5.7 KiB
Python
121 lines
5.7 KiB
Python
#!/usr/bin/env python3
|
||
"""Loop closure detection via perceptual hashing.
|
||
|
||
For each frame, compute pHash (DCT-based perceptual hash).
|
||
Find pairs (i, j) with |i-j| > MIN_SEPARATION and hash distance < THRESHOLD.
|
||
These are loop closures — AUV revisited same physical location.
|
||
|
||
Then correct DVL trajectory by snapping back at loop closures.
|
||
|
||
Usage:
|
||
python3 loop_closure_phash.py --frames-dir <dir> --dvl-csv <csv> \
|
||
--out-corrected /tmp/dvl_loopclosed.csv --plot /tmp/loop_closure.png \
|
||
--min-sep 60 --max-dist 8
|
||
"""
|
||
import argparse, csv, math
|
||
from pathlib import Path
|
||
import numpy as np
|
||
from PIL import Image
|
||
import imagehash
|
||
|
||
def main():
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument('--frames-dir', required=True)
|
||
ap.add_argument('--dvl-csv', required=True)
|
||
ap.add_argument('--out-corrected', required=True)
|
||
ap.add_argument('--plot', default=None)
|
||
ap.add_argument('--min-sep', type=int, default=60, help='min frame separation to count as loop')
|
||
ap.add_argument('--max-dist', type=int, default=10, help='max pHash Hamming distance for match')
|
||
ap.add_argument('--hash-size', type=int, default=8)
|
||
args = ap.parse_args()
|
||
|
||
frames = sorted(Path(args.frames_dir).glob('frame_*.jpg'))
|
||
print(f'[loop] hashing {len(frames)} frames (pHash size {args.hash_size})...', flush=True)
|
||
|
||
hashes = []
|
||
for i, f in enumerate(frames):
|
||
img = Image.open(f)
|
||
h = imagehash.phash(img, hash_size=args.hash_size)
|
||
hashes.append(h)
|
||
if i % 200 == 0: print(f' hashed {i}/{len(frames)}', flush=True)
|
||
|
||
print(f'[loop] searching loop closures (min_sep={args.min_sep}, max_dist={args.max_dist})...', flush=True)
|
||
loops = [] # list of (i, j, distance)
|
||
for i in range(len(hashes)):
|
||
for j in range(i + args.min_sep, len(hashes)):
|
||
d = hashes[i] - hashes[j]
|
||
if d <= args.max_dist:
|
||
loops.append((i, j, d))
|
||
if i % 200 == 0: print(f' search at {i}, loops found so far: {len(loops)}', flush=True)
|
||
|
||
print(f'[loop] found {len(loops)} loop closures', flush=True)
|
||
|
||
# Load DVL trajectory
|
||
dvl_rows = list(csv.DictReader(open(args.dvl_csv)))
|
||
e = np.array([float(r['east_m']) for r in dvl_rows])
|
||
n = np.array([float(r['north_m']) for r in dvl_rows])
|
||
|
||
# Simple correction: for each loop closure (i, j), interpolate a rigid correction
|
||
# over [i, j] to bring j back to i's position
|
||
# We'll apply gradual correction: for k in [i, j], offset by linear ramp
|
||
e_corr = e.copy(); n_corr = n.copy()
|
||
n_corrections = 0
|
||
for i, j, d in loops:
|
||
if j >= len(e_corr): continue
|
||
dx = e_corr[i] - e_corr[j]
|
||
dy = n_corr[i] - n_corr[j]
|
||
# spread correction linearly over [i+1, j]
|
||
nsteps = j - i
|
||
for k in range(i+1, j+1):
|
||
ratio = (k - i) / nsteps
|
||
e_corr[k] += dx * ratio
|
||
n_corr[k] += dy * ratio
|
||
# carry forward the offset to all frames after j
|
||
for k in range(j+1, len(e_corr)):
|
||
e_corr[k] += dx
|
||
n_corr[k] += dy
|
||
n_corrections += 1
|
||
|
||
print(f'[loop] applied {n_corrections} corrections to trajectory', flush=True)
|
||
|
||
with open(args.out_corrected, 'w', newline='') as ff:
|
||
w = csv.writer(ff)
|
||
w.writerow(['frame_idx','ts_s','east_m_orig','north_m_orig','east_m_corr','north_m_corr'])
|
||
for k, r in enumerate(dvl_rows):
|
||
w.writerow([r['frame_idx'], r['ts_s'], e[k], n[k], e_corr[k], n_corr[k]])
|
||
print(f'[out] {args.out_corrected}', flush=True)
|
||
|
||
if args.plot:
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
|
||
ax_orig, ax_corr, ax_pairs, ax_dist = axes[0,0], axes[0,1], axes[1,0], axes[1,1]
|
||
ax_orig.plot(e, n, '-b', linewidth=1.0); ax_orig.plot(e[0], n[0], 'go', markersize=10); ax_orig.plot(e[-1], n[-1], 'r^', markersize=10)
|
||
ax_orig.set_title(f'DVL trajectory ORIGINAL (drift visible)\nbbox={max(e)-min(e):.1f}×{max(n)-min(n):.1f}m')
|
||
ax_orig.set_xlabel('East (m)'); ax_orig.set_ylabel('North (m)'); ax_orig.set_aspect('equal'); ax_orig.grid(True, alpha=0.3)
|
||
|
||
ax_corr.plot(e_corr, n_corr, '-r', linewidth=1.0); ax_corr.plot(e_corr[0], n_corr[0], 'go', markersize=10); ax_corr.plot(e_corr[-1], n_corr[-1], 'r^', markersize=10)
|
||
ax_corr.set_title(f'DVL trajectory + LOOP CLOSURE\nbbox={max(e_corr)-min(e_corr):.1f}×{max(n_corr)-min(n_corr):.1f}m\nLoops applied: {n_corrections}')
|
||
ax_corr.set_xlabel('East (m)'); ax_corr.set_ylabel('North (m)'); ax_corr.set_aspect('equal'); ax_corr.grid(True, alpha=0.3)
|
||
|
||
# plot loop pairs as lines on original
|
||
ax_pairs.plot(e, n, '-', color='gray', linewidth=0.5, alpha=0.4)
|
||
for i, j, d in loops[:200]: # show first 200 pairs
|
||
ax_pairs.plot([e[i], e[j]], [n[i], n[j]], '-', color='orange', linewidth=0.4, alpha=0.3)
|
||
ax_pairs.set_title(f'Loop closure pairs (first 200, of {len(loops)})')
|
||
ax_pairs.set_xlabel('East'); ax_pairs.set_ylabel('North'); ax_pairs.set_aspect('equal'); ax_pairs.grid(True, alpha=0.3)
|
||
|
||
# histogram of loop distances
|
||
dists = [d for _,_,d in loops]
|
||
if dists:
|
||
ax_dist.hist(dists, bins=range(0, max(dists)+2))
|
||
ax_dist.set_xlabel('Hash Hamming distance'); ax_dist.set_ylabel('Count'); ax_dist.set_title('Loop closure hash distance distribution'); ax_dist.grid(True, alpha=0.3)
|
||
|
||
plt.suptitle(f'Loop closure detection (pHash {args.hash_size}, min_sep={args.min_sep}, max_dist={args.max_dist}) — GX039839')
|
||
plt.tight_layout()
|
||
plt.savefig(args.plot, dpi=130, bbox_inches='tight')
|
||
print(f'[plot] {args.plot}', flush=True)
|
||
|
||
if __name__ == '__main__': main()
|