dispatcher — parallel workers (threading) + heartbeat daemon + atomic claim

- threading: main loop spawns run_one in a thread per queued job; up to len(WORKERS) concurrent.
- pick_worker: thread-safe VRAM reservation to avoid two threads picking the same GPU.
- pop_queued/pop_queued_stitch: atomic SELECT+UPDATE sous BEGIN IMMEDIATE (status claimed).
- Heartbeat daemon: thread qui ecrit dispatcher.heartbeat toutes les 5s (fini le faux dead pendant les jobs longs).
- run_one: libere la reservation VRAM sur finally (error/done/queued).
This commit is contained in:
Flag
2026-04-21 23:15:24 +00:00
parent 43e2e6836e
commit d4158b24bc

View File

@@ -25,6 +25,7 @@ import shlex
import sqlite3 import sqlite3
import subprocess import subprocess
import sys import sys
import threading
import time import time
from contextlib import closing from contextlib import closing
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -59,6 +60,9 @@ DEFAULT_WORKERS = [
] ]
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS))) WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS)))
_worker_lock = threading.Lock()
_reserved_vram = {w["host"]: 0 for w in WORKERS}
def db() -> sqlite3.Connection: def db() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, isolation_level=None) conn = sqlite3.connect(DB_PATH, isolation_level=None)
@@ -84,12 +88,24 @@ def worker_free_vram_mib(worker: dict) -> int:
def pick_worker(estimated_vram_mib: int) -> dict | None: def pick_worker(estimated_vram_mib: int) -> dict | None:
best = None """Pick worker with most effective free VRAM (actual free minus local reservations)
for w in WORKERS: and reserve the estimated VRAM. Returns None if none fit."""
free = worker_free_vram_mib(w) with _worker_lock:
if free >= estimated_vram_mib and (best is None or free > best[0]): best = None
best = (free, w) for w in WORKERS:
return best[1] if best else None free = worker_free_vram_mib(w) - _reserved_vram.get(w["host"], 0)
if free >= estimated_vram_mib and (best is None or free > best[0]):
best = (free, w)
if best:
_reserved_vram[best[1]["host"]] = _reserved_vram.get(best[1]["host"], 0) + estimated_vram_mib
return best[1]
return None
def release_worker(worker: dict, estimated_vram_mib: int):
with _worker_lock:
h = worker["host"]
_reserved_vram[h] = max(0, _reserved_vram.get(h, 0) - estimated_vram_mib)
def estimate_vram_mib(frame_count: int) -> int: def estimate_vram_mib(frame_count: int) -> int:
@@ -399,6 +415,8 @@ def run_one(job: sqlite3.Row) -> bool:
estimated = estimate_vram_mib(job["frame_count"] or 150) estimated = estimate_vram_mib(job["frame_count"] or 150)
worker = pick_worker(estimated) worker = pick_worker(estimated)
if not worker: if not worker:
# release claim so the job can be re-tried by main loop
set_status(job_id, status="queued")
return False return False
set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso()) set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso())
try: try:
@@ -412,21 +430,42 @@ def run_one(job: sqlite3.Row) -> bool:
_maybe_create_per_auv_stitch(job_id) _maybe_create_per_auv_stitch(job_id)
except Exception as e: except Exception as e:
set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso()) set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso())
finally:
release_worker(worker, estimated)
return True return True
def pop_queued() -> sqlite3.Row | None: def pop_queued() -> sqlite3.Row | None:
"""Atomic claim: grab a queued job and mark it 'claimed' to prevent double-dispatch."""
with closing(db()) as conn: with closing(db()) as conn:
return conn.execute( try:
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1" conn.execute("BEGIN IMMEDIATE")
).fetchone() row = conn.execute(
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1"
).fetchone()
if row:
conn.execute("UPDATE jobs SET status='claimed' WHERE id=?", (row["id"],))
conn.execute("COMMIT")
except Exception:
conn.execute("ROLLBACK")
return None
return row
def pop_queued_stitch() -> sqlite3.Row | None: def pop_queued_stitch() -> sqlite3.Row | None:
with closing(db()) as conn: with closing(db()) as conn:
return conn.execute( try:
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1" conn.execute("BEGIN IMMEDIATE")
).fetchone() row = conn.execute(
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1"
).fetchone()
if row:
conn.execute("UPDATE stitches SET status='claimed' WHERE id=?", (row["id"],))
conn.execute("COMMIT")
except Exception:
conn.execute("ROLLBACK")
return None
return row
def write_heartbeat(): def write_heartbeat():
@@ -437,23 +476,48 @@ def write_heartbeat():
pass pass
_heartbeat_stop = threading.Event()
def _heartbeat_loop():
while not _heartbeat_stop.is_set():
write_heartbeat()
_heartbeat_stop.wait(5)
def _run_one_thread(job: sqlite3.Row):
try:
if not run_one(job):
print(f" ↳ job #{job['id']}: pas de worker dispo, remis en queued")
except Exception as e:
set_status(job["id"], status="error", error=f"run_one thread crashed: {str(e)[:1500]}",
finished_at=_now_iso())
def main(): def main():
print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}") print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}")
threading.Thread(target=_heartbeat_loop, daemon=True).start()
active: set[threading.Thread] = set()
max_parallel = max(1, len(WORKERS))
while True: while True:
write_heartbeat() # drain finished threads
job = pop_queued() active = {t for t in active if t.is_alive()}
if job: if len(active) < max_parallel:
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']})") job = pop_queued()
if not run_one(job): if job:
print(" ↳ pas de worker dispo, retry dans 30s") print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']}) [active={len(active) + 1}]")
time.sleep(30) t = threading.Thread(target=_run_one_thread, args=(job,), daemon=False)
continue t.start()
stitch = pop_queued_stitch() active.add(t)
if stitch: # brief pause so the thread can reserve its worker before we pop another
label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}" time.sleep(0.5)
print(f"→ stitch #{stitch['id']} ({label})") continue
run_one_stitch(stitch) stitch = pop_queued_stitch()
continue if stitch:
label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}"
print(f"→ stitch #{stitch['id']} ({label})")
run_one_stitch(stitch)
continue
time.sleep(POLL_S) time.sleep(POLL_S)