Files
cosma-qc/scripts/loop_closure_lightglue.py

253 lines
10 KiB
Python
Executable File

#!/usr/bin/env python3
"""Loop closure detection via LightGlue (SuperPoint + LightGlue matcher).
Pipeline:
1. Read DVL trajectory CSV (raw east_m,north_m per frame).
2. Build candidate pairs (i, j) with |i-j| > min_sep.
Sample stratifie if > max_pairs.
3. Send pairs + frames to GPU host (.87) via SSH; LightGlue runs there.
4. Filter pairs with n_high > match_threshold = loop closures.
5. Apply linear-ramp correction (same algo as pHash variant): for each LC,
pull frame j back to frame i, distribute drift across [i+1..j] linearly
and carry offset forward for k > j.
Usage:
python3 loop_closure_lightglue.py \
--frames-dir /tmp/frames_GX019818/ \
--dvl-csv /tmp/dvl_full_GX019818.csv \
--out-corrected /tmp/dvl_lightglue_GX019818.csv \
--plot /tmp/loop_closure_lightglue.png \
--min-sep 60 --match-threshold 50 --max-pairs 30000 \
--gpu-host 192.168.0.87 --gpu-user floppyrj45 \
--gpu-frames-dir /home/floppyrj45/lightglue-test/frames_GX019818 \
--gpu-venv /home/floppyrj45/lightglue-test/venv \
--gpu-worker /home/floppyrj45/lightglue-test/lightglue_pairs_worker.py
"""
import argparse
import csv
import math
import os
import random
import subprocess
import sys
import tempfile
from pathlib import Path
import numpy as np
def stratified_pairs(n_frames, min_sep, max_pairs, seed=42):
"""Sample pairs (i,j) with |i-j| > min_sep, stratified by separation bucket.
Tries to get good coverage: for each separation range [min_sep..2*min_sep],
[2*min_sep..4*min_sep], ..., draw equal share. Plus all-i to random-j fallback.
"""
rng = random.Random(seed)
pairs = set()
# Brute force for small N: all pairs |i-j|>min_sep then truncate
full_count = 0
for i in range(n_frames):
for j in range(i + min_sep + 1, n_frames):
full_count += 1
if full_count <= max_pairs:
for i in range(n_frames):
for j in range(i + min_sep + 1, n_frames):
pairs.add((i, j))
out = sorted(pairs)
return out
# Stratified buckets by log separation
deltas = []
d = min_sep + 1
while d < n_frames:
deltas.append(d)
d = int(d * 1.7) + 1
deltas.append(n_frames)
buckets = list(zip(deltas[:-1], deltas[1:]))
if not buckets:
buckets = [(min_sep + 1, n_frames)]
per_bucket = max_pairs // len(buckets)
for (lo, hi) in buckets:
attempts = 0
added = 0
while added < per_bucket and attempts < per_bucket * 20:
attempts += 1
i = rng.randrange(n_frames)
delta = rng.randint(lo, max(lo + 1, hi - 1))
j = i + delta
if j >= n_frames:
j = i - delta
if 0 <= j < n_frames and abs(i - j) > min_sep:
a, b = min(i, j), max(i, j)
if (a, b) not in pairs:
pairs.add((a, b))
added += 1
out = sorted(pairs)
return out
def main():
ap = argparse.ArgumentParser()
ap.add_argument('--frames-dir', required=True)
ap.add_argument('--dvl-csv', required=True)
ap.add_argument('--out-corrected', required=True)
ap.add_argument('--plot', default=None)
ap.add_argument('--min-sep', type=int, default=60)
ap.add_argument('--match-threshold', type=int, default=50)
ap.add_argument('--max-pairs', type=int, default=30000)
ap.add_argument('--gpu-host', default='192.168.0.87')
ap.add_argument('--gpu-user', default='floppyrj45')
ap.add_argument('--gpu-frames-dir', default='/home/floppyrj45/lightglue-test/frames_GX019818')
ap.add_argument('--gpu-venv', default='/home/floppyrj45/lightglue-test/venv')
ap.add_argument('--gpu-worker', default='/home/floppyrj45/lightglue-test/lightglue_pairs_worker.py')
ap.add_argument('--remote-pairs-path', default='/tmp/lg_pairs.txt')
ap.add_argument('--remote-out-path', default='/tmp/lg_matches.csv')
ap.add_argument('--n-positions-cap', type=int, default=0,
help='if >0, cap n_positions used for pair generation (must match GPU frames count)')
args = ap.parse_args()
# Map DVL CSV rows to frames present locally — we need positions in *sorted frames* on GPU host.
# We assume frame_idx in CSV matches file name 'frame_NNNN.jpg' with NNNN = frame_idx+1 zero-padded
# OR matches sorted index. Since file names are sequential (frame_0001..frame_1451) and DVL has 1663
# rows, only frames 0..1450 are physically present. We restrict LC search to those rows AND only
# frames whose file exists.
frames_dir = Path(args.frames_dir)
local_frames = sorted(p.name for p in frames_dir.iterdir()
if p.suffix.lower() in ('.jpg', '.jpeg', '.png'))
# local_frames sorted == what worker will sort → indices align across hosts.
# Map frame_name "frame_0001.jpg" -> 1-based number -> 0-based dvl frame_idx = num-1
def name_to_dvl_idx(name):
stem = Path(name).stem # frame_0001
num = int(stem.split('_')[1])
return num - 1 # 0-based
pos_to_dvl = [name_to_dvl_idx(n) for n in local_frames]
n_positions = len(local_frames)
if args.n_positions_cap and args.n_positions_cap < n_positions:
n_positions = args.n_positions_cap
pos_to_dvl = pos_to_dvl[:n_positions]
print(f'[lc] positions used for pairs: {n_positions}', flush=True)
# DVL CSV
dvl_rows = list(csv.DictReader(open(args.dvl_csv)))
e_full = np.array([float(r['east_m']) for r in dvl_rows])
n_full = np.array([float(r['north_m']) for r in dvl_rows])
n_full_rows = len(dvl_rows)
print(f'[lc] dvl rows: {n_full_rows}', flush=True)
# Build candidate pairs over *positions* (worker indexes positions of sorted frames)
pairs_pos = stratified_pairs(n_positions, args.min_sep, args.max_pairs)
print(f'[lc] candidate pairs: {len(pairs_pos)}', flush=True)
# Write pairs file locally then scp to GPU host
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.txt') as f:
pairs_local_path = f.name
for i, j in pairs_pos:
f.write(f'{i},{j}\n')
print(f'[lc] wrote pairs file {pairs_local_path}', flush=True)
scp_cmd = ['scp', '-o', 'StrictHostKeyChecking=no', pairs_local_path,
f'{args.gpu_user}@{args.gpu_host}:{args.remote_pairs_path}']
subprocess.run(scp_cmd, check=True)
print(f'[lc] uploaded pairs to {args.gpu_host}:{args.remote_pairs_path}', flush=True)
# Run worker remotely
remote_cmd = (
f'source {args.gpu_venv}/bin/activate && '
f'python3 {args.gpu_worker} '
f'--frames-dir {args.gpu_frames_dir} '
f'--pairs-file {args.remote_pairs_path} '
f'--out-file {args.remote_out_path} '
f'--score-thr 0.5'
)
ssh_cmd = ['ssh', '-o', 'StrictHostKeyChecking=no',
f'{args.gpu_user}@{args.gpu_host}', remote_cmd]
print(f'[lc] invoking worker remotely ...', flush=True)
r = subprocess.run(ssh_cmd)
if r.returncode != 0:
print(f'[lc] remote worker failed rc={r.returncode}', file=sys.stderr)
sys.exit(r.returncode)
# Pull back matches CSV
local_matches = '/tmp/lg_matches_local.csv'
subprocess.run(['scp', '-o', 'StrictHostKeyChecking=no',
f'{args.gpu_user}@{args.gpu_host}:{args.remote_out_path}', local_matches],
check=True)
print(f'[lc] pulled matches to {local_matches}', flush=True)
# Parse matches, filter
loops = [] # (dvl_i, dvl_j, n_high)
with open(local_matches) as f:
next(f) # header
for line in f:
parts = line.strip().split(',')
if len(parts) < 4:
continue
pi, pj, n_total, n_high = int(parts[0]), int(parts[1]), int(parts[2]), int(parts[3])
if n_high > args.match_threshold:
di = pos_to_dvl[pi]
dj = pos_to_dvl[pj]
if di > dj:
di, dj = dj, di
if dj - di > args.min_sep:
loops.append((di, dj, n_high))
print(f'[lc] kept {len(loops)} loop closures (n_high > {args.match_threshold})', flush=True)
# Apply linear-ramp correction (same as phash variant)
e_corr = e_full.copy()
n_corr = n_full.copy()
n_applied = 0
# Sort loops by i ascending then by j ascending so corrections are applied left to right
loops.sort(key=lambda x: (x[0], x[1]))
for i, j, nh in loops:
if j >= len(e_corr):
continue
dx = e_corr[i] - e_corr[j]
dy = n_corr[i] - n_corr[j]
nsteps = j - i
for k in range(i + 1, j + 1):
ratio = (k - i) / nsteps
e_corr[k] += dx * ratio
n_corr[k] += dy * ratio
for k in range(j + 1, len(e_corr)):
e_corr[k] += dx
n_corr[k] += dy
n_applied += 1
print(f'[lc] applied {n_applied} corrections', flush=True)
with open(args.out_corrected, 'w', newline='') as f:
w = csv.writer(f)
w.writerow(['frame_idx', 'ts_s', 'east_m_orig', 'north_m_orig', 'east_m_corr', 'north_m_corr', 'n_loops'])
for k, r in enumerate(dvl_rows):
w.writerow([r['frame_idx'], r['ts_s'], e_full[k], n_full[k], e_corr[k], n_corr[k],
n_applied if k == 0 else ''])
print(f'[out] {args.out_corrected}', flush=True)
if args.plot:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(14, 7))
axes[0].plot(e_full, n_full, '-b', lw=1)
axes[0].plot(e_full[0], n_full[0], 'go', ms=10)
axes[0].plot(e_full[-1], n_full[-1], 'r^', ms=10)
axes[0].set_title(f'RAW DVL\nbbox={e_full.max()-e_full.min():.1f}x{n_full.max()-n_full.min():.1f}m')
axes[0].set_xlabel('East m'); axes[0].set_ylabel('North m'); axes[0].set_aspect('equal'); axes[0].grid(alpha=0.3)
axes[1].plot(e_corr, n_corr, '-r', lw=1)
axes[1].plot(e_corr[0], n_corr[0], 'go', ms=10)
axes[1].plot(e_corr[-1], n_corr[-1], 'r^', ms=10)
axes[1].set_title(f'LightGlue LC ({n_applied} loops)\nbbox={e_corr.max()-e_corr.min():.1f}x{n_corr.max()-n_corr.min():.1f}m')
axes[1].set_xlabel('East m'); axes[1].set_ylabel('North m'); axes[1].set_aspect('equal'); axes[1].grid(alpha=0.3)
plt.suptitle(f'LightGlue loop closure — GX019818 (thr={args.match_threshold})')
plt.tight_layout()
plt.savefig(args.plot, dpi=130, bbox_inches='tight')
print(f'[plot] {args.plot}', flush=True)
if __name__ == '__main__':
main()