diff --git a/scripts/dispatcher.py b/scripts/dispatcher.py index 0770513..c073485 100644 --- a/scripts/dispatcher.py +++ b/scripts/dispatcher.py @@ -25,6 +25,7 @@ import shlex import sqlite3 import subprocess import sys +import threading import time from contextlib import closing from datetime import datetime, timezone @@ -59,6 +60,9 @@ DEFAULT_WORKERS = [ ] WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS))) +_worker_lock = threading.Lock() +_reserved_vram = {w["host"]: 0 for w in WORKERS} + def db() -> sqlite3.Connection: conn = sqlite3.connect(DB_PATH, isolation_level=None) @@ -84,12 +88,24 @@ def worker_free_vram_mib(worker: dict) -> int: def pick_worker(estimated_vram_mib: int) -> dict | None: - best = None - for w in WORKERS: - free = worker_free_vram_mib(w) - if free >= estimated_vram_mib and (best is None or free > best[0]): - best = (free, w) - return best[1] if best else None + """Pick worker with most effective free VRAM (actual free minus local reservations) + and reserve the estimated VRAM. Returns None if none fit.""" + with _worker_lock: + best = None + for w in WORKERS: + free = worker_free_vram_mib(w) - _reserved_vram.get(w["host"], 0) + if free >= estimated_vram_mib and (best is None or free > best[0]): + best = (free, w) + if best: + _reserved_vram[best[1]["host"]] = _reserved_vram.get(best[1]["host"], 0) + estimated_vram_mib + return best[1] + return None + + +def release_worker(worker: dict, estimated_vram_mib: int): + with _worker_lock: + h = worker["host"] + _reserved_vram[h] = max(0, _reserved_vram.get(h, 0) - estimated_vram_mib) def estimate_vram_mib(frame_count: int) -> int: @@ -399,6 +415,8 @@ def run_one(job: sqlite3.Row) -> bool: estimated = estimate_vram_mib(job["frame_count"] or 150) worker = pick_worker(estimated) if not worker: + # release claim so the job can be re-tried by main loop + set_status(job_id, status="queued") return False set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso()) try: @@ -412,21 +430,42 @@ def run_one(job: sqlite3.Row) -> bool: _maybe_create_per_auv_stitch(job_id) except Exception as e: set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso()) + finally: + release_worker(worker, estimated) return True def pop_queued() -> sqlite3.Row | None: + """Atomic claim: grab a queued job and mark it 'claimed' to prevent double-dispatch.""" with closing(db()) as conn: - return conn.execute( - "SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1" - ).fetchone() + try: + conn.execute("BEGIN IMMEDIATE") + row = conn.execute( + "SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1" + ).fetchone() + if row: + conn.execute("UPDATE jobs SET status='claimed' WHERE id=?", (row["id"],)) + conn.execute("COMMIT") + except Exception: + conn.execute("ROLLBACK") + return None + return row def pop_queued_stitch() -> sqlite3.Row | None: with closing(db()) as conn: - return conn.execute( - "SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1" - ).fetchone() + try: + conn.execute("BEGIN IMMEDIATE") + row = conn.execute( + "SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1" + ).fetchone() + if row: + conn.execute("UPDATE stitches SET status='claimed' WHERE id=?", (row["id"],)) + conn.execute("COMMIT") + except Exception: + conn.execute("ROLLBACK") + return None + return row def write_heartbeat(): @@ -437,23 +476,48 @@ def write_heartbeat(): pass +_heartbeat_stop = threading.Event() + + +def _heartbeat_loop(): + while not _heartbeat_stop.is_set(): + write_heartbeat() + _heartbeat_stop.wait(5) + + +def _run_one_thread(job: sqlite3.Row): + try: + if not run_one(job): + print(f" ↳ job #{job['id']}: pas de worker dispo, remis en queued") + except Exception as e: + set_status(job["id"], status="error", error=f"run_one thread crashed: {str(e)[:1500]}", + finished_at=_now_iso()) + + def main(): print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}") + threading.Thread(target=_heartbeat_loop, daemon=True).start() + active: set[threading.Thread] = set() + max_parallel = max(1, len(WORKERS)) while True: - write_heartbeat() - job = pop_queued() - if job: - print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']})") - if not run_one(job): - print(" ↳ pas de worker dispo, retry dans 30s") - time.sleep(30) - continue - stitch = pop_queued_stitch() - if stitch: - label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}" - print(f"→ stitch #{stitch['id']} ({label})") - run_one_stitch(stitch) - continue + # drain finished threads + active = {t for t in active if t.is_alive()} + if len(active) < max_parallel: + job = pop_queued() + if job: + print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']}) [active={len(active) + 1}]") + t = threading.Thread(target=_run_one_thread, args=(job,), daemon=False) + t.start() + active.add(t) + # brief pause so the thread can reserve its worker before we pop another + time.sleep(0.5) + continue + stitch = pop_queued_stitch() + if stitch: + label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}" + print(f"→ stitch #{stitch['id']} ({label})") + run_one_stitch(stitch) + continue time.sleep(POLL_S)