Files
cosma-qc/app/main.py

503 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import json
import os
import sqlite3
from contextlib import asynccontextmanager, closing
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
def _fmt_dur(seconds: float) -> str:
if seconds is None or seconds < 0:
return ""
s = int(seconds)
if s < 60:
return f"{s}s"
m, s = divmod(s, 60)
if m < 60:
return f"{m}m{s:02d}s" if s else f"{m}m"
h, m = divmod(m, 60)
if h < 24:
return f"{h}h{m:02d}m" if m else f"{h}h"
d, h = divmod(h, 24)
return f"{d}d{h:02d}h"
def _parse_ts(s: str | None) -> datetime | None:
if not s:
return None
try:
return datetime.fromisoformat(s.replace("Z", "+00:00"))
except Exception:
return None
def _job_duration_s(job: sqlite3.Row) -> int:
start = _parse_ts(job["started_at"])
end = _parse_ts(job["finished_at"]) or datetime.now(timezone.utc)
if not start:
return 0
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
return int((end - start).total_seconds())
from fastapi import FastAPI, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
DB_PATH = Path(os.environ.get("COSMA_QC_DB", "/var/lib/cosma-qc/jobs.db"))
WORKERS = json.loads(os.environ.get("COSMA_QC_WORKERS", json.dumps([
{"host": "192.168.0.87", "ssh_alias": "gpu", "gpu": "RTX 3060 12GB"},
{"host": "192.168.0.84", "ssh_alias": "cosma-vm","gpu": "RTX 3090 24GB"},
])))
STATUSES = ("queued", "extracting", "running", "done", "error")
def db() -> sqlite3.Connection:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(DB_PATH, isolation_level=None)
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
conn.row_factory = sqlite3.Row
return conn
def init_schema() -> None:
with closing(db()) as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS acquisitions (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
source_path TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY,
acquisition_id INTEGER NOT NULL REFERENCES acquisitions(id) ON DELETE CASCADE,
auv TEXT NOT NULL,
gopro_serial TEXT NOT NULL,
segment_label TEXT NOT NULL,
video_paths TEXT NOT NULL,
frame_count INTEGER,
frames_dir TEXT,
status TEXT NOT NULL DEFAULT 'queued',
worker_host TEXT,
viser_url TEXT,
ply_path TEXT,
progress INTEGER NOT NULL DEFAULT 0,
log_tail TEXT,
error TEXT,
started_at TEXT,
finished_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS jobs_status_idx ON jobs(status);
CREATE INDEX IF NOT EXISTS jobs_acq_idx ON jobs(acquisition_id);
CREATE TABLE IF NOT EXISTS stitches (
id INTEGER PRIMARY KEY,
acquisition_id INTEGER NOT NULL REFERENCES acquisitions(id) ON DELETE CASCADE,
level TEXT NOT NULL DEFAULT 'per_auv',
auv TEXT,
input_job_ids TEXT NOT NULL DEFAULT '[]',
input_stitch_ids TEXT NOT NULL DEFAULT '[]',
output_ply TEXT,
status TEXT NOT NULL DEFAULT 'queued',
worker_host TEXT,
started_at TEXT,
finished_at TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS stitches_acq_idx ON stitches(acquisition_id);
""")
@asynccontextmanager
async def lifespan(_: FastAPI):
init_schema()
yield
app = FastAPI(title="cosma-qc", lifespan=lifespan)
templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount("/docs", StaticFiles(directory=Path("/app/docs"), html=True), name="docs")
_viser_probe_cache: dict[str, tuple[float, bool]] = {}
_VISER_PROBE_TTL = 8.0 # seconds
def _viser_alive(url: str) -> bool:
"""Fast TCP check with short cache so we never surface a dead link in the dashboard."""
import time as _t
import socket
now = _t.time()
cached = _viser_probe_cache.get(url)
if cached and now - cached[0] < _VISER_PROBE_TTL:
return cached[1]
try:
from urllib.parse import urlparse
p = urlparse(url)
host, port = p.hostname, p.port
if not host or not port:
return False
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.4)
s.connect((host, port))
alive = True
except OSError:
alive = False
_viser_probe_cache[url] = (now, alive)
return alive
def _build_acquisitions():
with closing(db()) as conn:
acqs = conn.execute(
"SELECT * FROM acquisitions ORDER BY created_at DESC"
).fetchall()
jobs = conn.execute(
"SELECT * FROM jobs ORDER BY auv, gopro_serial, segment_label"
).fetchall()
stitches = conn.execute(
"SELECT * FROM stitches ORDER BY level DESC, auv"
).fetchall()
# Assign GP1/GP2 labels per AUV by enumerating distinct serials in fixed order.
gp_by_serial: dict[tuple[int, str], str] = {}
for j in jobs:
key = (j["acquisition_id"], j["auv"])
serials = gp_by_serial.setdefault(key, [])
if j["gopro_serial"] not in serials:
serials.append(j["gopro_serial"])
gp_label: dict[tuple[int, str, str], str] = {}
for (acq_id, auv), serials in gp_by_serial.items():
for idx, ser in enumerate(sorted(serials)):
gp_label[(acq_id, auv, ser)] = f"GP{idx + 1}"
by_acq: dict[int, list[dict]] = {}
by_acq_total: dict[int, int] = {}
for j in jobs:
d = dict(j)
dur_s = _job_duration_s(j)
d["_duration"] = _fmt_dur(dur_s)
d["gp_label"] = gp_label.get((j["acquisition_id"], j["auv"], j["gopro_serial"]), "?")
vid_dur_s = float(j["video_duration_s"] or 0)
d["video_duration_fmt"] = _fmt_dur(int(vid_dur_s)) if vid_dur_s > 0 else ""
# Fix segment_label end-time when exiftool returned duration=0 at ingest
seg_label = j["segment_label"] or ""
if vid_dur_s > 0 and "" in seg_label:
try:
from datetime import datetime as _dt, timedelta as _td
start_str = seg_label.split("")[0].strip()
t0 = _dt.strptime(start_str, "%H:%M")
t1 = t0 + _td(seconds=vid_dur_s)
seg_label = f"{start_str}{t1.strftime('%H:%M')}"
except Exception:
pass
d["segment_label"] = seg_label
d["trimmed_total"] = (j["trimmed_head"] or 0) + (j["trimmed_tail"] or 0)
thumb_path = DB_PATH.parent / "thumbnails" / f"job_{j['id']}.jpg"
d["has_thumbnail"] = thumb_path.is_file()
# Bust the browser cache on the mtime so the preview refreshes as the dispatcher re-copies it.
d["thumb_ts"] = int(thumb_path.stat().st_mtime) if d["has_thumbnail"] else 0
# Try the new column; fall back silently on old rows.
try:
d["step"] = j["step"]
except (KeyError, IndexError):
d["step"] = None
# Mask the viser link when the demo.py that was serving it has since died.
if j["status"] == "done" and j["viser_url"] and not _viser_alive(j["viser_url"]):
d["viser_url"] = None
# GLB download link: only when glb_path is set in DB (conversion confirmed)
glb_url = None
glb_path = d.get("glb_path")
if glb_path and d.get("worker_host"):
dir_name = glb_path.rstrip("/").split("/")[-2]
file_name = glb_path.rstrip("/").split("/")[-1]
glb_url = f"http://{d['worker_host']}:8300/{dir_name}/{file_name}"
d["glb_url"] = glb_url
by_acq.setdefault(j["acquisition_id"], []).append(d)
by_acq_total[j["acquisition_id"]] = by_acq_total.get(j["acquisition_id"], 0) + dur_s
stitches_by_acq: dict[int, list[dict]] = {}
for s in stitches:
d = dict(s)
start = _parse_ts(s["started_at"])
end = _parse_ts(s["finished_at"]) or (
datetime.now(timezone.utc) if s["status"] == "running" else None
)
if start and end:
if start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
d["_duration"] = _fmt_dur(int((end - start).total_seconds()))
else:
d["_duration"] = ""
stitches_by_acq.setdefault(s["acquisition_id"], []).append(d)
return [
{
"id": acq["id"],
"name": acq["name"],
"source_path": acq["source_path"],
"jobs": by_acq.get(acq["id"], []),
"stitches": stitches_by_acq.get(acq["id"], []),
"total_duration": _fmt_dur(by_acq_total.get(acq["id"], 0)),
}
for acq in acqs
]
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
acquisitions = _build_acquisitions()
return templates.TemplateResponse("index.html", {
"request": request,
"acquisitions": acquisitions,
"workers": WORKERS,
})
@app.get("/api/jobs")
async def list_jobs():
with closing(db()) as conn:
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC LIMIT 500").fetchall()
return [dict(r) for r in rows]
@app.get("/partials/jobs", response_class=HTMLResponse)
async def partial_jobs(request: Request):
return templates.TemplateResponse(
"_jobs_table.html",
{"request": request, "acquisitions": _build_acquisitions()},
)
@app.get("/partials/monitor", response_class=HTMLResponse)
async def partial_monitor(request: Request):
stats = await asyncio.gather(*[_worker_stats(w) for w in WORKERS])
return templates.TemplateResponse("_monitor.html", {
"request": request,
"workers": stats,
"dispatcher": _dispatcher_status(),
})
def _dispatcher_status() -> dict:
hb = DB_PATH.parent / "dispatcher.heartbeat"
try:
ts = _parse_ts(hb.read_text().strip())
if ts:
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
age = int((datetime.now(timezone.utc) - ts).total_seconds())
return {"alive": age < 30, "age_s": age}
except Exception:
pass
return {"alive": False, "age_s": None}
async def _worker_stats(worker: dict) -> dict:
alias = worker["ssh_alias"]
try:
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "ConnectTimeout=3", "-o", "BatchMode=yes", alias,
"nvidia-smi --query-gpu=memory.used,memory.total,utilization.gpu,temperature.gpu,power.draw --format=csv,noheader,nounits && df -h / | tail -1",
stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
)
out, _ = await asyncio.wait_for(proc.communicate(), timeout=5)
text = out.decode().strip().splitlines()
g = [x.strip() for x in text[0].split(",")] if text else ["?"] * 5
disk = text[1].split() if len(text) > 1 else ["?"] * 6
def _int(v: str):
try:
return int(float(v))
except Exception:
return None
return {
**worker,
"online": True,
"vram_used_mib": _int(g[0]) if len(g) > 0 else None,
"vram_total_mib": _int(g[1]) if len(g) > 1 else None,
"gpu_util_pct": _int(g[2]) if len(g) > 2 else None,
"gpu_temp_c": _int(g[3]) if len(g) > 3 else None,
"gpu_power_w": _int(g[4]) if len(g) > 4 else None,
"disk_used_pct": disk[4] if len(disk) > 4 else "?",
"disk_avail": disk[3] if len(disk) > 3 else "?",
}
except Exception as e:
return {**worker, "online": False, "error": str(e)[:80]}
@app.post("/jobs/{job_id}/cancel")
async def cancel_job(job_id: int):
with closing(db()) as conn:
conn.execute(
"UPDATE jobs SET status='error', error='cancelled by user', finished_at=datetime('now') "
"WHERE id=? AND status IN ('queued','extracting','running')",
(job_id,),
)
return {"ok": True}
@app.post("/jobs/{job_id}/retry")
async def retry_job(job_id: int):
with closing(db()) as conn:
conn.execute(
"UPDATE jobs SET status='queued', error=NULL, progress=0, started_at=NULL, "
"finished_at=NULL, worker_host=NULL WHERE id=? AND status='error'",
(job_id,),
)
return {"ok": True}
@app.post("/stitches/{stitch_id}/cancel")
async def cancel_stitch(stitch_id: int):
with closing(db()) as conn:
conn.execute(
"UPDATE stitches SET status='error', error='cancelled by user', finished_at=datetime('now') "
"WHERE id=? AND status IN ('queued','running')",
(stitch_id,),
)
return {"ok": True}
@app.post("/stitches/{stitch_id}/retry")
async def retry_stitch(stitch_id: int):
with closing(db()) as conn:
conn.execute(
"UPDATE stitches SET status='queued', error=NULL, output_ply=NULL, "
"started_at=NULL, finished_at=NULL, worker_host=NULL WHERE id=? AND status='error'",
(stitch_id,),
)
return {"ok": True}
VIEWER_PORT_BASE = 8200
VIEWER_SCRIPT_REMOTE = "/tmp/cosma-viser_ply.py"
def _worker_by_host(host: str) -> dict | None:
for w in WORKERS:
if w["host"] == host:
return w
return WORKERS[0] if WORKERS else None
async def _launch_viewer(worker: dict, ply_path: str, port: int) -> None:
alias = worker["ssh_alias"]
local_script = Path(__file__).parent.parent / "scripts" / "viser_ply.py"
proc = await asyncio.create_subprocess_exec(
"scp", "-o", "BatchMode=yes", str(local_script), f"{alias}:{VIEWER_SCRIPT_REMOTE}",
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
)
_, err = await proc.communicate()
if proc.returncode != 0:
raise HTTPException(500, f"scp viser_ply.py failed: {err.decode()[:200]}")
wrapper = (
f"#!/bin/bash\n"
f"pkill -f 'viser_ply.py.*--port {port}' 2>/dev/null\n"
f"sleep 1\n"
f"cd {worker['lingbot_path']}\n"
f"source .venv/bin/activate\n"
f"exec python3 {VIEWER_SCRIPT_REMOTE} {ply_path!r} --port {port} --downsample 0\n"
)
wrapper_path = f"/tmp/cosma-viser-launch-{port}.sh"
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "BatchMode=yes", alias,
f"cat > {wrapper_path} && chmod +x {wrapper_path}",
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
)
_, err = await proc.communicate(wrapper.encode())
if proc.returncode != 0:
raise HTTPException(500, f"wrapper write failed: {err.decode()[:200]}")
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "BatchMode=yes", alias,
f"setsid nohup {wrapper_path} </dev/null >/tmp/cosma-viser-{port}.log 2>&1 & disown; sleep 0.3",
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
await asyncio.sleep(5)
@app.post("/jobs/{job_id}/view")
async def view_job(job_id: int):
with closing(db()) as conn:
row = conn.execute(
"SELECT ply_path, worker_host FROM jobs WHERE id=? AND status='done'",
(job_id,),
).fetchone()
if not row or not row["ply_path"]:
raise HTTPException(404, "PLY non disponible")
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
port = VIEWER_PORT_BASE + job_id
await _launch_viewer(worker, row["ply_path"], port)
return {"url": f"http://{worker['host']}:{port}"}
@app.get("/jobs/{job_id}/thumbnail")
async def job_thumbnail(job_id: int):
"""Serve the cached thumbnail the dispatcher scp'd after trimming out-of-water frames."""
from fastapi.responses import FileResponse
thumb = DB_PATH.parent / "thumbnails" / f"job_{job_id}.jpg"
if not thumb.exists():
raise HTTPException(404, "no thumbnail yet")
return FileResponse(thumb, media_type="image/jpeg")
@app.post("/jobs/{job_id}/live")
async def live_job(job_id: int):
"""Return the URL of demo.py's native viser (PointCloudViewer with camera frustums,
confidence filtering, animation) if it's still listening. Otherwise 404 so the UI falls
back to /view (viser_ply.py standalone)."""
with closing(db()) as conn:
row = conn.execute(
"SELECT viser_url, worker_host FROM jobs WHERE id=? AND status='done'",
(job_id,),
).fetchone()
if not row or not row["viser_url"]:
raise HTTPException(404, "viser natif jamais démarré")
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
native_port = worker["viser_port_base"] + job_id
proc = await asyncio.create_subprocess_exec(
"ssh", "-o", "BatchMode=yes", "-o", "ConnectTimeout=3", worker["ssh_alias"],
f"nc -z -w2 127.0.0.1 {native_port}",
stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.DEVNULL,
)
await proc.wait()
if proc.returncode != 0:
raise HTTPException(410, "viser natif fermé — utilise le bouton PLY")
return {"url": row["viser_url"]}
@app.post("/stitches/{stitch_id}/view")
async def view_stitch(stitch_id: int):
with closing(db()) as conn:
row = conn.execute(
"SELECT output_ply, worker_host FROM stitches WHERE id=? AND status='done'",
(stitch_id,),
).fetchone()
if not row or not row["output_ply"]:
raise HTTPException(404, "PLY stitch non disponible")
worker = _worker_by_host(row["worker_host"]) or WORKERS[0]
port = VIEWER_PORT_BASE + 1000 + stitch_id
await _launch_viewer(worker, row["output_ply"], port)
return {"url": f"http://{worker['host']}:{port}"}