503 lines
18 KiB
Python
503 lines
18 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import sqlite3
|
||
from contextlib import asynccontextmanager, closing
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
|
||
def _fmt_dur(seconds: float) -> str:
|
||
if seconds is None or seconds < 0:
|
||
return ""
|
||
s = int(seconds)
|
||
if s < 60:
|
||
return f"{s}s"
|
||
m, s = divmod(s, 60)
|
||
if m < 60:
|
||
return f"{m}m{s:02d}s" if s else f"{m}m"
|
||
h, m = divmod(m, 60)
|
||
if h < 24:
|
||
return f"{h}h{m:02d}m" if m else f"{h}h"
|
||
d, h = divmod(h, 24)
|
||
return f"{d}d{h:02d}h"
|
||
|
||
|
||
def _parse_ts(s: str | None) -> datetime | None:
|
||
if not s:
|
||
return None
|
||
try:
|
||
return datetime.fromisoformat(s.replace("Z", "+00:00"))
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def _job_duration_s(job: sqlite3.Row) -> int:
|
||
start = _parse_ts(job["started_at"])
|
||
end = _parse_ts(job["finished_at"]) or datetime.now(timezone.utc)
|
||
if not start:
|
||
return 0
|
||
if start.tzinfo is None:
|
||
start = start.replace(tzinfo=timezone.utc)
|
||
if end.tzinfo is None:
|
||
end = end.replace(tzinfo=timezone.utc)
|
||
return int((end - start).total_seconds())
|
||
|
||
from fastapi import FastAPI, Form, HTTPException, Request
|
||
from fastapi.responses import HTMLResponse, JSONResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
|
||
DB_PATH = Path(os.environ.get("COSMA_QC_DB", "/var/lib/cosma-qc/jobs.db"))
|
||
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps([
|
||
{"host": "192.168.0.87", "ssh_alias": "gpu", "gpu": "RTX 3060 12GB"},
|
||
{"host": "192.168.0.84", "ssh_alias": "cosma-vm","gpu": "RTX 3090 24GB"},
|
||
])))
|
||
|
||
STATUSES = ("queued", "extracting", "running", "done", "error")
|
||
|
||
|
||
def db() -> sqlite3.Connection:
|
||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||
conn = sqlite3.connect(DB_PATH, isolation_level=None)
|
||
conn.execute("PRAGMA journal_mode=WAL")
|
||
conn.execute("PRAGMA foreign_keys=ON")
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
|
||
def init_schema() -> None:
|
||
with closing(db()) as conn:
|
||
conn.executescript("""
|
||
CREATE TABLE IF NOT EXISTS acquisitions (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
source_path TEXT NOT NULL,
|
||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS jobs (
|
||
id INTEGER PRIMARY KEY,
|
||
acquisition_id INTEGER NOT NULL REFERENCES acquisitions(id) ON DELETE CASCADE,
|
||
auv TEXT NOT NULL,
|
||
gopro_serial TEXT NOT NULL,
|
||
segment_label TEXT NOT NULL,
|
||
video_paths TEXT NOT NULL,
|
||
frame_count INTEGER,
|
||
frames_dir TEXT,
|
||
status TEXT NOT NULL DEFAULT 'queued',
|
||
worker_host TEXT,
|
||
viser_url TEXT,
|
||
ply_path TEXT,
|
||
progress INTEGER NOT NULL DEFAULT 0,
|
||
log_tail TEXT,
|
||
error TEXT,
|
||
started_at TEXT,
|
||
finished_at TEXT,
|
||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS jobs_status_idx ON jobs(status);
|
||
CREATE INDEX IF NOT EXISTS jobs_acq_idx ON jobs(acquisition_id);
|
||
|
||
CREATE TABLE IF NOT EXISTS stitches (
|
||
id INTEGER PRIMARY KEY,
|
||
acquisition_id INTEGER NOT NULL REFERENCES acquisitions(id) ON DELETE CASCADE,
|
||
level TEXT NOT NULL DEFAULT 'per_auv',
|
||
auv TEXT,
|
||
input_job_ids TEXT NOT NULL DEFAULT '[]',
|
||
input_stitch_ids TEXT NOT NULL DEFAULT '[]',
|
||
output_ply TEXT,
|
||
status TEXT NOT NULL DEFAULT 'queued',
|
||
worker_host TEXT,
|
||
started_at TEXT,
|
||
finished_at TEXT,
|
||
error TEXT,
|
||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS stitches_acq_idx ON stitches(acquisition_id);
|
||
""")
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(_: FastAPI):
|
||
init_schema()
|
||
yield
|
||
|
||
|
||
app = FastAPI(title="cosma-qc", lifespan=lifespan)
|
||
templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
|
||
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
||
app.mount("/docs", StaticFiles(directory=Path("/app/docs"), html=True), name="docs")
|
||
|
||
|
||
_viser_probe_cache: dict[str, tuple[float, bool]] = {}
|
||
_VISER_PROBE_TTL = 8.0 # seconds
|
||
|
||
|
||
def _viser_alive(url: str) -> bool:
|
||
"""Fast TCP check with short cache so we never surface a dead link in the dashboard."""
|
||
import time as _t
|
||
import socket
|
||
now = _t.time()
|
||
cached = _viser_probe_cache.get(url)
|
||
if cached and now - cached[0] < _VISER_PROBE_TTL:
|
||
return cached[1]
|
||
try:
|
||
from urllib.parse import urlparse
|
||
p = urlparse(url)
|
||
host, port = p.hostname, p.port
|
||
if not host or not port:
|
||
return False
|
||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||
s.settimeout(0.4)
|
||
s.connect((host, port))
|
||
alive = True
|
||
except OSError:
|
||
alive = False
|
||
_viser_probe_cache[url] = (now, alive)
|
||
return alive
|
||
|
||
|
||
def _build_acquisitions():
|
||
with closing(db()) as conn:
|
||
acqs = conn.execute(
|
||
"SELECT * FROM acquisitions ORDER BY created_at DESC"
|
||
).fetchall()
|
||
jobs = conn.execute(
|
||
"SELECT * FROM jobs ORDER BY auv, gopro_serial, segment_label"
|
||
).fetchall()
|
||
stitches = conn.execute(
|
||
"SELECT * FROM stitches ORDER BY level DESC, auv"
|
||
).fetchall()
|
||
|
||
# Assign GP1/GP2 labels per AUV by enumerating distinct serials in fixed order.
|
||
gp_by_serial: dict[tuple[int, str], str] = {}
|
||
for j in jobs:
|
||
key = (j["acquisition_id"], j["auv"])
|
||
serials = gp_by_serial.setdefault(key, [])
|
||
if j["gopro_serial"] not in serials:
|
||
serials.append(j["gopro_serial"])
|
||
gp_label: dict[tuple[int, str, str], str] = {}
|
||
for (acq_id, auv), serials in gp_by_serial.items():
|
||
for idx, ser in enumerate(sorted(serials)):
|
||
gp_label[(acq_id, auv, ser)] = f"GP{idx + 1}"
|
||
|
||
by_acq: dict[int, list[dict]] = {}
|
||
by_acq_total: dict[int, int] = {}
|
||
for j in jobs:
|
||
d = dict(j)
|
||
dur_s = _job_duration_s(j)
|
||
d["_duration"] = _fmt_dur(dur_s)
|
||
d["gp_label"] = gp_label.get((j["acquisition_id"], j["auv"], j["gopro_serial"]), "?")
|
||
vid_dur_s = float(j["video_duration_s"] or 0)
|
||
d["video_duration_fmt"] = _fmt_dur(int(vid_dur_s)) if vid_dur_s > 0 else "—"
|
||
# Fix segment_label end-time when exiftool returned duration=0 at ingest
|
||
seg_label = j["segment_label"] or ""
|
||
if vid_dur_s > 0 and "–" in seg_label:
|
||
try:
|
||
from datetime import datetime as _dt, timedelta as _td
|
||
start_str = seg_label.split("–")[0].strip()
|
||
t0 = _dt.strptime(start_str, "%H:%M")
|
||
t1 = t0 + _td(seconds=vid_dur_s)
|
||
seg_label = f"{start_str}–{t1.strftime('%H:%M')}"
|
||
except Exception:
|
||
pass
|
||
d["segment_label"] = seg_label
|
||
d["trimmed_total"] = (j["trimmed_head"] or 0) + (j["trimmed_tail"] or 0)
|
||
thumb_path = DB_PATH.parent / "thumbnails" / f"job_{j['id']}.jpg"
|
||
d["has_thumbnail"] = thumb_path.is_file()
|
||
# Bust the browser cache on the mtime so the preview refreshes as the dispatcher re-copies it.
|
||
d["thumb_ts"] = int(thumb_path.stat().st_mtime) if d["has_thumbnail"] else 0
|
||
# Try the new column; fall back silently on old rows.
|
||
try:
|
||
d["step"] = j["step"]
|
||
except (KeyError, IndexError):
|
||
d["step"] = None
|
||
# Mask the viser link when the demo.py that was serving it has since died.
|
||
if j["status"] == "done" and j["viser_url"] and not _viser_alive(j["viser_url"]):
|
||
d["viser_url"] = None
|
||
# GLB download link: only when glb_path is set in DB (conversion confirmed)
|
||
glb_url = None
|
||
glb_path = d.get("glb_path")
|
||
if glb_path and d.get("worker_host"):
|
||
dir_name = glb_path.rstrip("/").split("/")[-2]
|
||
file_name = glb_path.rstrip("/").split("/")[-1]
|
||
glb_url = f"http://{d['worker_host']}:8300/{dir_name}/{file_name}"
|
||
d["glb_url"] = glb_url
|
||
by_acq.setdefault(j["acquisition_id"], []).append(d)
|
||
by_acq_total[j["acquisition_id"]] = by_acq_total.get(j["acquisition_id"], 0) + dur_s
|
||
|
||
stitches_by_acq: dict[int, list[dict]] = {}
|
||
for s in stitches:
|
||
d = dict(s)
|
||
start = _parse_ts(s["started_at"])
|
||
end = _parse_ts(s["finished_at"]) or (
|
||
datetime.now(timezone.utc) if s["status"] == "running" else None
|
||
)
|
||
if start and end:
|
||
if start.tzinfo is None:
|
||
start = start.replace(tzinfo=timezone.utc)
|
||
if end.tzinfo is None:
|
||
end = end.replace(tzinfo=timezone.utc)
|
||
d["_duration"] = _fmt_dur(int((end - start).total_seconds()))
|
||
else:
|
||
d["_duration"] = ""
|
||
stitches_by_acq.setdefault(s["acquisition_id"], []).append(d)
|
||
|
||
return [
|
||
{
|
||
"id": acq["id"],
|
||
"name": acq["name"],
|
||
"source_path": acq["source_path"],
|
||
"jobs": by_acq.get(acq["id"], []),
|
||
"stitches": stitches_by_acq.get(acq["id"], []),
|
||
"total_duration": _fmt_dur(by_acq_total.get(acq["id"], 0)),
|
||
}
|
||
for acq in acqs
|
||
]
|
||
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index(request: Request):
|
||
acquisitions = _build_acquisitions()
|
||
return templates.TemplateResponse("index.html", {
|
||
"request": request,
|
||
"acquisitions": acquisitions,
|
||
"workers": WORKERS,
|
||
})
|
||
|
||
|
||
@app.get("/api/jobs")
|
||
async def list_jobs():
|
||
with closing(db()) as conn:
|
||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC LIMIT 500").fetchall()
|
||
return [dict(r) for r in rows]
|
||
|
||
|
||
@app.get("/partials/jobs", response_class=HTMLResponse)
|
||
async def partial_jobs(request: Request):
|
||
return templates.TemplateResponse(
|
||
"_jobs_table.html",
|
||
{"request": request, "acquisitions": _build_acquisitions()},
|
||
)
|
||
|
||
|
||
@app.get("/partials/monitor", response_class=HTMLResponse)
|
||
async def partial_monitor(request: Request):
|
||
stats = await asyncio.gather(*[_worker_stats(w) for w in WORKERS])
|
||
return templates.TemplateResponse("_monitor.html", {
|
||
"request": request,
|
||
"workers": stats,
|
||
"dispatcher": _dispatcher_status(),
|
||
})
|
||
|
||
|
||
def _dispatcher_status() -> dict:
|
||
hb = DB_PATH.parent / "dispatcher.heartbeat"
|
||
try:
|
||
ts = _parse_ts(hb.read_text().strip())
|
||
if ts:
|
||
if ts.tzinfo is None:
|
||
ts = ts.replace(tzinfo=timezone.utc)
|
||
age = int((datetime.now(timezone.utc) - ts).total_seconds())
|
||
return {"alive": age < 30, "age_s": age}
|
||
except Exception:
|
||
pass
|
||
return {"alive": False, "age_s": None}
|
||
|
||
|
||
async def _worker_stats(worker: dict) -> dict:
|
||
alias = worker["ssh_alias"]
|
||
try:
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"ssh", "-o", "ConnectTimeout=3", "-o", "BatchMode=yes", alias,
|
||
"nvidia-smi --query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu,power.draw --format=csv,noheader,nounits && df -h / | tail -1",
|
||
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
out, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
|
||
text = out.decode().strip().splitlines()
|
||
g = [x.strip() for x in text[0].split(",")] if text else ["?"] * 5
|
||
disk = text[1].split() if len(text) > 1 else ["?"] * 6
|
||
|
||
def _int(v: str):
|
||
try:
|
||
return int(float(v))
|
||
except Exception:
|
||
return None
|
||
|
||
return {
|
||
**worker,
|
||
"online": True,
|
||
"vram_used_mib": _int(g[0]) if len(g) > 0 else None,
|
||
"vram_total_mib": _int(g[1]) if len(g) > 1 else None,
|
||
"gpu_util_pct": _int(g[2]) if len(g) > 2 else None,
|
||
"gpu_temp_c": _int(g[3]) if len(g) > 3 else None,
|
||
"gpu_power_w": _int(g[4]) if len(g) > 4 else None,
|
||
"disk_used_pct": disk[4] if len(disk) > 4 else "?",
|
||
"disk_avail": disk[3] if len(disk) > 3 else "?",
|
||
}
|
||
except Exception as e:
|
||
return {**worker, "online": False, "error": str(e)[:80]}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/cancel")
|
||
async def cancel_job(job_id: int):
|
||
with closing(db()) as conn:
|
||
conn.execute(
|
||
"UPDATE jobs SET status='error', error='cancelled by user', finished_at=datetime('now') "
|
||
"WHERE id=? AND status IN ('queued','extracting','running')",
|
||
(job_id,),
|
||
)
|
||
return {"ok": True}
|
||
|
||
|
||
@app.post("/jobs/{job_id}/retry")
|
||
async def retry_job(job_id: int):
|
||
with closing(db()) as conn:
|
||
conn.execute(
|
||
"UPDATE jobs SET status='queued', error=NULL, progress=0, started_at=NULL, "
|
||
"finished_at=NULL, worker_host=NULL WHERE id=? AND status='error'",
|
||
(job_id,),
|
||
)
|
||
return {"ok": True}
|
||
|
||
|
||
@app.post("/stitches/{stitch_id}/cancel")
|
||
async def cancel_stitch(stitch_id: int):
|
||
with closing(db()) as conn:
|
||
conn.execute(
|
||
"UPDATE stitches SET status='error', error='cancelled by user', finished_at=datetime('now') "
|
||
"WHERE id=? AND status IN ('queued','running')",
|
||
(stitch_id,),
|
||
)
|
||
return {"ok": True}
|
||
|
||
|
||
@app.post("/stitches/{stitch_id}/retry")
|
||
async def retry_stitch(stitch_id: int):
|
||
with closing(db()) as conn:
|
||
conn.execute(
|
||
"UPDATE stitches SET status='queued', error=NULL, output_ply=NULL, "
|
||
"started_at=NULL, finished_at=NULL, worker_host=NULL WHERE id=? AND status='error'",
|
||
(stitch_id,),
|
||
)
|
||
return {"ok": True}
|
||
|
||
|
||
VIEWER_PORT_BASE = 8200
|
||
VIEWER_SCRIPT_REMOTE = "/tmp/cosma-viser_ply.py"
|
||
|
||
|
||
def _worker_by_host(host: str) -> dict | None:
|
||
for w in WORKERS:
|
||
if w["host"] == host:
|
||
return w
|
||
return WORKERS[0] if WORKERS else None
|
||
|
||
|
||
async def _launch_viewer(worker: dict, ply_path: str, port: int) -> None:
|
||
alias = worker["ssh_alias"]
|
||
local_script = Path(__file__).parent.parent / "scripts" / "viser_ply.py"
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"scp", "-o", "BatchMode=yes", str(local_script), f"{alias}:{VIEWER_SCRIPT_REMOTE}",
|
||
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
_, err = await proc.communicate()
|
||
if proc.returncode != 0:
|
||
raise HTTPException(500, f"scp viser_ply.py failed: {err.decode()[:200]}")
|
||
wrapper = (
|
||
f"#!/bin/bash\n"
|
||
f"pkill -f 'viser_ply.py.*--port {port}' 2>/dev/null\n"
|
||
f"sleep 1\n"
|
||
f"cd {worker['lingbot_path']}\n"
|
||
f"source .venv/bin/activate\n"
|
||
f"exec python3 {VIEWER_SCRIPT_REMOTE} {ply_path!r} --port {port} --downsample 0\n"
|
||
)
|
||
wrapper_path = f"/tmp/cosma-viser-launch-{port}.sh"
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"ssh", "-o", "BatchMode=yes", alias,
|
||
f"cat > {wrapper_path} && chmod +x {wrapper_path}",
|
||
stdin=asyncio.subprocess.PIPE,
|
||
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
_, err = await proc.communicate(wrapper.encode())
|
||
if proc.returncode != 0:
|
||
raise HTTPException(500, f"wrapper write failed: {err.decode()[:200]}")
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"ssh", "-o", "BatchMode=yes", alias,
|
||
f"setsid nohup {wrapper_path} </dev/null >/tmp/cosma-viser-{port}.log 2>&1 & disown; sleep 0.3",
|
||
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
await proc.communicate()
|
||
await asyncio.sleep(5)
|
||
|
||
|
||
@app.post("/jobs/{job_id}/view")
|
||
async def view_job(job_id: int):
|
||
with closing(db()) as conn:
|
||
row = conn.execute(
|
||
"SELECT ply_path, worker_host FROM jobs WHERE id=? AND status='done'",
|
||
(job_id,),
|
||
).fetchone()
|
||
if not row or not row["ply_path"]:
|
||
raise HTTPException(404, "PLY non disponible")
|
||
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
|
||
port = VIEWER_PORT_BASE + job_id
|
||
await _launch_viewer(worker, row["ply_path"], port)
|
||
return {"url": f"http://{worker['host']}:{port}"}
|
||
|
||
|
||
@app.get("/jobs/{job_id}/thumbnail")
|
||
async def job_thumbnail(job_id: int):
|
||
"""Serve the cached thumbnail the dispatcher scp'd after trimming out-of-water frames."""
|
||
from fastapi.responses import FileResponse
|
||
thumb = DB_PATH.parent / "thumbnails" / f"job_{job_id}.jpg"
|
||
if not thumb.exists():
|
||
raise HTTPException(404, "no thumbnail yet")
|
||
return FileResponse(thumb, media_type="image/jpeg")
|
||
|
||
|
||
@app.post("/jobs/{job_id}/live")
|
||
async def live_job(job_id: int):
|
||
"""Return the URL of demo.py's native viser (PointCloudViewer with camera frustums,
|
||
confidence filtering, animation) if it's still listening. Otherwise 404 so the UI falls
|
||
back to /view (viser_ply.py standalone)."""
|
||
with closing(db()) as conn:
|
||
row = conn.execute(
|
||
"SELECT viser_url, worker_host FROM jobs WHERE id=? AND status='done'",
|
||
(job_id,),
|
||
).fetchone()
|
||
if not row or not row["viser_url"]:
|
||
raise HTTPException(404, "viser natif jamais démarré")
|
||
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
|
||
native_port = worker["viser_port_base"] + job_id
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=3", worker["ssh_alias"],
|
||
f"nc -z -w2 127.0.0.1 {native_port}",
|
||
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL,
|
||
)
|
||
await proc.wait()
|
||
if proc.returncode != 0:
|
||
raise HTTPException(410, "viser natif fermé — utilise le bouton PLY")
|
||
return {"url": row["viser_url"]}
|
||
|
||
|
||
@app.post("/stitches/{stitch_id}/view")
|
||
async def view_stitch(stitch_id: int):
|
||
with closing(db()) as conn:
|
||
row = conn.execute(
|
||
"SELECT output_ply, worker_host FROM stitches WHERE id=? AND status='done'",
|
||
(stitch_id,),
|
||
).fetchone()
|
||
if not row or not row["output_ply"]:
|
||
raise HTTPException(404, "PLY stitch non disponible")
|
||
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
|
||
port = VIEWER_PORT_BASE + 1000 + stitch_id
|
||
await _launch_viewer(worker, row["output_ply"], port)
|
||
return {"url": f"http://{worker['host']}:{port}"}
|