feat: pipeline-runner — FastAPI endpoints + SSE
This commit is contained in:
98
pipeline_runner/main.py
Normal file
98
pipeline_runner/main.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import asyncio
|
||||
import gzip
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from .config import OUTPUT_DIR
|
||||
from .runner import run_pipeline, scan_sorties
|
||||
|
||||
app = FastAPI(title="COSMA Pipeline Runner")
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
# Active pipeline jobs: sortie_id → asyncio.Queue
|
||||
_jobs: dict[str, asyncio.Queue] = {}
|
||||
|
||||
|
||||
@app.get("/sorties")
|
||||
async def list_sorties():
|
||||
return await scan_sorties()
|
||||
|
||||
|
||||
@app.post("/run/{sortie_id:path}")
|
||||
async def run_sortie(sortie_id: str):
|
||||
if sortie_id in _jobs:
|
||||
return {"status": "already_running"}
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
_jobs[sortie_id] = queue
|
||||
asyncio.create_task(_run_and_cleanup(sortie_id, queue))
|
||||
return {"status": "started"}
|
||||
|
||||
|
||||
async def _run_and_cleanup(sortie_id: str, queue: asyncio.Queue):
|
||||
try:
|
||||
await run_pipeline(sortie_id, queue)
|
||||
finally:
|
||||
await asyncio.sleep(30)
|
||||
_jobs.pop(sortie_id, None)
|
||||
|
||||
|
||||
@app.get("/events/{sortie_id:path}")
|
||||
async def sse_events(sortie_id: str):
|
||||
if sortie_id not in _jobs:
|
||||
raise HTTPException(404, "No active job for this sortie")
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
queue = _jobs[sortie_id]
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(queue.get(), timeout=60)
|
||||
yield f"data: {json.dumps(event)}\n\n"
|
||||
if event.get("step") in ("error", "write") and event.get("pct") in (0, 100):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
yield "data: {\"step\":\"ping\"}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
|
||||
def _read_gz(path: Path) -> dict:
|
||||
with gzip.open(path) as f:
|
||||
return json.loads(f.read())
|
||||
|
||||
|
||||
@app.get("/sorties/{sortie_id:path}/usv")
|
||||
async def get_usv(sortie_id: str):
|
||||
p = OUTPUT_DIR / sortie_id / "processed" / "usv.json.gz"
|
||||
if not p.exists():
|
||||
raise HTTPException(404, "USV data not found — run pipeline first")
|
||||
return JSONResponse(_read_gz(p))
|
||||
|
||||
|
||||
@app.get("/sorties/{sortie_id:path}/auvs")
|
||||
async def list_auvs(sortie_id: str):
|
||||
proc = OUTPUT_DIR / sortie_id / "processed"
|
||||
auvs = [p.name.removesuffix(".json.gz").removeprefix("auv_")
|
||||
for p in proc.glob("auv_*.json.gz")]
|
||||
return sorted(auvs)
|
||||
|
||||
|
||||
@app.get("/sorties/{sortie_id:path}/auv/{auv_id}")
|
||||
async def get_auv(sortie_id: str, auv_id: str):
|
||||
p = OUTPUT_DIR / sortie_id / "processed" / f"auv_{auv_id}.json.gz"
|
||||
if not p.exists():
|
||||
raise HTTPException(404, f"AUV {auv_id} data not found")
|
||||
return JSONResponse(_read_gz(p))
|
||||
|
||||
|
||||
@app.get("/sorties/{sortie_id:path}/tracks")
|
||||
async def get_tracks(sortie_id: str):
|
||||
p = OUTPUT_DIR / sortie_id / "processed" / "tracks.geojson"
|
||||
if not p.exists():
|
||||
raise HTTPException(404, "tracks.geojson not found")
|
||||
with open(p) as f:
|
||||
return JSONResponse(json.load(f))
|
||||
Reference in New Issue
Block a user