diff --git a/pipeline/stages/05_inference.py b/pipeline/stages/05_inference.py index 05b6369..ff6b8c0 100644 --- a/pipeline/stages/05_inference.py +++ b/pipeline/stages/05_inference.py @@ -195,9 +195,10 @@ def run_inference(frames_dir: Path, worker_key: str, mission_name: str, print(f" [05] Launching inference on {host}...") t0 = time.time() + inf_timeout = int(_INF_CFG.get("inference_timeout_s", 10800)) r = subprocess.run( ["ssh", "-o", "StrictHostKeyChecking=no", ssh_target, demo_cmd], - capture_output=True, text=True, timeout=7200, # 2h max + capture_output=True, text=True, timeout=inf_timeout, ) elapsed = time.time() - t0 metrics["inference_s"] = round(elapsed, 1) @@ -265,6 +266,19 @@ def process_frames_dir(frames_dir: Path, worker_key: str, mission_name: str) -> if not frames: continue print(f"\n[05] === {auv_id}/{seg_dir.name}: {len(frames)} frames ===") + # Guard: min frames required for model (RoPE/attention) + min_frames = int(_INF_CFG.get("min_frames_for_inference", 32)) + if len(frames) < min_frames: + print(f" [05] SKIP {auv_id}/{seg_dir.name}: {len(frames)} frames < {min_frames} min") + init_db() + with get_conn() as conn_mf: + mr = conn_mf.execute("SELECT id FROM missions WHERE name=?", (mission_name,)).fetchone() + if mr: + upsert_job(conn_mf, mr["id"], auv_id, seg_dir.name, "05_inference", + status="skipped", + error_msg=f"frames_too_few={len(frames)}<{min_frames}") + continue + m = run_inference(seg_dir, worker_key, mission_name, auv_id, seg_dir.name) all_metrics.append(m)