dispatcher — parallel workers (threading) + heartbeat daemon + atomic claim

- threading: main loop spawns run_one in a thread per queued job; up to len(WORKERS) concurrent.
- pick_worker: thread-safe VRAM reservation to avoid two threads picking the same GPU.
- pop_queued/pop_queued_stitch: atomic SELECT+UPDATE sous BEGIN IMMEDIATE (status claimed).
- Heartbeat daemon: thread qui ecrit dispatcher.heartbeat toutes les 5s (fini le faux dead pendant les jobs longs).
- run_one: libere la reservation VRAM sur finally (error/done/queued).
This commit is contained in:
Flag
2026-04-21 23:15:24 +00:00
parent 43e2e6836e
commit d4158b24bc

View File

@@ -25,6 +25,7 @@ import shlex
import sqlite3
import subprocess
import sys
import threading
import time
from contextlib import closing
from datetime import datetime, timezone
@@ -59,6 +60,9 @@ DEFAULT_WORKERS = [
]
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS)))
_worker_lock = threading.Lock()
_reserved_vram = {w["host"]: 0 for w in WORKERS}
def db() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, isolation_level=None)
@@ -84,12 +88,24 @@ def worker_free_vram_mib(worker: dict) -> int:
def pick_worker(estimated_vram_mib: int) -> dict | None:
"""Pick worker with most effective free VRAM (actual free minus local reservations)
and reserve the estimated VRAM. Returns None if none fit."""
with _worker_lock:
best = None
for w in WORKERS:
free = worker_free_vram_mib(w)
free = worker_free_vram_mib(w) - _reserved_vram.get(w["host"], 0)
if free >= estimated_vram_mib and (best is None or free > best[0]):
best = (free, w)
return best[1] if best else None
if best:
_reserved_vram[best[1]["host"]] = _reserved_vram.get(best[1]["host"], 0) + estimated_vram_mib
return best[1]
return None
def release_worker(worker: dict, estimated_vram_mib: int):
with _worker_lock:
h = worker["host"]
_reserved_vram[h] = max(0, _reserved_vram.get(h, 0) - estimated_vram_mib)
def estimate_vram_mib(frame_count: int) -> int:
@@ -399,6 +415,8 @@ def run_one(job: sqlite3.Row) -> bool:
estimated = estimate_vram_mib(job["frame_count"] or 150)
worker = pick_worker(estimated)
if not worker:
# release claim so the job can be re-tried by main loop
set_status(job_id, status="queued")
return False
set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso())
try:
@@ -412,21 +430,42 @@ def run_one(job: sqlite3.Row) -> bool:
_maybe_create_per_auv_stitch(job_id)
except Exception as e:
set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso())
finally:
release_worker(worker, estimated)
return True
def pop_queued() -> sqlite3.Row | None:
"""Atomic claim: grab a queued job and mark it 'claimed' to prevent double-dispatch."""
with closing(db()) as conn:
return conn.execute(
try:
conn.execute("BEGIN IMMEDIATE")
row = conn.execute(
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1"
).fetchone()
if row:
conn.execute("UPDATE jobs SET status='claimed' WHERE id=?", (row["id"],))
conn.execute("COMMIT")
except Exception:
conn.execute("ROLLBACK")
return None
return row
def pop_queued_stitch() -> sqlite3.Row | None:
with closing(db()) as conn:
return conn.execute(
try:
conn.execute("BEGIN IMMEDIATE")
row = conn.execute(
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1"
).fetchone()
if row:
conn.execute("UPDATE stitches SET status='claimed' WHERE id=?", (row["id"],))
conn.execute("COMMIT")
except Exception:
conn.execute("ROLLBACK")
return None
return row
def write_heartbeat():
@@ -437,16 +476,41 @@ def write_heartbeat():
pass
_heartbeat_stop = threading.Event()
def _heartbeat_loop():
while not _heartbeat_stop.is_set():
write_heartbeat()
_heartbeat_stop.wait(5)
def _run_one_thread(job: sqlite3.Row):
try:
if not run_one(job):
print(f" ↳ job #{job['id']}: pas de worker dispo, remis en queued")
except Exception as e:
set_status(job["id"], status="error", error=f"run_one thread crashed: {str(e)[:1500]}",
finished_at=_now_iso())
def main():
print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}")
threading.Thread(target=_heartbeat_loop, daemon=True).start()
active: set[threading.Thread] = set()
max_parallel = max(1, len(WORKERS))
while True:
write_heartbeat()
# drain finished threads
active = {t for t in active if t.is_alive()}
if len(active) < max_parallel:
job = pop_queued()
if job:
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']})")
if not run_one(job):
print(" ↳ pas de worker dispo, retry dans 30s")
time.sleep(30)
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']}) [active={len(active) + 1}]")
t = threading.Thread(target=_run_one_thread, args=(job,), daemon=False)
t.start()
active.add(t)
# brief pause so the thread can reserve its worker before we pop another
time.sleep(0.5)
continue
stitch = pop_queued_stitch()
if stitch: