222 lines
7.2 KiB
Python
222 lines
7.2 KiB
Python
import asyncio
|
|
import csv
|
|
import gzip
|
|
import json
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import AsyncGenerator
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from .config import OUTPUT_DIR
|
|
from .runner import run_pipeline, scan_sorties, scan_sorties_local
|
|
|
|
app = FastAPI(title="COSMA Pipeline Runner")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
|
|
# Active pipeline jobs: sortie_id → asyncio.Queue
|
|
_jobs: dict[str, asyncio.Queue] = {}
|
|
|
|
# Cache sorties avec TTL 10min
|
|
_sorties_cache: list | None = None
|
|
_sorties_cache_ts: float = 0.0
|
|
_SORTIES_TTL = 600.0 # 10 minutes
|
|
_sorties_refresh_lock: asyncio.Lock | None = None
|
|
|
|
|
|
def _get_lock() -> asyncio.Lock:
|
|
global _sorties_refresh_lock
|
|
if _sorties_refresh_lock is None:
|
|
_sorties_refresh_lock = asyncio.Lock()
|
|
return _sorties_refresh_lock
|
|
|
|
|
|
async def _refresh_sorties_cache() -> None:
|
|
"""Refresh cache in background (holds lock to avoid parallel rclone calls)."""
|
|
global _sorties_cache, _sorties_cache_ts
|
|
lock = _get_lock()
|
|
async with lock:
|
|
# Double-check after acquiring lock
|
|
if time.monotonic() - _sorties_cache_ts < _SORTIES_TTL:
|
|
return
|
|
result = await asyncio.to_thread(scan_sorties)
|
|
_sorties_cache = result
|
|
_sorties_cache_ts = time.monotonic()
|
|
|
|
|
|
@app.get("/sorties")
|
|
async def list_sorties():
|
|
global _sorties_cache, _sorties_cache_ts
|
|
now = time.monotonic()
|
|
if _sorties_cache is None:
|
|
# Premier appel: bloquant (cache vide)
|
|
await _refresh_sorties_cache()
|
|
elif now - _sorties_cache_ts >= _SORTIES_TTL:
|
|
# Cache périmé: retourne le cache, refresh en arrière-plan
|
|
asyncio.create_task(_refresh_sorties_cache())
|
|
return _sorties_cache or []
|
|
|
|
|
|
@app.get("/sorties/local")
|
|
async def list_sorties_local():
|
|
"""Scan /data/sorties local (NAS, instantané) sans rclone."""
|
|
sorties = await asyncio.to_thread(scan_sorties_local)
|
|
return sorties
|
|
|
|
|
|
@app.post("/run/{sortie_id:path}")
|
|
async def run_sortie(sortie_id: str):
|
|
if sortie_id in _jobs:
|
|
return {"status": "already_running"}
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
_jobs[sortie_id] = queue
|
|
asyncio.create_task(_run_and_cleanup(sortie_id, queue))
|
|
return {"status": "started"}
|
|
|
|
|
|
async def _run_and_cleanup(sortie_id: str, queue: asyncio.Queue):
|
|
try:
|
|
await run_pipeline(sortie_id, queue)
|
|
finally:
|
|
await asyncio.sleep(30)
|
|
_jobs.pop(sortie_id, None)
|
|
|
|
|
|
@app.get("/events/{sortie_id:path}")
|
|
async def sse_events(sortie_id: str):
|
|
if sortie_id not in _jobs:
|
|
raise HTTPException(404, "No active job for this sortie")
|
|
|
|
async def generate() -> AsyncGenerator[str, None]:
|
|
queue = _jobs[sortie_id]
|
|
while True:
|
|
try:
|
|
event = await asyncio.wait_for(queue.get(), timeout=60)
|
|
yield f"data: {json.dumps(event)}\n\n"
|
|
if event.get("step") in ("error", "write") and event.get("pct") in (0, 100):
|
|
break
|
|
except asyncio.TimeoutError:
|
|
yield "data: {\"step\":\"ping\"}\n\n"
|
|
|
|
return StreamingResponse(generate(), media_type="text/event-stream")
|
|
|
|
|
|
def _read_gz(path: Path) -> dict:
|
|
with gzip.open(path) as f:
|
|
return json.loads(f.read())
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/usv")
|
|
async def get_usv(sortie_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / "usv.json.gz"
|
|
if not p.exists():
|
|
raise HTTPException(404, "USV data not found — run pipeline first")
|
|
return JSONResponse(_read_gz(p))
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/auvs")
|
|
async def list_auvs(sortie_id: str):
|
|
proc = OUTPUT_DIR / sortie_id / "processed"
|
|
auvs = [p.name.removesuffix(".json.gz").removeprefix("auv_")
|
|
for p in proc.glob("auv_*.json.gz")]
|
|
return sorted(auvs)
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/auv/{auv_id}")
|
|
async def get_auv(sortie_id: str, auv_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / f"auv_{auv_id}.json.gz"
|
|
if not p.exists():
|
|
raise HTTPException(404, f"AUV {auv_id} data not found")
|
|
return JSONResponse(_read_gz(p))
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/tracks")
|
|
async def get_tracks(sortie_id: str):
|
|
p = OUTPUT_DIR / sortie_id / "processed" / "tracks.geojson"
|
|
if not p.exists():
|
|
raise HTTPException(404, "tracks.geojson not found")
|
|
with open(p) as f:
|
|
return JSONResponse(json.load(f))
|
|
|
|
|
|
def _ts_nav(ts_str: str) -> float:
|
|
for fmt in ("%Y-%m-%d %H:%M:%S.%f", "%Y-%m-%d %H:%M:%S"):
|
|
try:
|
|
return datetime.strptime(ts_str.strip(), fmt).replace(tzinfo=timezone.utc).timestamp()
|
|
except ValueError:
|
|
continue
|
|
return 0.0
|
|
|
|
|
|
def _read_usv_track(nav_log_path: Path, max_pts: int = 2000) -> list[dict]:
|
|
"""Read navigation_log.csv → [{t_ms, lat, lon, heading, source}] downsampled."""
|
|
pts: list[dict] = []
|
|
lat_map: dict[float, float] = {}
|
|
lon_map: dict[float, float] = {}
|
|
heading_map: dict[float, float] = {}
|
|
with open(nav_log_path, newline="", encoding="utf-8") as f:
|
|
reader = csv.reader(f)
|
|
next(reader, None) # skip header
|
|
for row in reader:
|
|
if len(row) < 3:
|
|
continue
|
|
ts_str, field, val = row[0], row[1], row[2]
|
|
if field not in ("Lat", "Lon", "Heading"):
|
|
continue
|
|
t = _ts_nav(ts_str)
|
|
try:
|
|
v = float(val)
|
|
except ValueError:
|
|
continue
|
|
if field == "Lat":
|
|
lat_map[t] = v
|
|
elif field == "Lon":
|
|
lon_map[t] = v
|
|
else:
|
|
heading_map[t] = v
|
|
# Join on lat timestamps (master)
|
|
source = nav_log_path.parent.name
|
|
for t, lat in sorted(lat_map.items()):
|
|
lon = lon_map.get(t)
|
|
if lon is None:
|
|
# nearest lon within 1s
|
|
near = min(lon_map.keys(), key=lambda x: abs(x - t), default=None)
|
|
if near is None or abs(near - t) > 1.0:
|
|
continue
|
|
lon = lon_map[near]
|
|
pts.append({
|
|
"t_ms": int(t * 1000),
|
|
"lat": lat,
|
|
"lon": lon,
|
|
"heading": heading_map.get(t),
|
|
"source": source,
|
|
})
|
|
# Simple stride downsampling
|
|
if len(pts) > max_pts:
|
|
step = len(pts) // max_pts
|
|
pts = pts[::step]
|
|
return pts
|
|
|
|
|
|
_track_cache: dict[str, list[dict]] = {}
|
|
|
|
|
|
@app.get("/sorties/{sortie_id:path}/usv_track")
|
|
async def get_usv_track(sortie_id: str):
|
|
"""Return USV GPS track [{t_ms, lat, lon, heading, source}] for map polylines."""
|
|
if sortie_id in _track_cache:
|
|
return JSONResponse(_track_cache[sortie_id])
|
|
raw_dir = OUTPUT_DIR / sortie_id / "raw"
|
|
nav_logs = list(raw_dir.rglob("*_navigation_log.csv")) if raw_dir.exists() else []
|
|
if not nav_logs:
|
|
raise HTTPException(404, "No navigation_log.csv found — run pipeline first")
|
|
pts: list[dict] = []
|
|
for nav_log in nav_logs:
|
|
pts.extend(await asyncio.to_thread(_read_usv_track, nav_log))
|
|
pts.sort(key=lambda p: p["t_ms"])
|
|
_track_cache[sortie_id] = pts
|
|
return JSONResponse(pts)
|