"""LingBot-MAP demo: streaming 3D reconstruction from images or video. Usage: # Streaming inference (frame-by-frame with KV cache) python examples/demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ # Streaming inference with keyframe KV caching python examples/demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ --mode streaming --keyframe_interval 6 # Windowed inference (for very long sequences, >500 frames) python examples/demo.py --model_path /path/to/checkpoint.pt \ --video_path video.mp4 --fps 10 --mode windowed --window_size 64 # From video with custom FPS sampling python examples/demo.py --model_path /path/to/checkpoint.pt \ --video_path video.mp4 --fps 10 """ import argparse import glob import os import time # Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated # memory gap by letting the caching allocator grow segments on demand instead of # pre-reserving fixed-size blocks. os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") import cv2 import numpy as np import torch from tqdm.auto import tqdm from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri from lingbot_map.utils.geometry import closed_form_inverse_se3_general from lingbot_map.utils.load_fn import load_and_preprocess_images # ============================================================================= # Image loading # ============================================================================= def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png", first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8): """Load images from folder or video and preprocess into a tensor. Returns: (images, paths, resolved_image_folder): preprocessed tensor, file paths, and the folder containing the source images (for sky mask caching etc.). """ if video_path is not None: video_name = os.path.splitext(os.path.basename(video_path))[0] out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames") os.makedirs(out_dir, exist_ok=True) cap = cv2.VideoCapture(video_path) src_fps = cap.get(cv2.CAP_PROP_FPS) or 30 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) interval = max(1, round(src_fps / fps)) idx, saved = 0, [] pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame") while True: ret, frame = cap.read() if not ret: break if idx % interval == 0: path = os.path.join(out_dir, f"{len(saved):06d}.jpg") cv2.imwrite(path, frame) saved.append(path) idx += 1 pbar.update(1) pbar.close() cap.release() paths = saved resolved_folder = out_dir print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})") else: exts = image_ext.split(",") paths = [] for ext in exts: paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}"))) paths = sorted(paths) resolved_folder = image_folder if first_k is not None and first_k > 0: paths = paths[:first_k] if stride > 1: paths = paths[::stride] print(f"Loading {len(paths)} images...") images = load_and_preprocess_images( paths, mode="crop", image_size=image_size, patch_size=patch_size, ) h, w = images.shape[-2:] print(f"Preprocessed images to {w}x{h} using canonical crop mode") return images, paths, resolved_folder # ============================================================================= # Model loading # ============================================================================= def load_model(args, device): """Load GCTStream model from checkpoint.""" if getattr(args, "mode", "streaming") == "windowed": from lingbot_map.models.gct_stream_window import GCTStream else: from lingbot_map.models.gct_stream import GCTStream print("Building model...") model = GCTStream( img_size=args.image_size, patch_size=args.patch_size, enable_3d_rope=args.enable_3d_rope, max_frame_num=args.max_frame_num, kv_cache_sliding_window=args.kv_cache_sliding_window, kv_cache_scale_frames=args.num_scale_frames, kv_cache_cross_frame_special=True, kv_cache_include_scale_frames=True, use_sdpa=args.use_sdpa, camera_num_iterations=args.camera_num_iterations, ) if args.model_path: print(f"Loading checkpoint: {args.model_path}") ckpt = torch.load(args.model_path, map_location=device, weights_only=False) state_dict = ckpt.get("model", ckpt) missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f" Missing keys: {len(missing)}") if unexpected: print(f" Unexpected keys: {len(unexpected)}") print(" Checkpoint loaded.") return model.to(device).eval() # ============================================================================= # torch.compile (opt-in via --compile) # ============================================================================= def compile_model(model): """Compile hot, fixed-shape modules with mode="reduce-overhead". Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script, `model.point_head` is **kept** — the demo needs world_points for visualization. """ agg = model.aggregator for i, b in enumerate(agg.frame_blocks): agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead") for i, b in enumerate(agg.patch_embed.blocks): agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead") for b in agg.global_blocks: if hasattr(b, 'attn_pre'): b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead") if hasattr(b, 'ffn_residual'): b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead") b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead") def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1): """Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times. Warmup inputs are sliced from the already-preprocessed `images` tensor, so their spatial shape matches what real inference will feed — this is what makes the captured CUDA graphs reusable (reduce-overhead mode keys on shape). """ # images: [S, 3, H, W] on device already; slice and add batch dim. warm_scale = images[:scale_frames].unsqueeze(0).to(dtype) warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype) for _ in range(passes): model.clean_kv_cache() torch.compiler.cudagraph_mark_step_begin() with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): model.forward( warm_scale, num_frame_for_scale=scale_frames, num_frame_per_block=scale_frames, causal_inference=True, ) for i in range(warm_stream_n): torch.compiler.cudagraph_mark_step_begin() with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): model.forward( warm_stream[:, i:i + 1], num_frame_for_scale=scale_frames, num_frame_per_block=1, causal_inference=True, ) if torch.cuda.is_available(): torch.cuda.synchronize() # Wipe warmup KV so real inference_streaming starts clean (it also calls # clean_kv_cache internally, but this is defensive + makes intent obvious). model.clean_kv_cache() # ============================================================================= # Post-processing # ============================================================================= _BATCHED_NDIMS = { "pose_enc": 3, "depth": 5, "depth_conf": 4, "world_points": 5, "world_points_conf": 4, "extrinsic": 4, "intrinsic": 4, "chunk_scales": 2, "chunk_transforms": 4, "images": 5, } def _squeeze_single_batch(key, value): """Drop the leading batch dimension for single-sequence demo outputs.""" batched_ndim = _BATCHED_NDIMS.get(key) if batched_ndim is None or not hasattr(value, "ndim"): return value if value.ndim == batched_ndim and value.shape[0] == 1: return value[0] return value def postprocess(predictions, images): """Convert pose encoding to extrinsics (c2w) and move to CPU.""" extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) # Convert w2c to c2w extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype) extrinsic_4x4[..., :3, :4] = extrinsic extrinsic_4x4[..., 3, 3] = 1.0 extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4) extrinsic = extrinsic_4x4[..., :3, :4] predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic predictions.pop("pose_enc_list", None) predictions.pop("images", None) print("Moving results to CPU...") for k in list(predictions.keys()): if isinstance(predictions[k], torch.Tensor): predictions[k] = _squeeze_single_batch( k, predictions[k].to("cpu", non_blocking=True) ) images_cpu = images.to("cpu", non_blocking=True) if torch.cuda.is_available(): torch.cuda.synchronize() return predictions, images_cpu def _save_ply(predictions: dict, out_path: str, conf_threshold: float = 1.5) -> None: """Save world_points above conf_threshold as PLY file.""" import open3d as o3d pts = predictions.get("world_points") conf = predictions.get("world_points_conf") if pts is None: print("WARNING: no world_points in predictions — PLY not saved") return if hasattr(pts, "numpy"): pts = pts.numpy() pts_flat = pts.reshape(-1, 3) if conf is not None: if hasattr(conf, "numpy"): conf = conf.numpy() mask = conf.reshape(-1) > conf_threshold pts_flat = pts_flat[mask] os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True) pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(pts_flat.astype(np.float64)) o3d.io.write_point_cloud(out_path, pcd) print(f"PLY saved: {out_path} ({len(pcd.points):,} pts)") def _save_poses(predictions: dict, n_frames: int, source_fps: float, out_path: str, save_fps: float = 2.0) -> None: """Save camera extrinsics subsampled to save_fps as NPZ. Output NPZ keys: poses (M, 3, 4) float32 c2w extrinsic matrices timestamps_ns (M,) int64 relative timestamps (ns from frame 0) frame_ids (M,) int64 original frame indices """ ext = predictions.get("extrinsic") if ext is None: print("WARNING: no extrinsic in predictions — poses not saved") return if hasattr(ext, "numpy"): ext = ext.numpy() N = ext.shape[0] step = max(1, round(source_fps / save_fps)) if save_fps > 0 else 1 idxs = np.arange(0, N, step, dtype=np.int64) poses_sub = ext[idxs].astype(np.float32) t_ns = (idxs * (1e9 / source_fps)).astype(np.int64) os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True) np.savez(out_path, poses=poses_sub, timestamps_ns=t_ns, frame_ids=idxs) print(f"Poses saved: {out_path} ({len(idxs)} poses @ {save_fps} fps)") def prepare_for_visualization(predictions, images=None): """Convert predictions to the unbatched NumPy format used by vis code.""" vis_predictions = {} for k, v in predictions.items(): if isinstance(v, torch.Tensor): v = _squeeze_single_batch(k, v.detach().cpu()) vis_predictions[k] = v.numpy() elif isinstance(v, np.ndarray): vis_predictions[k] = _squeeze_single_batch(k, v) else: vis_predictions[k] = v if images is None: images = predictions.get("images") if isinstance(images, torch.Tensor): images = images.detach().cpu() if isinstance(images, np.ndarray): images = _squeeze_single_batch("images", images) elif isinstance(images, torch.Tensor): images = _squeeze_single_batch("images", images).numpy() if isinstance(images, torch.Tensor): images = images.numpy() if images is not None: vis_predictions["images"] = images return vis_predictions # ============================================================================= # Main # ============================================================================= def main(): parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo") # Input parser.add_argument("--image_folder", type=str, default=None) parser.add_argument("--video_path", type=str, default=None) parser.add_argument("--fps", type=int, default=10) parser.add_argument("--first_k", type=int, default=None) parser.add_argument("--stride", type=int, default=1) # Model parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--image_size", type=int, default=518) parser.add_argument("--patch_size", type=int, default=14) # Inference mode parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"], help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences") # Streaming options parser.add_argument("--enable_3d_rope", action="store_true", default=True) parser.add_argument("--max_frame_num", type=int, default=1024) parser.add_argument("--num_scale_frames", type=int, default=8) parser.add_argument( "--keyframe_interval", type=int, default=None, help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame. " "If unset, auto-selected: 1 when num_frames <= 320, else ceil(num_frames / 320).", ) parser.add_argument("--kv_cache_sliding_window", type=int, default=64) parser.add_argument("--camera_num_iterations", type=int, default=4, help="Camera head iterative-refinement steps. Default 4; set 1 for faster inference " "(skips 3 refinement passes at a small accuracy cost).") parser.add_argument("--use_sdpa", action="store_true", default=False, help="Use SDPA backend (no flashinfer needed). Default: FlashInfer") parser.add_argument("--compile", action="store_true", default=False, help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. " "Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.") parser.add_argument( "--offload_to_cpu", action=argparse.BooleanOptionalAction, help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. " "Use --no-offload_to_cpu to keep outputs on GPU.", ) # Windowed options parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)") parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows") # Visualization parser.add_argument("--port", type=int, default=8080) parser.add_argument("--conf_threshold", type=float, default=1.5) parser.add_argument("--downsample_factor", type=int, default=10) parser.add_argument("--point_size", type=float, default=0.00001) parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") parser.add_argument("--sky_mask_dir", type=str, default=None, help="Directory for cached sky masks (default: _sky_masks/)") parser.add_argument("--sky_mask_visualization_dir", type=str, default=None, help="Save sky mask visualizations (original | mask | overlay) to this directory") parser.add_argument("--export_preprocessed", type=str, default=None, help="Export stride-sampled, resized/cropped images to this folder") # Output export parser.add_argument("--save_ply", type=str, default=None, help="Save point cloud to this PLY file path") parser.add_argument("--save_poses", type=str, default=None, help="Save camera extrinsics to this NPZ file path") parser.add_argument("--save_poses_fps", type=float, default=2.0, help="Subsampling FPS for saved poses (default 2.0)") args = parser.parse_args() assert args.image_folder or args.video_path, \ "Provide --image_folder or --video_path" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── Load images & model ────────────────────────────────────────────────── t0 = time.time() images, paths, resolved_image_folder = load_images( image_folder=args.image_folder, video_path=args.video_path, fps=args.fps, first_k=args.first_k, stride=args.stride, image_size=args.image_size, patch_size=args.patch_size, ) # Export preprocessed images if requested if args.export_preprocessed: os.makedirs(args.export_preprocessed, exist_ok=True) print(f"Exporting {images.shape[0]} preprocessed images to {args.export_preprocessed}...") for i in range(images.shape[0]): img = (images[i].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) cv2.imwrite( os.path.join(args.export_preprocessed, f"{i:06d}.png"), cv2.cvtColor(img, cv2.COLOR_RGB2BGR), ) print(f"Exported to {args.export_preprocessed}") model = load_model(args, device) print(f"Total load time: {time.time() - t0:.1f}s") # Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm). if torch.cuda.is_available(): dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 else: dtype = torch.float32 # Cast the aggregator (DINOv2-style trunk) to the inference dtype to remove the # redundant fp32 master weight copy + autocast bf16 weight cache (~2-3 GB saved, # no measurable quality change). gct_base._predict_* upcasts inputs to fp32 and # runs each head under `autocast(enabled=False)`, so camera/depth/point heads # keep fp32 weights automatically. if dtype != torch.float32 and getattr(model, "aggregator", None) is not None: print(f"Casting aggregator to {dtype} (heads kept in fp32)") model.aggregator = model.aggregator.to(dtype=dtype) images = images.to(device) num_frames = images.shape[0] print(f"Input: {num_frames} frames, shape {tuple(images.shape)}") print(f"Mode: {args.mode}") if torch.cuda.is_available(): torch.cuda.empty_cache() print( f"GPU mem after load: " f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, " f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB" ) if args.keyframe_interval is None: if args.mode == "streaming" and num_frames > 320: args.keyframe_interval = (num_frames + 319) // 320 print( f"Auto-selected --keyframe_interval={args.keyframe_interval} " f"(num_frames={num_frames} > 320)." ) else: args.keyframe_interval = 1 if args.mode != "streaming" and args.keyframe_interval != 1: print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.") args.keyframe_interval = 1 elif args.mode == "streaming" and args.keyframe_interval > 1: print( f"Keyframe streaming enabled: interval={args.keyframe_interval} " f"(after the first {args.num_scale_frames} scale frames)." ) # ── Optional: torch.compile + CUDA-graph warmup (streaming only) ──────── if args.compile: if args.mode != "streaming": print( f"--compile only applies to --mode streaming (got {args.mode!r}); " "skipping compile." ) else: scale_for_warm = min(args.num_scale_frames, num_frames) warm_stream_n = min(10, max(1, num_frames - scale_for_warm)) print(f"Warmup eager (scale + {warm_stream_n} streaming)...") t_warm = time.time() _warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1) print(f" eager warmup: {time.time() - t_warm:.1f}s") print("Compiling hot modules...") compile_model(model) # 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so # the caching allocator / graph-address map converge on the state the # real inference will see. See gct_profile.py:302-306 for rationale. print("Warmup compiled (3x dress rehearsal)...") t_warm = time.time() _warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3) print(f" compiled warmup: {time.time() - t_warm:.1f}s") # ── Inference ──────────────────────────────────────────────────────────── print(f"Running {args.mode} inference (dtype={dtype})...") t0 = time.time() output_device = torch.device("cpu") if args.offload_to_cpu else None with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): if args.mode == "streaming": predictions = model.inference_streaming( images, num_scale_frames=args.num_scale_frames, keyframe_interval=args.keyframe_interval, output_device=output_device, ) else: # windowed predictions = model.inference_windowed( images, window_size=args.window_size, overlap_size=args.overlap_size, num_scale_frames=args.num_scale_frames, output_device=output_device, ) print(f"Inference done in {time.time() - t0:.1f}s") if torch.cuda.is_available(): print( f"GPU peak during inference: " f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB " f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)" ) # ── Post-process ───────────────────────────────────────────────────────── if args.offload_to_cpu: del images if torch.cuda.is_available(): torch.cuda.empty_cache() images_for_post = predictions["images"] # already CPU else: images_for_post = images predictions, images_cpu = postprocess(predictions, images_for_post) # ── Export ─────────────────────────────────────────────────────────────── source_fps = args.fps if args.video_path else 1.0 if args.save_ply: _save_ply(predictions, args.save_ply, args.conf_threshold) if args.save_poses: _save_poses(predictions, len(paths), source_fps, args.save_poses, args.save_poses_fps) # ── Visualize ──────────────────────────────────────────────────────────── try: from lingbot_map.vis import PointCloudViewer viewer = PointCloudViewer( pred_dict=prepare_for_visualization(predictions, images_cpu), port=args.port, vis_threshold=args.conf_threshold, downsample_factor=args.downsample_factor, point_size=args.point_size, mask_sky=args.mask_sky, image_folder=resolved_image_folder, sky_mask_dir=args.sky_mask_dir, sky_mask_visualization_dir=args.sky_mask_visualization_dir, ) print(f"3D viewer at http://localhost:{args.port}") viewer.run() except ImportError: print("viser not installed. Install with: pip install lingbot-map[vis]") print(f"Predictions contain keys: {list(predictions.keys())}") if __name__ == "__main__": main()