"""Live 3D reconstruction server. Mobile browser opens "/", grants camera access via getUserMedia, and pushes JPEG frames over WebSocket to this process. Each frame is fed into lingbot-map in streaming mode (shared KV cache across frames) and the resulting point cloud is pushed to a viser scene served on a separate port. Usage: python server_live.py --model_path /path/to/lingbot-map.pt Open: http://:8080 → mobile capture page http://:8081 → 3D reconstruction viewer (viser) """ import argparse import asyncio import io import os import sys import time from pathlib import Path from typing import Optional os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import numpy as np import torch import viser from aiohttp import web from PIL import Image # The demo.py in the upstream repo has load_model() — we reuse it. # LINGBOT_MAP_DIR defaults to a sibling checkout of Robbyant/lingbot-map. LINGBOT_MAP_DIR = os.environ.get( "LINGBOT_MAP_DIR", str(Path(__file__).resolve().parent.parent / "lingbot-map"), ) sys.path.insert(0, LINGBOT_MAP_DIR) from demo import load_model # noqa: E402 IMG_H, IMG_W = 294, 518 # matches demo canonical crop SCALE_FRAMES = 4 MAX_PC_FRAMES = 60 # rolling window of point clouds kept in the scene CONF_THRESHOLD = 2.0 def frame_bytes_to_tensor(data: bytes) -> torch.Tensor: img = Image.open(io.BytesIO(data)).convert("RGB").resize((IMG_W, IMG_H)) arr = np.asarray(img, dtype=np.float32) / 255.0 return torch.from_numpy(arr).permute(2, 0, 1) # [C, H, W] class LiveServer: def __init__(self, model_path: str, web_port: int, viser_port: int): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = argparse.Namespace( model_path=model_path, image_size=518, patch_size=14, enable_3d_rope=False, use_sdpa=True, mode="streaming", max_frame_num=10000, kv_cache_sliding_window=0, num_scale_frames=SCALE_FRAMES, camera_num_iterations=4, ) print(f"Loading model from {model_path} ...") self.model = load_model(args, self.device) self.dtype = ( torch.bfloat16 if self.device.type == "cuda" and torch.cuda.get_device_capability()[0] >= 8 else torch.float32 ) if self.dtype != torch.float32 and getattr(self.model, "aggregator", None) is not None: self.model.aggregator = self.model.aggregator.to(dtype=self.dtype) self.model.clean_kv_cache() self.viser = viser.ViserServer(host="0.0.0.0", port=viser_port) self.viser.scene.world_axes.visible = True self.scale_buffer: list[torch.Tensor] = [] self.started = False self.frame_idx = 0 self.pc_handles: list = [] self.lock = asyncio.Lock() self.web_port = web_port self.viser_port = viser_port self.app = web.Application(client_max_size=8 * 1024 * 1024) static_dir = Path(__file__).resolve().parent / "static" self.app.router.add_get("/", self._index) self.app.router.add_static("/static/", static_dir) self.app.router.add_get("/ws", self._ws) async def _index(self, req: web.Request) -> web.Response: return web.FileResponse(Path(__file__).resolve().parent / "static" / "index.html") async def _ws(self, req: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse(max_msg_size=8 * 1024 * 1024) await ws.prepare(req) print(f"[ws] client connected from {req.remote}") async for msg in ws: if msg.type == web.WSMsgType.BINARY: t0 = time.time() frame = frame_bytes_to_tensor(msg.data) await self._process_frame(frame) dt = (time.time() - t0) * 1000 await ws.send_str(f'{{"frame":{self.frame_idx},"ms":{dt:.0f}}}') elif msg.type == web.WSMsgType.ERROR: print(f"[ws] error: {ws.exception()}") print("[ws] client disconnected") return ws async def _process_frame(self, frame: torch.Tensor) -> None: async with self.lock: if not self.started: self.scale_buffer.append(frame) if len(self.scale_buffer) < SCALE_FRAMES: return scale = torch.stack(self.scale_buffer, dim=0).unsqueeze(0).to(self.device) # [1,S,3,H,W] with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype): out = self.model.forward( scale, num_frame_for_scale=SCALE_FRAMES, num_frame_per_block=SCALE_FRAMES, causal_inference=True, ) self._push_points(out, imgs=scale[0]) self.scale_buffer.clear() self.started = True return img = frame.unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,3,H,W] with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype): out = self.model.forward( img, num_frame_for_scale=SCALE_FRAMES, num_frame_per_block=1, causal_inference=True, ) self._push_points(out, imgs=img[0]) def _push_points(self, out: dict, imgs: torch.Tensor) -> None: wp = out["world_points"][0].float().cpu().numpy() # [S,H,W,3] wp_conf = out["world_points_conf"][0].float().cpu().numpy() # [S,H,W] rgb = imgs.cpu().numpy() # [S,3,H,W] rgb = np.transpose(rgb, (0, 2, 3, 1)) # [S,H,W,3] for i in range(wp.shape[0]): self.frame_idx += 1 mask = wp_conf[i] > CONF_THRESHOLD pts = wp[i][mask] cols = (np.clip(rgb[i][mask], 0, 1) * 255).astype(np.uint8) if pts.shape[0] == 0: continue # Downsample to cap scene size if pts.shape[0] > 15000: idx = np.random.choice(pts.shape[0], 15000, replace=False) pts = pts[idx] cols = cols[idx] h = self.viser.scene.add_point_cloud( name=f"/pc/{self.frame_idx:06d}", points=pts.astype(np.float32), colors=cols, point_size=0.005, ) self.pc_handles.append(h) while len(self.pc_handles) > MAX_PC_FRAMES: old = self.pc_handles.pop(0) try: old.remove() except Exception: pass async def run(self) -> None: runner = web.AppRunner(self.app) await runner.setup() site = web.TCPSite(runner, "0.0.0.0", self.web_port) await site.start() print(f"HTTP : http://0.0.0.0:{self.web_port}/") print(f"viser 3D: http://0.0.0.0:{self.viser_port}/") print("Ctrl-C to stop.") while True: await asyncio.sleep(3600) def main() -> None: p = argparse.ArgumentParser() p.add_argument("--model_path", required=True) p.add_argument("--web_port", type=int, default=8080) p.add_argument("--viser_port", type=int, default=8081) args = p.parse_args() server = LiveServer(args.model_path, args.web_port, args.viser_port) try: asyncio.run(server.run()) except KeyboardInterrupt: print("\nstopped.") if __name__ == "__main__": main()