Files
lingbot-map/demo.py

586 lines
25 KiB
Python

"""LingBot-MAP demo: streaming 3D reconstruction from images or video.
Usage:
# Streaming inference (frame-by-frame with KV cache)
python examples/demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/
# Streaming inference with keyframe KV caching
python examples/demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --mode streaming --keyframe_interval 6
# Windowed inference (for very long sequences, >500 frames)
python examples/demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10 --mode windowed --window_size 64
# From video with custom FPS sampling
python examples/demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10
"""
import argparse
import glob
import os
import time
# Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated
# memory gap by letting the caching allocator grow segments on demand instead of
# pre-reserving fixed-size blocks.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import cv2
import numpy as np
import torch
from tqdm.auto import tqdm
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
from lingbot_map.utils.geometry import closed_form_inverse_se3_general
from lingbot_map.utils.load_fn import load_and_preprocess_images
# =============================================================================
# Image loading
# =============================================================================
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
"""Load images from folder or video and preprocess into a tensor.
Returns:
(images, paths, resolved_image_folder): preprocessed tensor, file paths,
and the folder containing the source images (for sky mask caching etc.).
"""
if video_path is not None:
video_name = os.path.splitext(os.path.basename(video_path))[0]
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
os.makedirs(out_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
src_fps = cap.get(cv2.CAP_PROP_FPS) or 30
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
interval = max(1, round(src_fps / fps))
idx, saved = 0, []
pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame")
while True:
ret, frame = cap.read()
if not ret:
break
if idx % interval == 0:
path = os.path.join(out_dir, f"{len(saved):06d}.jpg")
cv2.imwrite(path, frame)
saved.append(path)
idx += 1
pbar.update(1)
pbar.close()
cap.release()
paths = saved
resolved_folder = out_dir
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
else:
exts = image_ext.split(",")
paths = []
for ext in exts:
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
paths = sorted(paths)
resolved_folder = image_folder
if first_k is not None and first_k > 0:
paths = paths[:first_k]
if stride > 1:
paths = paths[::stride]
print(f"Loading {len(paths)} images...")
images = load_and_preprocess_images(
paths,
mode="crop",
image_size=image_size,
patch_size=patch_size,
)
h, w = images.shape[-2:]
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
return images, paths, resolved_folder
# =============================================================================
# Model loading
# =============================================================================
def load_model(args, device):
"""Load GCTStream model from checkpoint."""
if getattr(args, "mode", "streaming") == "windowed":
from lingbot_map.models.gct_stream_window import GCTStream
else:
from lingbot_map.models.gct_stream import GCTStream
print("Building model...")
model = GCTStream(
img_size=args.image_size,
patch_size=args.patch_size,
enable_3d_rope=args.enable_3d_rope,
max_frame_num=args.max_frame_num,
kv_cache_sliding_window=args.kv_cache_sliding_window,
kv_cache_scale_frames=args.num_scale_frames,
kv_cache_cross_frame_special=True,
kv_cache_include_scale_frames=True,
use_sdpa=args.use_sdpa,
camera_num_iterations=args.camera_num_iterations,
)
if args.model_path:
print(f"Loading checkpoint: {args.model_path}")
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f" Missing keys: {len(missing)}")
if unexpected:
print(f" Unexpected keys: {len(unexpected)}")
print(" Checkpoint loaded.")
return model.to(device).eval()
# =============================================================================
# torch.compile (opt-in via --compile)
# =============================================================================
def compile_model(model):
"""Compile hot, fixed-shape modules with mode="reduce-overhead".
Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script,
`model.point_head` is **kept** — the demo needs world_points for visualization.
"""
agg = model.aggregator
for i, b in enumerate(agg.frame_blocks):
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
for i, b in enumerate(agg.patch_embed.blocks):
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
for b in agg.global_blocks:
if hasattr(b, 'attn_pre'):
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
if hasattr(b, 'ffn_residual'):
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1):
"""Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times.
Warmup inputs are sliced from the already-preprocessed `images` tensor, so their
spatial shape matches what real inference will feed — this is what makes the
captured CUDA graphs reusable (reduce-overhead mode keys on shape).
"""
# images: [S, 3, H, W] on device already; slice and add batch dim.
warm_scale = images[:scale_frames].unsqueeze(0).to(dtype)
warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype)
for _ in range(passes):
model.clean_kv_cache()
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
model.forward(
warm_scale,
num_frame_for_scale=scale_frames,
num_frame_per_block=scale_frames,
causal_inference=True,
)
for i in range(warm_stream_n):
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
model.forward(
warm_stream[:, i:i + 1],
num_frame_for_scale=scale_frames,
num_frame_per_block=1,
causal_inference=True,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
# Wipe warmup KV so real inference_streaming starts clean (it also calls
# clean_kv_cache internally, but this is defensive + makes intent obvious).
model.clean_kv_cache()
# =============================================================================
# Post-processing
# =============================================================================
_BATCHED_NDIMS = {
"pose_enc": 3,
"depth": 5,
"depth_conf": 4,
"world_points": 5,
"world_points_conf": 4,
"extrinsic": 4,
"intrinsic": 4,
"chunk_scales": 2,
"chunk_transforms": 4,
"images": 5,
}
def _squeeze_single_batch(key, value):
"""Drop the leading batch dimension for single-sequence demo outputs."""
batched_ndim = _BATCHED_NDIMS.get(key)
if batched_ndim is None or not hasattr(value, "ndim"):
return value
if value.ndim == batched_ndim and value.shape[0] == 1:
return value[0]
return value
def postprocess(predictions, images):
"""Convert pose encoding to extrinsics (c2w) and move to CPU."""
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
# Convert w2c to c2w
extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
extrinsic_4x4[..., :3, :4] = extrinsic
extrinsic_4x4[..., 3, 3] = 1.0
extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
extrinsic = extrinsic_4x4[..., :3, :4]
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
predictions.pop("pose_enc_list", None)
predictions.pop("images", None)
print("Moving results to CPU...")
for k in list(predictions.keys()):
if isinstance(predictions[k], torch.Tensor):
predictions[k] = _squeeze_single_batch(
k, predictions[k].to("cpu", non_blocking=True)
)
images_cpu = images.to("cpu", non_blocking=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return predictions, images_cpu
def _save_ply(predictions: dict, out_path: str, conf_threshold: float = 1.5) -> None:
"""Save world_points above conf_threshold as PLY file."""
import open3d as o3d
pts = predictions.get("world_points")
conf = predictions.get("world_points_conf")
if pts is None:
print("WARNING: no world_points in predictions — PLY not saved")
return
if hasattr(pts, "numpy"):
pts = pts.numpy()
pts_flat = pts.reshape(-1, 3)
if conf is not None:
if hasattr(conf, "numpy"):
conf = conf.numpy()
mask = conf.reshape(-1) > conf_threshold
pts_flat = pts_flat[mask]
os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts_flat.astype(np.float64))
o3d.io.write_point_cloud(out_path, pcd)
print(f"PLY saved: {out_path} ({len(pcd.points):,} pts)")
def _save_poses(predictions: dict, n_frames: int, source_fps: float,
out_path: str, save_fps: float = 2.0) -> None:
"""Save camera extrinsics subsampled to save_fps as NPZ.
Output NPZ keys:
poses (M, 3, 4) float32 c2w extrinsic matrices
timestamps_ns (M,) int64 relative timestamps (ns from frame 0)
frame_ids (M,) int64 original frame indices
"""
ext = predictions.get("extrinsic")
if ext is None:
print("WARNING: no extrinsic in predictions — poses not saved")
return
if hasattr(ext, "numpy"):
ext = ext.numpy()
N = ext.shape[0]
step = max(1, round(source_fps / save_fps)) if save_fps > 0 else 1
idxs = np.arange(0, N, step, dtype=np.int64)
poses_sub = ext[idxs].astype(np.float32)
t_ns = (idxs * (1e9 / source_fps)).astype(np.int64)
os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
np.savez(out_path, poses=poses_sub, timestamps_ns=t_ns, frame_ids=idxs)
print(f"Poses saved: {out_path} ({len(idxs)} poses @ {save_fps} fps)")
def prepare_for_visualization(predictions, images=None):
"""Convert predictions to the unbatched NumPy format used by vis code."""
vis_predictions = {}
for k, v in predictions.items():
if isinstance(v, torch.Tensor):
v = _squeeze_single_batch(k, v.detach().cpu())
vis_predictions[k] = v.numpy()
elif isinstance(v, np.ndarray):
vis_predictions[k] = _squeeze_single_batch(k, v)
else:
vis_predictions[k] = v
if images is None:
images = predictions.get("images")
if isinstance(images, torch.Tensor):
images = images.detach().cpu()
if isinstance(images, np.ndarray):
images = _squeeze_single_batch("images", images)
elif isinstance(images, torch.Tensor):
images = _squeeze_single_batch("images", images).numpy()
if isinstance(images, torch.Tensor):
images = images.numpy()
if images is not None:
vis_predictions["images"] = images
return vis_predictions
# =============================================================================
# Main
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo")
# Input
parser.add_argument("--image_folder", type=str, default=None)
parser.add_argument("--video_path", type=str, default=None)
parser.add_argument("--fps", type=int, default=10)
parser.add_argument("--first_k", type=int, default=None)
parser.add_argument("--stride", type=int, default=1)
# Model
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--image_size", type=int, default=518)
parser.add_argument("--patch_size", type=int, default=14)
# Inference mode
parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"],
help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences")
# Streaming options
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
parser.add_argument("--max_frame_num", type=int, default=1024)
parser.add_argument("--num_scale_frames", type=int, default=8)
parser.add_argument(
"--keyframe_interval",
type=int,
default=None,
help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame. "
"If unset, auto-selected: 1 when num_frames <= 320, else ceil(num_frames / 320).",
)
parser.add_argument("--kv_cache_sliding_window", type=int, default=64)
parser.add_argument("--camera_num_iterations", type=int, default=4,
help="Camera head iterative-refinement steps. Default 4; set 1 for faster inference "
"(skips 3 refinement passes at a small accuracy cost).")
parser.add_argument("--use_sdpa", action="store_true", default=False,
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
parser.add_argument("--compile", action="store_true", default=False,
help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. "
"Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.")
parser.add_argument(
"--offload_to_cpu",
action=argparse.BooleanOptionalAction,
help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. "
"Use --no-offload_to_cpu to keep outputs on GPU.",
)
# Windowed options
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows")
# Visualization
parser.add_argument("--port", type=int, default=8080)
parser.add_argument("--conf_threshold", type=float, default=1.5)
parser.add_argument("--downsample_factor", type=int, default=10)
parser.add_argument("--point_size", type=float, default=0.00001)
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
parser.add_argument("--sky_mask_dir", type=str, default=None,
help="Directory for cached sky masks (default: <image_folder>_sky_masks/)")
parser.add_argument("--sky_mask_visualization_dir", type=str, default=None,
help="Save sky mask visualizations (original | mask | overlay) to this directory")
parser.add_argument("--export_preprocessed", type=str, default=None,
help="Export stride-sampled, resized/cropped images to this folder")
# Output export
parser.add_argument("--save_ply", type=str, default=None,
help="Save point cloud to this PLY file path")
parser.add_argument("--save_poses", type=str, default=None,
help="Save camera extrinsics to this NPZ file path")
parser.add_argument("--save_poses_fps", type=float, default=2.0,
help="Subsampling FPS for saved poses (default 2.0)")
args = parser.parse_args()
assert args.image_folder or args.video_path, \
"Provide --image_folder or --video_path"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Load images & model ──────────────────────────────────────────────────
t0 = time.time()
images, paths, resolved_image_folder = load_images(
image_folder=args.image_folder, video_path=args.video_path,
fps=args.fps, first_k=args.first_k, stride=args.stride,
image_size=args.image_size, patch_size=args.patch_size,
)
# Export preprocessed images if requested
if args.export_preprocessed:
os.makedirs(args.export_preprocessed, exist_ok=True)
print(f"Exporting {images.shape[0]} preprocessed images to {args.export_preprocessed}...")
for i in range(images.shape[0]):
img = (images[i].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite(
os.path.join(args.export_preprocessed, f"{i:06d}.png"),
cv2.cvtColor(img, cv2.COLOR_RGB2BGR),
)
print(f"Exported to {args.export_preprocessed}")
model = load_model(args, device)
print(f"Total load time: {time.time() - t0:.1f}s")
# Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm).
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32
# Cast the aggregator (DINOv2-style trunk) to the inference dtype to remove the
# redundant fp32 master weight copy + autocast bf16 weight cache (~2-3 GB saved,
# no measurable quality change). gct_base._predict_* upcasts inputs to fp32 and
# runs each head under `autocast(enabled=False)`, so camera/depth/point heads
# keep fp32 weights automatically.
if dtype != torch.float32 and getattr(model, "aggregator", None) is not None:
print(f"Casting aggregator to {dtype} (heads kept in fp32)")
model.aggregator = model.aggregator.to(dtype=dtype)
images = images.to(device)
num_frames = images.shape[0]
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
print(f"Mode: {args.mode}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(
f"GPU mem after load: "
f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, "
f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB"
)
if args.keyframe_interval is None:
if args.mode == "streaming" and num_frames > 320:
args.keyframe_interval = (num_frames + 319) // 320
print(
f"Auto-selected --keyframe_interval={args.keyframe_interval} "
f"(num_frames={num_frames} > 320)."
)
else:
args.keyframe_interval = 1
if args.mode != "streaming" and args.keyframe_interval != 1:
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
args.keyframe_interval = 1
elif args.mode == "streaming" and args.keyframe_interval > 1:
print(
f"Keyframe streaming enabled: interval={args.keyframe_interval} "
f"(after the first {args.num_scale_frames} scale frames)."
)
# ── Optional: torch.compile + CUDA-graph warmup (streaming only) ────────
if args.compile:
if args.mode != "streaming":
print(
f"--compile only applies to --mode streaming (got {args.mode!r}); "
"skipping compile."
)
else:
scale_for_warm = min(args.num_scale_frames, num_frames)
warm_stream_n = min(10, max(1, num_frames - scale_for_warm))
print(f"Warmup eager (scale + {warm_stream_n} streaming)...")
t_warm = time.time()
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1)
print(f" eager warmup: {time.time() - t_warm:.1f}s")
print("Compiling hot modules...")
compile_model(model)
# 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so
# the caching allocator / graph-address map converge on the state the
# real inference will see. See gct_profile.py:302-306 for rationale.
print("Warmup compiled (3x dress rehearsal)...")
t_warm = time.time()
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3)
print(f" compiled warmup: {time.time() - t_warm:.1f}s")
# ── Inference ────────────────────────────────────────────────────────────
print(f"Running {args.mode} inference (dtype={dtype})...")
t0 = time.time()
output_device = torch.device("cpu") if args.offload_to_cpu else None
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
if args.mode == "streaming":
predictions = model.inference_streaming(
images,
num_scale_frames=args.num_scale_frames,
keyframe_interval=args.keyframe_interval,
output_device=output_device,
)
else: # windowed
predictions = model.inference_windowed(
images,
window_size=args.window_size,
overlap_size=args.overlap_size,
num_scale_frames=args.num_scale_frames,
output_device=output_device,
)
print(f"Inference done in {time.time() - t0:.1f}s")
if torch.cuda.is_available():
print(
f"GPU peak during inference: "
f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB "
f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)"
)
# ── Post-process ─────────────────────────────────────────────────────────
if args.offload_to_cpu:
del images
if torch.cuda.is_available():
torch.cuda.empty_cache()
images_for_post = predictions["images"] # already CPU
else:
images_for_post = images
predictions, images_cpu = postprocess(predictions, images_for_post)
# ── Export ───────────────────────────────────────────────────────────────
source_fps = args.fps if args.video_path else 1.0
if args.save_ply:
_save_ply(predictions, args.save_ply, args.conf_threshold)
if args.save_poses:
_save_poses(predictions, len(paths), source_fps, args.save_poses, args.save_poses_fps)
# ── Visualize ────────────────────────────────────────────────────────────
try:
from lingbot_map.vis import PointCloudViewer
viewer = PointCloudViewer(
pred_dict=prepare_for_visualization(predictions, images_cpu),
port=args.port,
vis_threshold=args.conf_threshold,
downsample_factor=args.downsample_factor,
point_size=args.point_size,
mask_sky=args.mask_sky,
image_folder=resolved_image_folder,
sky_mask_dir=args.sky_mask_dir,
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
)
print(f"3D viewer at http://localhost:{args.port}")
viewer.run()
except ImportError:
print("viser not installed. Install with: pip install lingbot-map[vis]")
print(f"Predictions contain keys: {list(predictions.keys())}")
if __name__ == "__main__":
main()