dispatcher — parallel workers (threading) + heartbeat daemon + atomic claim
- threading: main loop spawns run_one in a thread per queued job; up to len(WORKERS) concurrent. - pick_worker: thread-safe VRAM reservation to avoid two threads picking the same GPU. - pop_queued/pop_queued_stitch: atomic SELECT+UPDATE sous BEGIN IMMEDIATE (status claimed). - Heartbeat daemon: thread qui ecrit dispatcher.heartbeat toutes les 5s (fini le faux dead pendant les jobs longs). - run_one: libere la reservation VRAM sur finally (error/done/queued).
This commit is contained in:
@@ -25,6 +25,7 @@ import shlex
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from contextlib import closing
|
||||
from datetime import datetime, timezone
|
||||
@@ -59,6 +60,9 @@ DEFAULT_WORKERS = [
|
||||
]
|
||||
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps(DEFAULT_WORKERS)))
|
||||
|
||||
_worker_lock = threading.Lock()
|
||||
_reserved_vram = {w["host"]: 0 for w in WORKERS}
|
||||
|
||||
|
||||
def db() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(DB_PATH, isolation_level=None)
|
||||
@@ -84,12 +88,24 @@ def worker_free_vram_mib(worker: dict) -> int:
|
||||
|
||||
|
||||
def pick_worker(estimated_vram_mib: int) -> dict | None:
|
||||
best = None
|
||||
for w in WORKERS:
|
||||
free = worker_free_vram_mib(w)
|
||||
if free >= estimated_vram_mib and (best is None or free > best[0]):
|
||||
best = (free, w)
|
||||
return best[1] if best else None
|
||||
"""Pick worker with most effective free VRAM (actual free minus local reservations)
|
||||
and reserve the estimated VRAM. Returns None if none fit."""
|
||||
with _worker_lock:
|
||||
best = None
|
||||
for w in WORKERS:
|
||||
free = worker_free_vram_mib(w) - _reserved_vram.get(w["host"], 0)
|
||||
if free >= estimated_vram_mib and (best is None or free > best[0]):
|
||||
best = (free, w)
|
||||
if best:
|
||||
_reserved_vram[best[1]["host"]] = _reserved_vram.get(best[1]["host"], 0) + estimated_vram_mib
|
||||
return best[1]
|
||||
return None
|
||||
|
||||
|
||||
def release_worker(worker: dict, estimated_vram_mib: int):
|
||||
with _worker_lock:
|
||||
h = worker["host"]
|
||||
_reserved_vram[h] = max(0, _reserved_vram.get(h, 0) - estimated_vram_mib)
|
||||
|
||||
|
||||
def estimate_vram_mib(frame_count: int) -> int:
|
||||
@@ -399,6 +415,8 @@ def run_one(job: sqlite3.Row) -> bool:
|
||||
estimated = estimate_vram_mib(job["frame_count"] or 150)
|
||||
worker = pick_worker(estimated)
|
||||
if not worker:
|
||||
# release claim so the job can be re-tried by main loop
|
||||
set_status(job_id, status="queued")
|
||||
return False
|
||||
set_status(job_id, status="extracting", worker_host=worker["host"], started_at=_now_iso())
|
||||
try:
|
||||
@@ -412,21 +430,42 @@ def run_one(job: sqlite3.Row) -> bool:
|
||||
_maybe_create_per_auv_stitch(job_id)
|
||||
except Exception as e:
|
||||
set_status(job_id, status="error", error=str(e)[:2000], finished_at=_now_iso())
|
||||
finally:
|
||||
release_worker(worker, estimated)
|
||||
return True
|
||||
|
||||
|
||||
def pop_queued() -> sqlite3.Row | None:
|
||||
"""Atomic claim: grab a queued job and mark it 'claimed' to prevent double-dispatch."""
|
||||
with closing(db()) as conn:
|
||||
return conn.execute(
|
||||
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1"
|
||||
).fetchone()
|
||||
try:
|
||||
conn.execute("BEGIN IMMEDIATE")
|
||||
row = conn.execute(
|
||||
"SELECT * FROM jobs WHERE status='queued' ORDER BY created_at LIMIT 1"
|
||||
).fetchone()
|
||||
if row:
|
||||
conn.execute("UPDATE jobs SET status='claimed' WHERE id=?", (row["id"],))
|
||||
conn.execute("COMMIT")
|
||||
except Exception:
|
||||
conn.execute("ROLLBACK")
|
||||
return None
|
||||
return row
|
||||
|
||||
|
||||
def pop_queued_stitch() -> sqlite3.Row | None:
|
||||
with closing(db()) as conn:
|
||||
return conn.execute(
|
||||
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1"
|
||||
).fetchone()
|
||||
try:
|
||||
conn.execute("BEGIN IMMEDIATE")
|
||||
row = conn.execute(
|
||||
"SELECT * FROM stitches WHERE status='queued' ORDER BY created_at LIMIT 1"
|
||||
).fetchone()
|
||||
if row:
|
||||
conn.execute("UPDATE stitches SET status='claimed' WHERE id=?", (row["id"],))
|
||||
conn.execute("COMMIT")
|
||||
except Exception:
|
||||
conn.execute("ROLLBACK")
|
||||
return None
|
||||
return row
|
||||
|
||||
|
||||
def write_heartbeat():
|
||||
@@ -437,23 +476,48 @@ def write_heartbeat():
|
||||
pass
|
||||
|
||||
|
||||
_heartbeat_stop = threading.Event()
|
||||
|
||||
|
||||
def _heartbeat_loop():
|
||||
while not _heartbeat_stop.is_set():
|
||||
write_heartbeat()
|
||||
_heartbeat_stop.wait(5)
|
||||
|
||||
|
||||
def _run_one_thread(job: sqlite3.Row):
|
||||
try:
|
||||
if not run_one(job):
|
||||
print(f" ↳ job #{job['id']}: pas de worker dispo, remis en queued")
|
||||
except Exception as e:
|
||||
set_status(job["id"], status="error", error=f"run_one thread crashed: {str(e)[:1500]}",
|
||||
finished_at=_now_iso())
|
||||
|
||||
|
||||
def main():
|
||||
print(f"cosma-qc dispatcher · DB={DB_PATH} · workers={[w['host'] for w in WORKERS]}")
|
||||
threading.Thread(target=_heartbeat_loop, daemon=True).start()
|
||||
active: set[threading.Thread] = set()
|
||||
max_parallel = max(1, len(WORKERS))
|
||||
while True:
|
||||
write_heartbeat()
|
||||
job = pop_queued()
|
||||
if job:
|
||||
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']})")
|
||||
if not run_one(job):
|
||||
print(" ↳ pas de worker dispo, retry dans 30s")
|
||||
time.sleep(30)
|
||||
continue
|
||||
stitch = pop_queued_stitch()
|
||||
if stitch:
|
||||
label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}"
|
||||
print(f"→ stitch #{stitch['id']} ({label})")
|
||||
run_one_stitch(stitch)
|
||||
continue
|
||||
# drain finished threads
|
||||
active = {t for t in active if t.is_alive()}
|
||||
if len(active) < max_parallel:
|
||||
job = pop_queued()
|
||||
if job:
|
||||
print(f"→ job #{job['id']} ({job['auv']}/{job['gopro_serial']}/{job['segment_label']}) [active={len(active) + 1}]")
|
||||
t = threading.Thread(target=_run_one_thread, args=(job,), daemon=False)
|
||||
t.start()
|
||||
active.add(t)
|
||||
# brief pause so the thread can reserve its worker before we pop another
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
stitch = pop_queued_stitch()
|
||||
if stitch:
|
||||
label = f"{stitch['level']} {stitch['auv'] or ''} acq#{stitch['acquisition_id']}"
|
||||
print(f"→ stitch #{stitch['id']} ({label})")
|
||||
run_one_stitch(stitch)
|
||||
continue
|
||||
time.sleep(POLL_S)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user