780 lines
32 KiB
Python
780 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""Dispatcher daemon: picks queued jobs/stitches and runs them on available workers.
|
|
|
|
Env:
|
|
COSMA_QC_DB : SQLite path (default /var/lib/cosma-qc/jobs.db)
|
|
COSMA_QC_WORKERS : JSON list of workers [{host, ssh_alias, gpu, vram_mib,
|
|
frames_dir, lingbot_path}]
|
|
COSMA_QC_FPS : extraction fps (default 3)
|
|
COSMA_QC_IMG_H : image height (default 294)
|
|
COSMA_QC_IMG_W : image width (default 518)
|
|
|
|
Jobs lifecycle:
|
|
queued → extracting → running → done → [triggers per_auv stitch]
|
|
↘ error
|
|
Stitch lifecycle:
|
|
queued → running → done → [triggers cross_auv stitch if all per_auv done]
|
|
↘ error
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import shlex
|
|
import sqlite3
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
from contextlib import closing
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
|
|
def _now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat(timespec="seconds")
|
|
|
|
DB_PATH = Path(os.environ.get("COSMA_QC_DB", "/var/lib/cosma-qc/jobs.db"))
|
|
FPS = int(os.environ.get("COSMA_QC_FPS", "3"))
|
|
IMG_H = int(os.environ.get("COSMA_QC_IMG_H", "294"))
|
|
IMG_W = int(os.environ.get("COSMA_QC_IMG_W", "518"))
|
|
POLL_S = int(os.environ.get("COSMA_QC_POLL_S", "4"))
|
|
STITCH_SCRIPT = Path(__file__).parent / "stitch.py"
|
|
|
|
DEFAULT_WORKERS = [
|
|
{
|
|
"host": "192.168.0.87", "ssh_alias": "gpu", "gpu": "RTX 3060 12GB",
|
|
"vram_mib": 11913,
|
|
"frames_dir": "/home/floppyrj45/cosma-qc-frames",
|
|
"lingbot_path": "/home/floppyrj45/ai-video/lingbot-map",
|
|
"viser_port_base": 8100,
|
|
},
|
|
{
|
|
"host": "192.168.0.84", "ssh_alias": "cosma-vm", "gpu": "RTX 3090 24GB",
|
|
"vram_mib": 24576,
|
|
"frames_dir": "/home/floppyrj45/cosma-qc-frames",
|
|
"lingbot_path": "/home/floppyrj45/ai-video/lingbot-map",
|
|
"viser_port_base": 8100,
|
|
},
|
|
]
|
|
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS)))
|
|
|
|
_worker_lock = threading.Lock()
|
|
_reserved_vram = {w["host"]: 0 for w in WORKERS}
|
|
|
|
|
|
def db() -> sqlite3.Connection:
|
|
conn = sqlite3.connect(DB_PATH, isolation_level=None)
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def _migrate():
|
|
"""Idempotent schema upgrades for fields added after initial release."""
|
|
with closing(db()) as conn:
|
|
cols = {r["name"] for r in conn.execute("PRAGMA table_info(jobs)")}
|
|
for col, ddl in (
|
|
("trimmed_head", "INTEGER DEFAULT 0"),
|
|
("trimmed_tail", "INTEGER DEFAULT 0"),
|
|
("video_duration_s", "REAL DEFAULT 0"),
|
|
# Human-readable phase so the dashboard can show "scp 2/3", "ffmpeg 45%", "reconstruct 12/113 windows"...
|
|
("step", "TEXT"),
|
|
):
|
|
if col not in cols:
|
|
conn.execute(f"ALTER TABLE jobs ADD COLUMN {col} {ddl}")
|
|
|
|
|
|
_migrate()
|
|
|
|
|
|
def ssh(alias: str, cmd: str, timeout: int = 30) -> tuple[int, str, str]:
|
|
p = subprocess.run(
|
|
["ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=5", alias, cmd],
|
|
capture_output=True, text=True, timeout=timeout,
|
|
)
|
|
return p.returncode, p.stdout, p.stderr
|
|
|
|
|
|
def worker_free_vram_mib(worker: dict) -> int:
|
|
rc, out, _ = ssh(worker["ssh_alias"], "nvidia-smi --query-gpu=memory.free --format=csv,noheader,nounits")
|
|
try:
|
|
return int(out.strip().splitlines()[0]) if rc == 0 else 0
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
def _cleanup_stale_demo(worker: dict) -> bool:
|
|
"""Kill the oldest demo.py on a worker with no active reconstruction to reclaim VRAM.
|
|
Returns True if a process was killed."""
|
|
rc, out, _ = ssh(worker["ssh_alias"], "pgrep -o -f 'python3.*demo\\.py' 2>/dev/null", timeout=10)
|
|
if rc != 0 or not out.strip():
|
|
return False
|
|
pid = out.strip().splitlines()[0].strip()
|
|
if not pid.isdigit():
|
|
return False
|
|
ssh(worker["ssh_alias"], f"kill {pid} 2>/dev/null", timeout=10)
|
|
print(f" cleanup: killed stale demo.py pid={pid} on {worker['host']}", flush=True)
|
|
time.sleep(3)
|
|
return True
|
|
|
|
|
|
_jobs_per_worker: dict[str, int] = {}
|
|
|
|
|
|
def pick_worker(estimated_vram_mib: int) -> dict | None:
|
|
"""Pick a worker that (a) has enough free VRAM AND (b) is currently the least busy —
|
|
otherwise the beefiest GPU always wins and the second worker sits idle."""
|
|
with _worker_lock:
|
|
candidates = []
|
|
for w in WORKERS:
|
|
free = worker_free_vram_mib(w) - _reserved_vram.get(w["host"], 0)
|
|
print(f" pick_worker: {w['host']} free={free} reserved={_reserved_vram.get(w['host'], 0)} load={_jobs_per_worker.get(w['host'], 0)}", flush=True)
|
|
if free < estimated_vram_mib:
|
|
continue
|
|
load = _jobs_per_worker.get(w["host"], 0)
|
|
# Sort key must avoid comparing dicts: use host as tiebreaker.
|
|
candidates.append(((load, -free, w["host"]), w))
|
|
if not candidates:
|
|
print(f" pick_worker: no candidate for {estimated_vram_mib} MiB", flush=True)
|
|
# Free VRAM on idle workers by killing leftover demo.py (kept alive for viser).
|
|
for w in WORKERS:
|
|
if _jobs_per_worker.get(w["host"], 0) == 0:
|
|
_cleanup_stale_demo(w)
|
|
return None
|
|
candidates.sort(key=lambda c: c[0])
|
|
w = candidates[0][1]
|
|
_reserved_vram[w["host"]] = _reserved_vram.get(w["host"], 0) + estimated_vram_mib
|
|
_jobs_per_worker[w["host"]] = _jobs_per_worker.get(w["host"], 0) + 1
|
|
return w
|
|
|
|
|
|
def release_worker(worker: dict, estimated_vram_mib: int):
|
|
with _worker_lock:
|
|
h = worker["host"]
|
|
_reserved_vram[h] = max(0, _reserved_vram.get(h, 0) - estimated_vram_mib)
|
|
_jobs_per_worker[h] = max(0, _jobs_per_worker.get(h, 0) - 1)
|
|
|
|
|
|
def estimate_vram_mib(frame_count: int) -> int:
|
|
# windowed mode + offload_to_cpu caps VRAM usage regardless of total frames.
|
|
# Observed: ~3.5 GB model + ~1.5 GB working set for window_size=16. Safe budget: 6 GB.
|
|
return 6000
|
|
|
|
|
|
def set_status(job_id: int, **fields):
|
|
# Auto-clear stale error text when the job moves into a live state so the dashboard
|
|
# stops showing a previous failure alongside a fresh run.
|
|
if fields.get("status") in ("extracting", "running", "done", "queued") and "error" not in fields:
|
|
fields["error"] = None
|
|
keys = list(fields.keys())
|
|
vals = [fields[k] for k in keys]
|
|
q = "UPDATE jobs SET " + ", ".join(f"{k}=?" for k in keys) + " WHERE id=?"
|
|
with closing(db()) as conn:
|
|
conn.execute(q, (*vals, job_id))
|
|
|
|
|
|
def set_stitch_status(stitch_id: int, **fields):
|
|
keys = list(fields.keys())
|
|
vals = [fields[k] for k in keys]
|
|
q = "UPDATE stitches SET " + ", ".join(f"{k}=?" for k in keys) + " WHERE id=?"
|
|
with closing(db()) as conn:
|
|
conn.execute(q, (*vals, stitch_id))
|
|
|
|
|
|
def count_frames(worker: dict, frames_dir: str) -> int:
|
|
rc, out, _ = ssh(worker["ssh_alias"], f"ls {shlex.quote(frames_dir)} 2>/dev/null | wc -l")
|
|
try:
|
|
return int(out.strip()) if rc == 0 else 0
|
|
except Exception:
|
|
return 0
|
|
|
|
|
|
# Record sessions always start with the AUV on deck or at the surface — these frames pollute
|
|
# reconstruction. Detect the first sustained underwater run (red channel absorbed by water, so
|
|
# mean_R < mean_G and mean_R < mean_B) and delete the hors-eau prefix before demo.py runs.
|
|
_AUTO_TRIM_SCRIPT = r"""
|
|
import cv2, glob, os, sys
|
|
frames_dir = sys.argv[1]
|
|
need_streak = 10 # consecutive underwater frames required to lock start/end
|
|
paths = sorted(glob.glob(os.path.join(frames_dir, 'frame_*.jpg')))
|
|
if not paths:
|
|
print('TRIM_RESULT 0 0 0'); sys.exit(0)
|
|
|
|
def is_underwater(path):
|
|
img = cv2.imread(path, cv2.IMREAD_REDUCED_COLOR_4)
|
|
if img is None:
|
|
return None
|
|
b, g, r = [float(c) for c in cv2.mean(img)[:3]]
|
|
# Red is absorbed by water → R < G and R < B on underwater shots.
|
|
return r < g - 5 and r < b - 5
|
|
|
|
# Scan from the start for the first sustained underwater run.
|
|
start = 0
|
|
streak = 0
|
|
for i, p in enumerate(paths):
|
|
uw = is_underwater(p)
|
|
if uw is None:
|
|
continue
|
|
if uw:
|
|
streak += 1
|
|
if streak >= need_streak:
|
|
start = i - need_streak + 1
|
|
break
|
|
else:
|
|
streak = 0
|
|
|
|
# Scan from the end for the last sustained underwater run.
|
|
end = len(paths)
|
|
streak = 0
|
|
for j in range(len(paths) - 1, -1, -1):
|
|
uw = is_underwater(paths[j])
|
|
if uw is None:
|
|
continue
|
|
if uw:
|
|
streak += 1
|
|
if streak >= need_streak:
|
|
end = j + need_streak # exclusive
|
|
break
|
|
else:
|
|
streak = 0
|
|
|
|
if end <= start:
|
|
# Sanity: never delete everything.
|
|
start = 0
|
|
end = len(paths)
|
|
|
|
removed_head = start
|
|
removed_tail = len(paths) - end
|
|
for p in paths[:start]:
|
|
try: os.remove(p)
|
|
except OSError: pass
|
|
for p in paths[end:]:
|
|
try: os.remove(p)
|
|
except OSError: pass
|
|
print(f'TRIM_RESULT {removed_head} {removed_tail} {end - start}')
|
|
"""
|
|
|
|
|
|
def _refresh_thumbnail(worker: dict, frames_dir: str, job_id: int) -> None:
|
|
"""Scp the latest extracted frame back to the dashboard host. Silent on failure."""
|
|
thumb_dir = DB_PATH.parent / "thumbnails"
|
|
thumb_dir.mkdir(exist_ok=True)
|
|
thumb_local = thumb_dir / f"job_{job_id}.jpg"
|
|
rc, out, _ = ssh(worker["ssh_alias"], f"ls -1 {shlex.quote(frames_dir)}/frame_*.jpg 2>/dev/null | tail -1")
|
|
latest = out.strip()
|
|
if not latest:
|
|
return
|
|
subprocess.run(
|
|
["scp", "-o", "BatchMode=yes", f"{worker['ssh_alias']}:{latest}", str(thumb_local)],
|
|
capture_output=True, timeout=15,
|
|
)
|
|
|
|
|
|
def trim_above_water_prefix(worker: dict, frames_dir: str) -> tuple[int, int, int]:
|
|
"""Delete leading and trailing out-of-water frames. Returns (head, tail, remaining)."""
|
|
script_remote = f"/tmp/cosma-trim-{os.getpid()}.py"
|
|
# Write script on worker and run it inside the lingbot-map venv (has cv2)
|
|
rc, _, err = ssh(
|
|
worker["ssh_alias"],
|
|
f"cat > {shlex.quote(script_remote)} << 'PYEOF'\n{_AUTO_TRIM_SCRIPT}\nPYEOF",
|
|
timeout=15,
|
|
)
|
|
if rc != 0:
|
|
print(f" ↳ trim script upload failed: {err[:150]}")
|
|
return (0, 0, 0)
|
|
rc, out, err = ssh(
|
|
worker["ssh_alias"],
|
|
f"source {shlex.quote(worker['lingbot_path'])}/.venv/bin/activate && "
|
|
f"python3 {shlex.quote(script_remote)} {shlex.quote(frames_dir)}; rm -f {shlex.quote(script_remote)}",
|
|
timeout=1200,
|
|
)
|
|
for line in out.splitlines():
|
|
if line.startswith("TRIM_RESULT"):
|
|
parts = line.split()
|
|
head, tail, remaining = int(parts[1]), int(parts[2]), int(parts[3])
|
|
return (head, tail, remaining)
|
|
print(f" ↳ trim unexpected output: {out[:200]} / {err[:200]}")
|
|
return (0, 0, 0)
|
|
|
|
|
|
def scp_to_worker(local_path: str, worker: dict, remote_path: str):
|
|
"""Copy a file to the worker.
|
|
|
|
`local_path` may be either:
|
|
- a path on the dispatcher host (standard scp from here)
|
|
- "host:abs_path" — pulled by the worker directly from `host`
|
|
(avoids routing bytes through the dispatcher).
|
|
"""
|
|
if ":" in local_path and not local_path.startswith("/"):
|
|
src_host, src_path = local_path.split(":", 1)
|
|
# Pull from source host directly on the worker
|
|
pull_cmd = (
|
|
f"scp -o BatchMode=yes {shlex.quote(src_host)}:{shlex.quote(src_path)} "
|
|
f"{shlex.quote(remote_path)}"
|
|
)
|
|
rc, _, err = ssh(worker["ssh_alias"], pull_cmd, timeout=7200)
|
|
if rc != 0:
|
|
raise RuntimeError(f"remote scp ({src_host}→{worker['host']}) failed: {err[:200]}")
|
|
return
|
|
r = subprocess.run(
|
|
["scp", "-o", "BatchMode=yes", local_path, f"{worker['ssh_alias']}:{remote_path}"],
|
|
capture_output=True, timeout=1800,
|
|
)
|
|
if r.returncode != 0:
|
|
raise RuntimeError(f"scp failed: {r.stderr.decode()[:200]}")
|
|
|
|
|
|
def _path_basename(p: str) -> str:
|
|
if ":" in p and not p.startswith("/"):
|
|
return Path(p.split(":", 1)[1]).name
|
|
return Path(p).name
|
|
|
|
|
|
def video_duration_s(worker: dict, worker_src: str) -> float:
|
|
_, out, _ = ssh(worker["ssh_alias"],
|
|
f"ffprobe -v error -show_entries format=duration "
|
|
f"-of csv=p=0 {shlex.quote(worker_src)} 2>/dev/null || echo 0")
|
|
try:
|
|
return float(out.strip())
|
|
except Exception:
|
|
return 0.0
|
|
|
|
|
|
def resolve_worker_video_source(worker: dict, video_path: str, frames_dir: str) -> tuple[str, bool]:
|
|
"""Return a path readable by the worker and whether it is ephemeral cache.
|
|
|
|
Preferred path: direct read from the shared/external storage if the worker can see it.
|
|
Fallback: stage a src_*.MP4 copy inside the job frames_dir on the worker.
|
|
"""
|
|
if ":" in video_path and not video_path.startswith("/"):
|
|
src_host, src_path = video_path.split(":", 1)
|
|
if src_host == worker["ssh_alias"] or src_host == worker["host"]:
|
|
return src_path, False
|
|
return f"{frames_dir}/src_{_path_basename(video_path)}", True
|
|
|
|
|
|
def ensure_worker_video_source(video_path: str, worker: dict, worker_src: str, is_ephemeral: bool, step_label: str):
|
|
if not is_ephemeral:
|
|
return
|
|
rc_check = ssh(worker["ssh_alias"], f"test -f {shlex.quote(worker_src)}")[0]
|
|
if rc_check == 0:
|
|
return
|
|
print(f" scp {_path_basename(video_path)} → {worker['host']}...")
|
|
set_status(step_label, step=f"scp: {_path_basename(video_path)}")
|
|
scp_to_worker(video_path, worker, worker_src)
|
|
|
|
|
|
def do_extract(job: sqlite3.Row, worker: dict) -> str:
|
|
videos = json.loads(job["video_paths"])
|
|
frames_dir = f"{worker['frames_dir']}/job_{job['id']}"
|
|
# Clean any frame_*.jpg from a prior run so count_frames reflects this extraction only
|
|
# (retries/restarts otherwise inflate frame_count with duplicates).
|
|
ssh(worker["ssh_alias"], f"mkdir -p {shlex.quote(frames_dir)} && rm -f {shlex.quote(frames_dir)}/frame_*.jpg")
|
|
idx = 0
|
|
total_frames_est = 0 # will be computed after each scp
|
|
total_duration_s = 0.0
|
|
n_videos = len(videos)
|
|
for vid_num, v in enumerate(videos, start=1):
|
|
vf = f"fps={FPS},scale={IMG_W}:{IMG_H}"
|
|
pattern = f"{frames_dir}/frame_%06d.jpg"
|
|
worker_src, is_ephemeral = resolve_worker_video_source(worker, v, frames_dir)
|
|
ensure_worker_video_source(v, worker, worker_src, is_ephemeral, job["id"])
|
|
dur = video_duration_s(worker, worker_src)
|
|
total_duration_s += dur
|
|
total_frames_est += max(1, int(dur * FPS))
|
|
|
|
exit_file = f"/tmp/cosma-ffmpeg-{job['id']}-{idx}.exit"
|
|
bg = (
|
|
f"rm -f {shlex.quote(exit_file)}; "
|
|
f"ffmpeg -hide_banner -loglevel error -i {shlex.quote(worker_src)} "
|
|
f"-vf {shlex.quote(vf)} -start_number {idx} -q:v 4 {shlex.quote(pattern)} "
|
|
f"</dev/null >/tmp/cosma-ffmpeg-{job['id']}.log 2>&1; "
|
|
f"echo $? > {shlex.quote(exit_file)}"
|
|
)
|
|
ssh(worker["ssh_alias"], f"setsid bash -c {shlex.quote(bg)} >/dev/null 2>&1 &")
|
|
|
|
thumb_refresh_counter = 0
|
|
while True:
|
|
# Use -s (file exists AND size > 0) to avoid race: setsid bash writes the exit code
|
|
# AFTER ffmpeg finishes; a plain -f can match a zero-byte placeholder mid-write.
|
|
rc_done, _, _ = ssh(worker["ssh_alias"], f"test -s {shlex.quote(exit_file)}")
|
|
current = count_frames(worker, frames_dir)
|
|
pct = min(99, current * 100 // total_frames_est)
|
|
set_status(job["id"], frame_count=current, progress=pct,
|
|
step=f"ffmpeg {vid_num}/{n_videos}: {current} frames")
|
|
# Refresh the preview thumbnail every few polls so the dashboard reflects what the
|
|
# camera is seeing right now, not the very first (surface) frame.
|
|
thumb_refresh_counter += 1
|
|
if thumb_refresh_counter % 3 == 1 and current > 0:
|
|
_refresh_thumbnail(worker, frames_dir, job["id"])
|
|
if rc_done == 0:
|
|
break
|
|
time.sleep(5)
|
|
|
|
_, code_str, _ = ssh(worker["ssh_alias"], f"cat {shlex.quote(exit_file)} 2>/dev/null || echo 1")
|
|
rc = int(code_str.strip()) if code_str.strip().isdigit() else 1
|
|
if rc != 0:
|
|
_, err, _ = ssh(worker["ssh_alias"], f"cat /tmp/cosma-ffmpeg-{job['id']}.log 2>/dev/null | tail -5 || echo ''")
|
|
raise RuntimeError(f"ffmpeg failed on {v}: {err[:200]}")
|
|
idx = count_frames(worker, frames_dir)
|
|
# Free MP4 cache immediately only when we had to stage a local copy.
|
|
if is_ephemeral:
|
|
ssh(worker["ssh_alias"], f"rm -f {shlex.quote(worker_src)}")
|
|
set_status(job["id"], frame_count=idx, progress=min(99, idx * 100 // total_frames_est))
|
|
# Persist the measured video duration so the dashboard shows real length (segment_label
|
|
# from ingest is only the timestamp of the first MP4 and lies when a segment spans multiple).
|
|
set_status(job["id"], video_duration_s=total_duration_s, step="trimming hors-eau")
|
|
# Skip segments that are too short to contain a meaningful dive.
|
|
min_video_s = int(os.environ.get("COSMA_QC_MIN_VIDEO_S", "480")) # 8 min default
|
|
if total_duration_s < min_video_s:
|
|
print(f" ↳ job #{job['id']}: video too short ({int(total_duration_s)}s < {min_video_s}s) — marking skipped")
|
|
set_status(job["id"], status="skipped", error=f"too short: {int(total_duration_s)}s of video")
|
|
raise RuntimeError("skipped_short")
|
|
# Drop the hors-eau prefix AND suffix before reconstruction — AUV is out-of-water at both ends.
|
|
head, tail, remaining = trim_above_water_prefix(worker, frames_dir)
|
|
if head or tail:
|
|
print(f" ↳ job #{job['id']}: trimmed head={head} tail={tail} out-of-water, {remaining} kept")
|
|
set_status(job["id"], frame_count=remaining, trimmed_head=head, trimmed_tail=tail)
|
|
# Skip jobs with too little underwater content to be worth reconstructing (e.g., brief
|
|
# surface checks that the auto-segmentation picked up as a dive).
|
|
min_frames = max(60, int(30 * FPS)) # need ~30 s of underwater footage minimum
|
|
if remaining < min_frames:
|
|
print(f" ↳ job #{job['id']}: only {remaining} underwater frames (<{min_frames}) — marking skipped")
|
|
set_status(job["id"], status="skipped", error=f"too short: {remaining} underwater frames")
|
|
raise RuntimeError("skipped_short")
|
|
# Snapshot the latest post-trim frame so the dashboard preview matches what the demo.py will see.
|
|
_refresh_thumbnail(worker, frames_dir, job["id"])
|
|
# Trim once per job so LVM thin pool on the host actually reclaims the freed blocks.
|
|
ssh(worker["ssh_alias"], "sudo fstrim / 2>/dev/null || fstrim / 2>/dev/null", timeout=60)
|
|
return frames_dir
|
|
|
|
|
|
def do_reconstruct(job: sqlite3.Row, worker: dict, frames_dir: str) -> tuple[str, str, str]:
|
|
port = worker["viser_port_base"] + job["id"]
|
|
log = f"/tmp/cosma-qc-job-{job['id']}.log"
|
|
ckpt = f"{worker['lingbot_path']}/checkpoints/lingbot-map/lingbot-map-long.pt"
|
|
ply_path = f"{frames_dir}/reconstruction.ply"
|
|
# More conservative RAM policy: recent runs die with rc=137 during image loading.
|
|
# Push stride earlier and keep windows smaller to trade speed for survival.
|
|
frame_count = job["frame_count"] or 0
|
|
ram_gb = 23 if worker["host"] == "192.168.0.87" else 62
|
|
ram_budget_gb = ram_gb * 0.22
|
|
stride = 1
|
|
while frame_count * 3.15 / 1024 / stride > ram_budget_gb:
|
|
stride += 1
|
|
if frame_count > 4000:
|
|
stride = max(stride, 4)
|
|
elif frame_count > 2500:
|
|
stride = max(stride, 3)
|
|
elif frame_count > 1500:
|
|
stride = max(stride, 2)
|
|
eff = frame_count // max(1, stride) if frame_count else 0
|
|
if eff > 2200:
|
|
window_size, overlap_size = 24, 6
|
|
elif eff > 900:
|
|
window_size, overlap_size = 16, 4
|
|
else:
|
|
window_size, overlap_size = 12, 3
|
|
marker = shlex.quote(frames_dir)
|
|
cmd = (
|
|
f"cd {shlex.quote(worker['lingbot_path'])} && source .venv/bin/activate && "
|
|
f"setsid python3 demo.py --model_path {shlex.quote(ckpt)} "
|
|
f"--image_folder {shlex.quote(frames_dir)} --port {port} "
|
|
f"--stride {stride} --use_sdpa --mode windowed --window_size {window_size} --overlap_size {overlap_size} --offload_to_cpu "
|
|
f"--save_ply {shlex.quote(ply_path)} > {log} 2>&1 & "
|
|
f"DEMO_PID=$!; "
|
|
f"for i in $(seq 1 3600); do "
|
|
f" if ! kill -0 $DEMO_PID 2>/dev/null; then wait $DEMO_PID; exit $?; fi; "
|
|
f" if grep -q 'PLY saved:' {log} 2>/dev/null; then "
|
|
# Keep demo.py alive so its viser/PointCloudViewer (with camera frustums, per-frame
|
|
# confidence filtering, animation) stays reachable. Standalone viser_ply.py only has
|
|
# XYZ+RGB which gives a poor-looking cloud. The worker eats ~6GB VRAM per alive demo.py
|
|
# until pick_worker can no longer fit a new job; _cleanup_stale_demos reaps the oldest.
|
|
f" exit 0; "
|
|
f" fi; "
|
|
f" sleep 3; "
|
|
f"done; "
|
|
f"pkill -KILL -f \"demo.py.*{frames_dir}\" 2>/dev/null; exit 124"
|
|
)
|
|
set_status(job["id"], step=f"reconstruct demo.py (windowed w{window_size}, stride {stride})")
|
|
rc, _, err = ssh(worker["ssh_alias"], cmd, timeout=3 * 3600)
|
|
# Accept rc==0 OR PLY file exists with non-zero size (kill -TERM may return non-zero)
|
|
ply_rc, ply_size, _ = ssh(worker["ssh_alias"], f"stat -c %s {shlex.quote(ply_path)} 2>/dev/null || echo 0")
|
|
try:
|
|
ply_ok = int(ply_size.strip()) > 0
|
|
except ValueError:
|
|
ply_ok = False
|
|
if not ply_ok:
|
|
tail = ssh(worker["ssh_alias"], f"tail -30 {log}")[1]
|
|
raise RuntimeError(f"demo.py failed (rc={rc}): {err[:200]}\n---\n{tail[:800]}")
|
|
viser_url = f"http://{worker['host']}:{port}"
|
|
return viser_url, log, ply_path
|
|
|
|
|
|
def _maybe_create_per_auv_stitch(job_id: int):
|
|
with closing(db()) as conn:
|
|
job = conn.execute("SELECT * FROM jobs WHERE id=?", (job_id,)).fetchone()
|
|
if not job:
|
|
return
|
|
acq_id, auv = job["acquisition_id"], job["auv"]
|
|
# Skip jobs flagged 'skipped' (bad segments: GoPro on deck, no underwater content).
|
|
total = conn.execute(
|
|
"SELECT COUNT(*) FROM jobs WHERE acquisition_id=? AND auv=? AND status != 'skipped'", (acq_id, auv)
|
|
).fetchone()[0]
|
|
done = conn.execute(
|
|
"SELECT COUNT(*) FROM jobs WHERE acquisition_id=? AND auv=? AND status='done'", (acq_id, auv)
|
|
).fetchone()[0]
|
|
if total == 0 or done < total:
|
|
return
|
|
existing = conn.execute(
|
|
"SELECT id FROM stitches WHERE acquisition_id=? AND level='per_auv' AND auv=?", (acq_id, auv)
|
|
).fetchone()
|
|
if existing:
|
|
return
|
|
job_ids = [r["id"] for r in conn.execute(
|
|
"SELECT id FROM jobs WHERE acquisition_id=? AND auv=? AND status='done'", (acq_id, auv)
|
|
).fetchall()]
|
|
conn.execute(
|
|
"INSERT INTO stitches (acquisition_id, level, auv, input_job_ids) VALUES (?,?,?,?)",
|
|
(acq_id, "per_auv", auv, json.dumps(job_ids))
|
|
)
|
|
print(f" → Stitch per_auv créé pour {auv} acq#{acq_id}")
|
|
|
|
|
|
def _maybe_create_cross_auv_stitch(stitch_id: int):
|
|
with closing(db()) as conn:
|
|
st = conn.execute("SELECT * FROM stitches WHERE id=?", (stitch_id,)).fetchone()
|
|
if not st:
|
|
return
|
|
acq_id = st["acquisition_id"]
|
|
n_auvs = conn.execute(
|
|
"SELECT COUNT(DISTINCT auv) FROM jobs WHERE acquisition_id=?", (acq_id,)
|
|
).fetchone()[0]
|
|
if n_auvs < 2:
|
|
return
|
|
total_per_auv = conn.execute(
|
|
"SELECT COUNT(*) FROM stitches WHERE acquisition_id=? AND level='per_auv'", (acq_id,)
|
|
).fetchone()[0]
|
|
done_per_auv = conn.execute(
|
|
"SELECT COUNT(*) FROM stitches WHERE acquisition_id=? AND level='per_auv' AND status='done'", (acq_id,)
|
|
).fetchone()[0]
|
|
if total_per_auv == 0 or done_per_auv < n_auvs:
|
|
return
|
|
existing = conn.execute(
|
|
"SELECT id FROM stitches WHERE acquisition_id=? AND level='cross_auv'", (acq_id,)
|
|
).fetchone()
|
|
if existing:
|
|
return
|
|
stitch_ids = [r["id"] for r in conn.execute(
|
|
"SELECT id FROM stitches WHERE acquisition_id=? AND level='per_auv'", (acq_id,)
|
|
).fetchall()]
|
|
conn.execute(
|
|
"INSERT INTO stitches (acquisition_id, level, input_stitch_ids, input_job_ids) VALUES (?,?,?,?)",
|
|
(acq_id, "cross_auv", json.dumps(stitch_ids), "[]")
|
|
)
|
|
print(f" → Stitch cross_auv créé pour acq#{acq_id}")
|
|
|
|
|
|
def deploy_stitch_script(worker: dict):
|
|
subprocess.run(
|
|
["scp", str(STITCH_SCRIPT), f"{worker['ssh_alias']}:/tmp/cosma-stitch.py"],
|
|
capture_output=True, timeout=30
|
|
)
|
|
|
|
|
|
def run_one_stitch(stitch: sqlite3.Row):
|
|
stitch_id = stitch["id"]
|
|
worker = pick_worker(2000)
|
|
if not worker:
|
|
worker = WORKERS[0]
|
|
|
|
with closing(db()) as conn:
|
|
if stitch["level"] == "per_auv":
|
|
job_ids = json.loads(stitch["input_job_ids"] or "[]")
|
|
if job_ids:
|
|
rows = conn.execute(
|
|
f"SELECT ply_path FROM jobs WHERE id IN ({','.join('?'*len(job_ids))})",
|
|
job_ids
|
|
).fetchall()
|
|
else:
|
|
rows = []
|
|
ply_paths = [r["ply_path"] for r in rows if r["ply_path"]]
|
|
else:
|
|
stitch_ids = json.loads(stitch["input_stitch_ids"] or "[]")
|
|
if stitch_ids:
|
|
rows = conn.execute(
|
|
f"SELECT output_ply FROM stitches WHERE id IN ({','.join('?'*len(stitch_ids))})",
|
|
stitch_ids
|
|
).fetchall()
|
|
else:
|
|
rows = []
|
|
ply_paths = [r["output_ply"] for r in rows if r["output_ply"]]
|
|
|
|
if len(ply_paths) == 0:
|
|
set_stitch_status(stitch_id, status="error",
|
|
error="Aucun PLY disponible",
|
|
finished_at=_now_iso())
|
|
return
|
|
|
|
out_ply = f"{worker['frames_dir']}/stitch_{stitch_id}.ply"
|
|
|
|
# Single PLY — no alignment needed, pass through directly.
|
|
if len(ply_paths) == 1:
|
|
rc, _, err = ssh(worker["ssh_alias"], f"cp {shlex.quote(ply_paths[0])} {shlex.quote(out_ply)}")
|
|
if rc != 0:
|
|
set_stitch_status(stitch_id, status="error", error=f"cp failed: {err[:200]}", finished_at=_now_iso())
|
|
return
|
|
set_stitch_status(stitch_id, status="done", output_ply=out_ply, finished_at=_now_iso())
|
|
print(f" → stitch #{stitch_id} passthrough (1 PLY) → {out_ply}")
|
|
_maybe_create_cross_auv_stitch(stitch_id)
|
|
return
|
|
deploy_stitch_script(worker)
|
|
|
|
cmd = (
|
|
f"source {shlex.quote(worker['lingbot_path'])}/.venv/bin/activate && "
|
|
f"python3 /tmp/cosma-stitch.py {shlex.quote(out_ply)} "
|
|
+ " ".join(shlex.quote(p) for p in ply_paths)
|
|
+ f" > /tmp/cosma-stitch-{stitch_id}.log 2>&1"
|
|
)
|
|
|
|
set_stitch_status(stitch_id, status="running", worker_host=worker["host"], started_at=_now_iso())
|
|
try:
|
|
rc, _, err = ssh(worker["ssh_alias"], cmd, timeout=4 * 3600)
|
|
except Exception as e:
|
|
set_stitch_status(stitch_id, status="error", error=str(e)[:500], finished_at=_now_iso())
|
|
return
|
|
|
|
if rc == 0:
|
|
set_stitch_status(stitch_id, status="done", output_ply=out_ply, finished_at=_now_iso())
|
|
_maybe_create_cross_auv_stitch(stitch_id)
|
|
else:
|
|
tail = ssh(worker["ssh_alias"], f"tail -20 /tmp/cosma-stitch-{stitch_id}.log")[1]
|
|
set_stitch_status(stitch_id, status="error",
|
|
error=f"{err[:200]}\n{tail[:600]}",
|
|
finished_at=_now_iso())
|
|
|
|
|
|
def run_one(job: sqlite3.Row) -> bool:
|
|
"""Returns True if a worker was picked and work started."""
|
|
job_id = job["id"]
|
|
estimated = estimate_vram_mib(job["frame_count"] or 150)
|
|
worker = pick_worker(estimated)
|
|
if not worker:
|
|
# release claim so the job can be re-tried by main loop
|
|
set_status(job_id, status="queued")
|
|
return False
|
|
set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso())
|
|
try:
|
|
frames_dir = do_extract(job, worker)
|
|
frame_count = count_frames(worker, frames_dir)
|
|
set_status(job_id, frames_dir=frames_dir, frame_count=frame_count,
|
|
status="running", progress=0)
|
|
viser_url, log, ply_path = do_reconstruct(job, worker, frames_dir)
|
|
set_status(job_id, status="done", viser_url=viser_url, ply_path=ply_path,
|
|
progress=100, log_tail=log, finished_at=_now_iso())
|
|
_maybe_create_per_auv_stitch(job_id)
|
|
except Exception as e:
|
|
# do_extract raises "skipped_short" after flagging status='skipped' — don't override.
|
|
if "skipped_short" not in str(e):
|
|
set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso())
|
|
else:
|
|
set_status(job_id, finished_at=_now_iso())
|
|
finally:
|
|
release_worker(worker, estimated)
|
|
return True
|
|
|
|
|
|
def pop_queued() -> sqlite3.Row | None:
|
|
"""Atomic claim: grab a queued job and mark it 'claimed' to prevent double-dispatch."""
|
|
with closing(db()) as conn:
|
|
try:
|
|
conn.execute("BEGIN IMMEDIATE")
|
|
row = conn.execute(
|
|
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1"
|
|
).fetchone()
|
|
if row:
|
|
conn.execute("UPDATE jobs SET status='claimed' WHERE id=?", (row["id"],))
|
|
conn.execute("COMMIT")
|
|
except Exception:
|
|
conn.execute("ROLLBACK")
|
|
return None
|
|
return row
|
|
|
|
|
|
def pop_queued_stitch() -> sqlite3.Row | None:
|
|
with closing(db()) as conn:
|
|
try:
|
|
conn.execute("BEGIN IMMEDIATE")
|
|
row = conn.execute(
|
|
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1"
|
|
).fetchone()
|
|
if row:
|
|
conn.execute("UPDATE stitches SET status='claimed' WHERE id=?", (row["id"],))
|
|
conn.execute("COMMIT")
|
|
except Exception:
|
|
conn.execute("ROLLBACK")
|
|
return None
|
|
return row
|
|
|
|
|
|
def write_heartbeat():
|
|
hb = DB_PATH.parent / "dispatcher.heartbeat"
|
|
try:
|
|
hb.write_text(_now_iso())
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
_heartbeat_stop = threading.Event()
|
|
|
|
|
|
def _heartbeat_loop():
|
|
while not _heartbeat_stop.is_set():
|
|
write_heartbeat()
|
|
_heartbeat_stop.wait(5)
|
|
|
|
|
|
def _run_one_thread(job: sqlite3.Row):
|
|
try:
|
|
if not run_one(job):
|
|
print(f" ↳ job #{job['id']}: pas de worker dispo, remis en queued")
|
|
except Exception as e:
|
|
set_status(job["id"], status="error", error=f"run_one thread crashed: {str(e)[:1500]}",
|
|
finished_at=_now_iso())
|
|
|
|
|
|
def main():
|
|
print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}")
|
|
threading.Thread(target=_heartbeat_loop, daemon=True).start()
|
|
active: set[threading.Thread] = set()
|
|
max_parallel = max(1, len(WORKERS))
|
|
while True:
|
|
# drain finished threads
|
|
active = {t for t in active if t.is_alive()}
|
|
if len(active) < max_parallel:
|
|
job = pop_queued()
|
|
if job:
|
|
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']}) [active={len(active) + 1}]")
|
|
t = threading.Thread(target=_run_one_thread, args=(job,), daemon=False)
|
|
t.start()
|
|
active.add(t)
|
|
# brief pause so the thread can reserve its worker before we pop another
|
|
time.sleep(0.5)
|
|
continue
|
|
stitch = pop_queued_stitch()
|
|
if stitch:
|
|
label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}"
|
|
print(f"→ stitch #{stitch['id']} ({label})")
|
|
run_one_stitch(stitch)
|
|
continue
|
|
time.sleep(POLL_S)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
sys.exit(0)
|