208 lines
7.5 KiB
Python
208 lines
7.5 KiB
Python
"""Live 3D reconstruction server.
|
|
|
|
Mobile browser opens "/", grants camera access via getUserMedia, and pushes
|
|
JPEG frames over WebSocket to this process. Each frame is fed into
|
|
lingbot-map in streaming mode (shared KV cache across frames) and the
|
|
resulting point cloud is pushed to a viser scene served on a separate
|
|
port.
|
|
|
|
Usage:
|
|
python server_live.py --model_path /path/to/lingbot-map.pt
|
|
|
|
Open:
|
|
http://<host>:8080 → mobile capture page
|
|
http://<host>:8081 → 3D reconstruction viewer (viser)
|
|
"""
|
|
import argparse
|
|
import asyncio
|
|
import io
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
|
|
|
import numpy as np
|
|
import torch
|
|
import viser
|
|
from aiohttp import web
|
|
from PIL import Image
|
|
|
|
# The demo.py in the upstream repo has load_model() — we reuse it.
|
|
# LINGBOT_MAP_DIR defaults to a sibling checkout of Robbyant/lingbot-map.
|
|
LINGBOT_MAP_DIR = os.environ.get(
|
|
"LINGBOT_MAP_DIR",
|
|
str(Path(__file__).resolve().parent.parent / "lingbot-map"),
|
|
)
|
|
sys.path.insert(0, LINGBOT_MAP_DIR)
|
|
from demo import load_model # noqa: E402
|
|
|
|
IMG_H, IMG_W = 294, 518 # matches demo canonical crop
|
|
SCALE_FRAMES = 4
|
|
MAX_PC_FRAMES = 60 # rolling window of point clouds kept in the scene
|
|
CONF_THRESHOLD = 2.0
|
|
|
|
|
|
def frame_bytes_to_tensor(data: bytes) -> torch.Tensor:
|
|
img = Image.open(io.BytesIO(data)).convert("RGB").resize((IMG_W, IMG_H))
|
|
arr = np.asarray(img, dtype=np.float32) / 255.0
|
|
return torch.from_numpy(arr).permute(2, 0, 1) # [C, H, W]
|
|
|
|
|
|
class LiveServer:
|
|
def __init__(self, model_path: str, web_port: int, viser_port: int):
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
args = argparse.Namespace(
|
|
model_path=model_path,
|
|
image_size=518,
|
|
patch_size=14,
|
|
enable_3d_rope=False,
|
|
use_sdpa=True,
|
|
mode="streaming",
|
|
max_frame_num=10000,
|
|
kv_cache_sliding_window=0,
|
|
num_scale_frames=SCALE_FRAMES,
|
|
camera_num_iterations=4,
|
|
)
|
|
print(f"Loading model from {model_path} ...")
|
|
self.model = load_model(args, self.device)
|
|
self.dtype = (
|
|
torch.bfloat16
|
|
if self.device.type == "cuda" and torch.cuda.get_device_capability()[0] >= 8
|
|
else torch.float32
|
|
)
|
|
if self.dtype != torch.float32 and getattr(self.model, "aggregator", None) is not None:
|
|
self.model.aggregator = self.model.aggregator.to(dtype=self.dtype)
|
|
|
|
self.model.clean_kv_cache()
|
|
|
|
self.viser = viser.ViserServer(host="0.0.0.0", port=viser_port)
|
|
self.viser.scene.world_axes.visible = True
|
|
|
|
self.scale_buffer: list[torch.Tensor] = []
|
|
self.started = False
|
|
self.frame_idx = 0
|
|
self.pc_handles: list = []
|
|
self.lock = asyncio.Lock()
|
|
|
|
self.web_port = web_port
|
|
self.viser_port = viser_port
|
|
|
|
self.app = web.Application(client_max_size=8 * 1024 * 1024)
|
|
static_dir = Path(__file__).resolve().parent / "static"
|
|
self.app.router.add_get("/", self._index)
|
|
self.app.router.add_static("/static/", static_dir)
|
|
self.app.router.add_get("/ws", self._ws)
|
|
|
|
async def _index(self, req: web.Request) -> web.Response:
|
|
return web.FileResponse(Path(__file__).resolve().parent / "static" / "index.html")
|
|
|
|
async def _ws(self, req: web.Request) -> web.WebSocketResponse:
|
|
ws = web.WebSocketResponse(max_msg_size=8 * 1024 * 1024)
|
|
await ws.prepare(req)
|
|
print(f"[ws] client connected from {req.remote}")
|
|
async for msg in ws:
|
|
if msg.type == web.WSMsgType.BINARY:
|
|
t0 = time.time()
|
|
frame = frame_bytes_to_tensor(msg.data)
|
|
await self._process_frame(frame)
|
|
dt = (time.time() - t0) * 1000
|
|
await ws.send_str(f'{{"frame":{self.frame_idx},"ms":{dt:.0f}}}')
|
|
elif msg.type == web.WSMsgType.ERROR:
|
|
print(f"[ws] error: {ws.exception()}")
|
|
print("[ws] client disconnected")
|
|
return ws
|
|
|
|
async def _process_frame(self, frame: torch.Tensor) -> None:
|
|
async with self.lock:
|
|
if not self.started:
|
|
self.scale_buffer.append(frame)
|
|
if len(self.scale_buffer) < SCALE_FRAMES:
|
|
return
|
|
scale = torch.stack(self.scale_buffer, dim=0).unsqueeze(0).to(self.device) # [1,S,3,H,W]
|
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype):
|
|
out = self.model.forward(
|
|
scale,
|
|
num_frame_for_scale=SCALE_FRAMES,
|
|
num_frame_per_block=SCALE_FRAMES,
|
|
causal_inference=True,
|
|
)
|
|
self._push_points(out, imgs=scale[0])
|
|
self.scale_buffer.clear()
|
|
self.started = True
|
|
return
|
|
|
|
img = frame.unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,3,H,W]
|
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype):
|
|
out = self.model.forward(
|
|
img,
|
|
num_frame_for_scale=SCALE_FRAMES,
|
|
num_frame_per_block=1,
|
|
causal_inference=True,
|
|
)
|
|
self._push_points(out, imgs=img[0])
|
|
|
|
def _push_points(self, out: dict, imgs: torch.Tensor) -> None:
|
|
wp = out["world_points"][0].float().cpu().numpy() # [S,H,W,3]
|
|
wp_conf = out["world_points_conf"][0].float().cpu().numpy() # [S,H,W]
|
|
rgb = imgs.cpu().numpy() # [S,3,H,W]
|
|
rgb = np.transpose(rgb, (0, 2, 3, 1)) # [S,H,W,3]
|
|
|
|
for i in range(wp.shape[0]):
|
|
self.frame_idx += 1
|
|
mask = wp_conf[i] > CONF_THRESHOLD
|
|
pts = wp[i][mask]
|
|
cols = (np.clip(rgb[i][mask], 0, 1) * 255).astype(np.uint8)
|
|
if pts.shape[0] == 0:
|
|
continue
|
|
# Downsample to cap scene size
|
|
if pts.shape[0] > 15000:
|
|
idx = np.random.choice(pts.shape[0], 15000, replace=False)
|
|
pts = pts[idx]
|
|
cols = cols[idx]
|
|
h = self.viser.scene.add_point_cloud(
|
|
name=f"/pc/{self.frame_idx:06d}",
|
|
points=pts.astype(np.float32),
|
|
colors=cols,
|
|
point_size=0.005,
|
|
)
|
|
self.pc_handles.append(h)
|
|
while len(self.pc_handles) > MAX_PC_FRAMES:
|
|
old = self.pc_handles.pop(0)
|
|
try:
|
|
old.remove()
|
|
except Exception:
|
|
pass
|
|
|
|
async def run(self) -> None:
|
|
runner = web.AppRunner(self.app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "0.0.0.0", self.web_port)
|
|
await site.start()
|
|
print(f"HTTP : http://0.0.0.0:{self.web_port}/")
|
|
print(f"viser 3D: http://0.0.0.0:{self.viser_port}/")
|
|
print("Ctrl-C to stop.")
|
|
while True:
|
|
await asyncio.sleep(3600)
|
|
|
|
|
|
def main() -> None:
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--model_path", required=True)
|
|
p.add_argument("--web_port", type=int, default=8080)
|
|
p.add_argument("--viser_port", type=int, default=8081)
|
|
args = p.parse_args()
|
|
|
|
server = LiveServer(args.model_path, args.web_port, args.viser_port)
|
|
try:
|
|
asyncio.run(server.run())
|
|
except KeyboardInterrupt:
|
|
print("\nstopped.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|