initial prototype: aiohttp + WebSocket + viser live reconstruction
This commit is contained in:
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
.venv/
|
||||
checkpoints/
|
||||
66
README.md
Normal file
66
README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# live-reconstruction
|
||||
|
||||
Live 3D reconstruction from a mobile phone camera using
|
||||
[lingbot-map](https://github.com/Robbyant/lingbot-map) in streaming mode.
|
||||
|
||||
```
|
||||
Mobile browser ── getUserMedia ──> JPEG frames
|
||||
── WebSocket ──────> aiohttp server (Python, this repo)
|
||||
├── lingbot-map streaming inference (KV cache, bfloat16)
|
||||
└── viser scene update (rolling window of point clouds)
|
||||
|
||||
Desktop/tablet ───────────────> viser page (http://host:8081) = interactive 3D viewer
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
1. Checkout this repo next to a working `lingbot-map` checkout:
|
||||
|
||||
```
|
||||
~/ai-video/lingbot-map/ # clone of Robbyant/lingbot-map (with venv + pip install -e .)
|
||||
~/ai-video/live-reconstruction/ # this repo
|
||||
```
|
||||
|
||||
2. Download the model weights (once):
|
||||
|
||||
```bash
|
||||
cd ~/ai-video/lingbot-map
|
||||
.venv/bin/python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download('robbyant/lingbot-map', local_dir='./checkpoints/lingbot-map')"
|
||||
```
|
||||
|
||||
3. Install extra deps in the lingbot-map venv:
|
||||
|
||||
```bash
|
||||
~/ai-video/lingbot-map/.venv/bin/pip install aiohttp pillow
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
cd ~/ai-video/live-reconstruction
|
||||
~/ai-video/lingbot-map/.venv/bin/python server_live.py \
|
||||
--model_path ~/ai-video/lingbot-map/checkpoints/lingbot-map/lingbot-map.pt
|
||||
```
|
||||
|
||||
Then:
|
||||
|
||||
- Open `http://<host>:8080/` on your phone → tap **Start camera**.
|
||||
- Open `http://<host>:8081/` on a desktop browser → interactive viser 3D viewer.
|
||||
|
||||
## Constraints
|
||||
|
||||
- Needs a CUDA GPU; tested on RTX 3060 12 GB.
|
||||
- Peak VRAM ~10 GB with bfloat16 + SDPA fallback (FlashInfer not installed).
|
||||
- Throughput on 3060: ~2 frames/s. The mobile page throttles to 2 FPS by default.
|
||||
- `getUserMedia` requires HTTPS on WAN — LAN / VPN exposure only for now.
|
||||
- Free up GPU memory before launching: stop `ollama`, ComfyUI, fish-speech etc.
|
||||
|
||||
## Env
|
||||
|
||||
- `LINGBOT_MAP_DIR` — override path to the upstream lingbot-map checkout
|
||||
(default: `../lingbot-map`).
|
||||
|
||||
## License
|
||||
|
||||
Code in this repo: MIT. Upstream model code: see Robbyant/lingbot-map.
|
||||
207
server_live.py
Normal file
207
server_live.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Live 3D reconstruction server.
|
||||
|
||||
Mobile browser opens "/", grants camera access via getUserMedia, and pushes
|
||||
JPEG frames over WebSocket to this process. Each frame is fed into
|
||||
lingbot-map in streaming mode (shared KV cache across frames) and the
|
||||
resulting point cloud is pushed to a viser scene served on a separate
|
||||
port.
|
||||
|
||||
Usage:
|
||||
python server_live.py --model_path /path/to/lingbot-map.pt
|
||||
|
||||
Open:
|
||||
http://<host>:8080 → mobile capture page
|
||||
http://<host>:8081 → 3D reconstruction viewer (viser)
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import viser
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
|
||||
# The demo.py in the upstream repo has load_model() — we reuse it.
|
||||
# LINGBOT_MAP_DIR defaults to a sibling checkout of Robbyant/lingbot-map.
|
||||
LINGBOT_MAP_DIR = os.environ.get(
|
||||
"LINGBOT_MAP_DIR",
|
||||
str(Path(__file__).resolve().parent.parent / "lingbot-map"),
|
||||
)
|
||||
sys.path.insert(0, LINGBOT_MAP_DIR)
|
||||
from demo import load_model # noqa: E402
|
||||
|
||||
IMG_H, IMG_W = 294, 518 # matches demo canonical crop
|
||||
SCALE_FRAMES = 4
|
||||
MAX_PC_FRAMES = 60 # rolling window of point clouds kept in the scene
|
||||
CONF_THRESHOLD = 2.0
|
||||
|
||||
|
||||
def frame_bytes_to_tensor(data: bytes) -> torch.Tensor:
|
||||
img = Image.open(io.BytesIO(data)).convert("RGB").resize((IMG_W, IMG_H))
|
||||
arr = np.asarray(img, dtype=np.float32) / 255.0
|
||||
return torch.from_numpy(arr).permute(2, 0, 1) # [C, H, W]
|
||||
|
||||
|
||||
class LiveServer:
|
||||
def __init__(self, model_path: str, web_port: int, viser_port: int):
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
args = argparse.Namespace(
|
||||
model_path=model_path,
|
||||
image_size=518,
|
||||
patch_size=14,
|
||||
enable_3d_rope=False,
|
||||
use_sdpa=True,
|
||||
mode="streaming",
|
||||
max_frame_num=10000,
|
||||
kv_cache_sliding_window=0,
|
||||
num_scale_frames=SCALE_FRAMES,
|
||||
camera_num_iterations=4,
|
||||
)
|
||||
print(f"Loading model from {model_path} ...")
|
||||
self.model = load_model(args, self.device)
|
||||
self.dtype = (
|
||||
torch.bfloat16
|
||||
if self.device.type == "cuda" and torch.cuda.get_device_capability()[0] >= 8
|
||||
else torch.float32
|
||||
)
|
||||
if self.dtype != torch.float32 and getattr(self.model, "aggregator", None) is not None:
|
||||
self.model.aggregator = self.model.aggregator.to(dtype=self.dtype)
|
||||
|
||||
self.model.clean_kv_cache()
|
||||
|
||||
self.viser = viser.ViserServer(host="0.0.0.0", port=viser_port)
|
||||
self.viser.scene.world_axes.visible = True
|
||||
|
||||
self.scale_buffer: list[torch.Tensor] = []
|
||||
self.started = False
|
||||
self.frame_idx = 0
|
||||
self.pc_handles: list = []
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
self.web_port = web_port
|
||||
self.viser_port = viser_port
|
||||
|
||||
self.app = web.Application(client_max_size=8 * 1024 * 1024)
|
||||
static_dir = Path(__file__).resolve().parent / "static"
|
||||
self.app.router.add_get("/", self._index)
|
||||
self.app.router.add_static("/static/", static_dir)
|
||||
self.app.router.add_get("/ws", self._ws)
|
||||
|
||||
async def _index(self, req: web.Request) -> web.Response:
|
||||
return web.FileResponse(Path(__file__).resolve().parent / "static" / "index.html")
|
||||
|
||||
async def _ws(self, req: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse(max_msg_size=8 * 1024 * 1024)
|
||||
await ws.prepare(req)
|
||||
print(f"[ws] client connected from {req.remote}")
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.BINARY:
|
||||
t0 = time.time()
|
||||
frame = frame_bytes_to_tensor(msg.data)
|
||||
await self._process_frame(frame)
|
||||
dt = (time.time() - t0) * 1000
|
||||
await ws.send_str(f'{{"frame":{self.frame_idx},"ms":{dt:.0f}}}')
|
||||
elif msg.type == web.WSMsgType.ERROR:
|
||||
print(f"[ws] error: {ws.exception()}")
|
||||
print("[ws] client disconnected")
|
||||
return ws
|
||||
|
||||
async def _process_frame(self, frame: torch.Tensor) -> None:
|
||||
async with self.lock:
|
||||
if not self.started:
|
||||
self.scale_buffer.append(frame)
|
||||
if len(self.scale_buffer) < SCALE_FRAMES:
|
||||
return
|
||||
scale = torch.stack(self.scale_buffer, dim=0).unsqueeze(0).to(self.device) # [1,S,3,H,W]
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype):
|
||||
out = self.model.forward(
|
||||
scale,
|
||||
num_frame_for_scale=SCALE_FRAMES,
|
||||
num_frame_per_block=SCALE_FRAMES,
|
||||
causal_inference=True,
|
||||
)
|
||||
self._push_points(out, imgs=scale[0])
|
||||
self.scale_buffer.clear()
|
||||
self.started = True
|
||||
return
|
||||
|
||||
img = frame.unsqueeze(0).unsqueeze(0).to(self.device) # [1,1,3,H,W]
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=self.dtype):
|
||||
out = self.model.forward(
|
||||
img,
|
||||
num_frame_for_scale=SCALE_FRAMES,
|
||||
num_frame_per_block=1,
|
||||
causal_inference=True,
|
||||
)
|
||||
self._push_points(out, imgs=img[0])
|
||||
|
||||
def _push_points(self, out: dict, imgs: torch.Tensor) -> None:
|
||||
wp = out["world_points"][0].float().cpu().numpy() # [S,H,W,3]
|
||||
wp_conf = out["world_points_conf"][0].float().cpu().numpy() # [S,H,W]
|
||||
rgb = imgs.cpu().numpy() # [S,3,H,W]
|
||||
rgb = np.transpose(rgb, (0, 2, 3, 1)) # [S,H,W,3]
|
||||
|
||||
for i in range(wp.shape[0]):
|
||||
self.frame_idx += 1
|
||||
mask = wp_conf[i] > CONF_THRESHOLD
|
||||
pts = wp[i][mask]
|
||||
cols = (np.clip(rgb[i][mask], 0, 1) * 255).astype(np.uint8)
|
||||
if pts.shape[0] == 0:
|
||||
continue
|
||||
# Downsample to cap scene size
|
||||
if pts.shape[0] > 15000:
|
||||
idx = np.random.choice(pts.shape[0], 15000, replace=False)
|
||||
pts = pts[idx]
|
||||
cols = cols[idx]
|
||||
h = self.viser.scene.add_point_cloud(
|
||||
name=f"/pc/{self.frame_idx:06d}",
|
||||
points=pts.astype(np.float32),
|
||||
colors=cols,
|
||||
point_size=0.005,
|
||||
)
|
||||
self.pc_handles.append(h)
|
||||
while len(self.pc_handles) > MAX_PC_FRAMES:
|
||||
old = self.pc_handles.pop(0)
|
||||
try:
|
||||
old.remove()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def run(self) -> None:
|
||||
runner = web.AppRunner(self.app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, "0.0.0.0", self.web_port)
|
||||
await site.start()
|
||||
print(f"HTTP : http://0.0.0.0:{self.web_port}/")
|
||||
print(f"viser 3D: http://0.0.0.0:{self.viser_port}/")
|
||||
print("Ctrl-C to stop.")
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model_path", required=True)
|
||||
p.add_argument("--web_port", type=int, default=8080)
|
||||
p.add_argument("--viser_port", type=int, default=8081)
|
||||
args = p.parse_args()
|
||||
|
||||
server = LiveServer(args.model_path, args.web_port, args.viser_port)
|
||||
try:
|
||||
asyncio.run(server.run())
|
||||
except KeyboardInterrupt:
|
||||
print("\nstopped.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
114
static/index.html
Normal file
114
static/index.html
Normal file
@@ -0,0 +1,114 @@
|
||||
<!doctype html>
|
||||
<html lang="fr">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>Live 3D Reconstruction — lingbot-map</title>
|
||||
<style>
|
||||
html, body { margin:0; padding:0; background:#111; color:#eee; font-family:ui-monospace,monospace; }
|
||||
header { padding: 8px 12px; background:#1c1c1c; display:flex; justify-content:space-between; align-items:center; }
|
||||
header h1 { font-size: 14px; margin: 0; font-weight: 500; }
|
||||
#status { font-size: 12px; color:#9cf; }
|
||||
#video-wrap { position:relative; width:100%; background:#000; }
|
||||
video { width:100%; display:block; }
|
||||
canvas { display:none; }
|
||||
#controls { padding: 12px; display:flex; gap:8px; flex-wrap:wrap; }
|
||||
button { background:#2b5; color:#000; border:0; padding:10px 16px; font-size:14px; border-radius:4px; cursor:pointer; }
|
||||
button:disabled { background:#555; color:#888; cursor:not-allowed; }
|
||||
#stats { padding: 8px 12px; font-size: 12px; color:#aaa; }
|
||||
#link { padding: 8px 12px; font-size: 13px; }
|
||||
#link a { color:#9cf; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>LIVE RECONSTRUCTION</h1>
|
||||
<span id="status">idle</span>
|
||||
</header>
|
||||
<div id="video-wrap">
|
||||
<video id="v" autoplay muted playsinline></video>
|
||||
<canvas id="c" width="518" height="294"></canvas>
|
||||
</div>
|
||||
<div id="controls">
|
||||
<button id="start">Start camera</button>
|
||||
<button id="stop" disabled>Stop</button>
|
||||
<label style="display:flex;align-items:center;gap:6px;font-size:13px;">
|
||||
FPS <input id="fps" type="number" value="2" min="1" max="10" style="width:3em">
|
||||
</label>
|
||||
</div>
|
||||
<div id="stats">frames sent: <span id="sent">0</span> · last RTT: <span id="rtt">-</span> ms</div>
|
||||
<div id="link">3D viewer → <a id="viser" target="_blank">open viser</a></div>
|
||||
<script>
|
||||
const video = document.getElementById("v");
|
||||
const canvas = document.getElementById("c");
|
||||
const ctx = canvas.getContext("2d");
|
||||
const statusEl = document.getElementById("status");
|
||||
const sentEl = document.getElementById("sent");
|
||||
const rttEl = document.getElementById("rtt");
|
||||
const btnStart = document.getElementById("start");
|
||||
const btnStop = document.getElementById("stop");
|
||||
const fpsInput = document.getElementById("fps");
|
||||
|
||||
document.getElementById("viser").href = `http://${location.hostname}:8081/`;
|
||||
|
||||
let ws = null;
|
||||
let stream = null;
|
||||
let timer = null;
|
||||
let sent = 0;
|
||||
|
||||
function setStatus(s, ok=true){ statusEl.textContent = s; statusEl.style.color = ok? "#9cf" : "#f88"; }
|
||||
|
||||
btnStart.onclick = async () => {
|
||||
try {
|
||||
stream = await navigator.mediaDevices.getUserMedia({
|
||||
video: { facingMode: { ideal: "environment" }, width: { ideal: 1280 }, height: { ideal: 720 } },
|
||||
audio: false,
|
||||
});
|
||||
} catch (e) {
|
||||
setStatus("camera denied: " + e.message, false);
|
||||
return;
|
||||
}
|
||||
video.srcObject = stream;
|
||||
await new Promise(r => video.onloadedmetadata = r);
|
||||
ws = new WebSocket(`ws://${location.host}/ws`);
|
||||
ws.binaryType = "arraybuffer";
|
||||
ws.onopen = () => { setStatus("connected"); startLoop(); btnStart.disabled = true; btnStop.disabled = false; };
|
||||
ws.onclose = () => { setStatus("disconnected", false); stopLoop(); btnStart.disabled = false; btnStop.disabled = true; };
|
||||
ws.onerror = () => setStatus("ws error", false);
|
||||
ws.onmessage = (ev) => {
|
||||
try { const m = JSON.parse(ev.data); rttEl.textContent = m.ms; } catch (_){}
|
||||
};
|
||||
};
|
||||
|
||||
btnStop.onclick = () => {
|
||||
if (ws) ws.close();
|
||||
if (stream) stream.getTracks().forEach(t => t.stop());
|
||||
};
|
||||
|
||||
function startLoop(){
|
||||
const fps = Math.max(1, Math.min(10, Number(fpsInput.value) || 2));
|
||||
const interval = Math.round(1000 / fps);
|
||||
timer = setInterval(() => {
|
||||
if (!ws || ws.readyState !== WebSocket.OPEN) return;
|
||||
if (video.videoWidth === 0) return;
|
||||
ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
|
||||
canvas.toBlob(blob => {
|
||||
if (!blob) return;
|
||||
blob.arrayBuffer().then(buf => {
|
||||
if (ws && ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(buf);
|
||||
sent += 1;
|
||||
sentEl.textContent = sent;
|
||||
}
|
||||
});
|
||||
}, "image/jpeg", 0.7);
|
||||
}, interval);
|
||||
}
|
||||
|
||||
function stopLoop(){
|
||||
if (timer) clearInterval(timer);
|
||||
timer = null;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
Reference in New Issue
Block a user