Merge branch 'fix/05-inference-yaml-params' into feature/auto-pipeline
This commit is contained in:
@@ -32,11 +32,24 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
from orchestrator.db import init_db, get_conn, upsert_job, record_metric, now_iso
|
from orchestrator.db import init_db, get_conn, upsert_job, record_metric, now_iso
|
||||||
|
|
||||||
PIPELINE_BASE = Path(os.environ.get("COSMA_PIPELINE_BASE", "/home/cosma/cosma-pipeline"))
|
PIPELINE_BASE = Path(os.environ.get("COSMA_PIPELINE_BASE", "/home/cosma/cosma-pipeline"))
|
||||||
|
|
||||||
|
def _load_inference_cfg() -> dict:
|
||||||
|
"""Load inference params from thresholds.yaml, with sane defaults."""
|
||||||
|
cfg_path = Path(__file__).parent.parent / "config" / "thresholds.yaml"
|
||||||
|
try:
|
||||||
|
data = yaml.safe_load(cfg_path.read_text())
|
||||||
|
return data.get("inference", {})
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
_INF_CFG = _load_inference_cfg()
|
||||||
|
|
||||||
WORKERS = {
|
WORKERS = {
|
||||||
".84": {
|
".84": {
|
||||||
"host": "192.168.0.84",
|
"host": "192.168.0.84",
|
||||||
@@ -146,19 +159,37 @@ def run_inference(frames_dir: Path, worker_key: str, mission_name: str,
|
|||||||
return metrics
|
return metrics
|
||||||
print(f" [05] rsync done")
|
print(f" [05] rsync done")
|
||||||
|
|
||||||
# Step 2: build demo.py command
|
# Step 2: build demo.py command -- params from thresholds.yaml[inference]
|
||||||
checkpoint = f"{w['ai_dir']}/checkpoints/lingbot-map/lingbot-map.pt"
|
checkpoint = f"{w['ai_dir']}/checkpoints/lingbot-map/lingbot-map.pt"
|
||||||
|
inf_mode = _INF_CFG.get("mode", "streaming")
|
||||||
|
conf_thr = _INF_CFG.get("ply_conf_threshold", 1.5)
|
||||||
|
kf_interval = _INF_CFG.get("keyframe_interval", 1)
|
||||||
|
max_frames = _INF_CFG.get("max_frame_num", 1024)
|
||||||
|
if inf_mode == "windowed":
|
||||||
|
window_size = _INF_CFG.get("window_size", 64)
|
||||||
|
overlap_size = _INF_CFG.get("overlap_size", 16)
|
||||||
|
mode_flags = (
|
||||||
|
f"--mode windowed "
|
||||||
|
f"--window_size {window_size} "
|
||||||
|
f"--overlap_size {overlap_size} "
|
||||||
|
)
|
||||||
|
else: # streaming (default, validated GX049839_v2 146M pts)
|
||||||
|
mode_flags = (
|
||||||
|
f"--mode streaming "
|
||||||
|
f"--keyframe_interval {kf_interval} "
|
||||||
|
f"--max_frame_num {max_frames} "
|
||||||
|
)
|
||||||
demo_cmd = (
|
demo_cmd = (
|
||||||
f"cd {w['ai_dir']} && "
|
f"cd {w['ai_dir']} && "
|
||||||
f"{w['venv']} demo.py "
|
f"{w['venv']} demo.py "
|
||||||
f"--model_path {checkpoint} "
|
f"--model_path {checkpoint} "
|
||||||
f"--image_folder {worker_frames} "
|
f"--image_folder {worker_frames} "
|
||||||
f"--mode windowed "
|
f"{mode_flags}"
|
||||||
f"--window_size 64 "
|
f"--ply_conf_threshold {conf_thr} "
|
||||||
f"--overlap_size 16 "
|
|
||||||
f"--save_ply {ply_remote} "
|
f"--save_ply {ply_remote} "
|
||||||
f"--save_poses {npz_remote} "
|
f"--save_poses {npz_remote} "
|
||||||
f"--use_sdpa "
|
f"--use_sdpa "
|
||||||
|
f"--offload_to_cpu "
|
||||||
f"2>&1"
|
f"2>&1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user