99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
import asyncio
|
|
import gzip
|
|
import json
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from .config import OUTPUT_DIR
|
|
from .runner import run_pipeline, scan_sorties
|
|
|
|
app = FastAPI(title="COSMA Pipeline Runner")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
|
|
# Active pipeline jobs: sortie_id → asyncio.Queue
|
|
_jobs: dict[str, asyncio.Queue] = {}
|
|
|
|
|
|
@app.get("/sorties")
|
|
async def list_sorties():
|
|
return await scan_sorties()
|
|
|
|
|
|
@app.post("/run/{sortie_id:path}")
|
|
async def run_sortie(sortie_id: str):
|
|
if sortie_id in _jobs:
|
|
return {"status": "already_running"}
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
_jobs[sortie_id] = queue
|
|
asyncio.create_task(_run_and_cleanup(sortie_id, queue))
|
|
return {"status": "started"}
|
|
|
|
|
|
async def _run_and_cleanup(sortie_id: str, queue: asyncio.Queue):
|
|
try:
|
|
await run_pipeline(sortie_id, queue)
|
|
finally:
|
|
await asyncio.sleep(30)
|
|
_jobs.pop(sortie_id, None)
|
|
|
|
|
|
@app.get("/events/{sortie_id:path}")
|
|
async def sse_events(sortie_id: str):
|
|
if sortie_id not in _jobs:
|
|
raise HTTPException(404, "No active job for this sortie")
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
queue = _jobs[sortie_id]
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=60)
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
if event.get("step") in ("error", "write") and event.get("pct") in (0, 100):
|
|
break
|
|
except asyncio.TimeoutError:
|
|
yield "data: {\"step\":\"ping\"}\n\n"
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
|
|
|
|
def _read_gz(path: Path) -> dict:
|
|
with gzip.open(path) as f:
|
|
return json.loads(f.read())
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/usv")
|
|
async def get_usv(sortie_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / "usv.json.gz"
|
|
if not p.exists():
|
|
raise HTTPException(404, "USV data not found — run pipeline first")
|
|
return JSONResponse(_read_gz(p))
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/auvs")
|
|
async def list_auvs(sortie_id: str):
|
|
proc = OUTPUT_DIR / sortie_id / "processed"
|
|
auvs = [p.name.removesuffix(".json.gz").removeprefix("auv_")
|
|
for p in proc.glob("auv_*.json.gz")]
|
|
return sorted(auvs)
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/auv/{auv_id}")
|
|
async def get_auv(sortie_id: str, auv_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / f"auv_{auv_id}.json.gz"
|
|
if not p.exists():
|
|
raise HTTPException(404, f"AUV {auv_id} data not found")
|
|
return JSONResponse(_read_gz(p))
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/tracks")
|
|
async def get_tracks(sortie_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / "tracks.geojson"
|
|
if not p.exists():
|
|
raise HTTPException(404, "tracks.geojson not found")
|
|
with open(p) as f:
|
|
return JSONResponse(json.load(f))
|