586 lines
25 KiB
Python
586 lines
25 KiB
Python
"""LingBot-MAP demo: streaming 3D reconstruction from images or video.
|
|
|
|
Usage:
|
|
# Streaming inference (frame-by-frame with KV cache)
|
|
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
|
--image_folder /path/to/images/
|
|
|
|
# Streaming inference with keyframe KV caching
|
|
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
|
--image_folder /path/to/images/ --mode streaming --keyframe_interval 6
|
|
|
|
# Windowed inference (for very long sequences, >500 frames)
|
|
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
|
--video_path video.mp4 --fps 10 --mode windowed --window_size 64
|
|
|
|
# From video with custom FPS sampling
|
|
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
|
--video_path video.mp4 --fps 10
|
|
"""
|
|
|
|
import argparse
|
|
import glob
|
|
import os
|
|
import time
|
|
|
|
# Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated
|
|
# memory gap by letting the caching allocator grow segments on demand instead of
|
|
# pre-reserving fixed-size blocks.
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
|
|
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
|
from lingbot_map.utils.geometry import closed_form_inverse_se3_general
|
|
from lingbot_map.utils.load_fn import load_and_preprocess_images
|
|
|
|
|
|
# =============================================================================
|
|
# Image loading
|
|
# =============================================================================
|
|
|
|
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
|
|
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
|
|
"""Load images from folder or video and preprocess into a tensor.
|
|
|
|
Returns:
|
|
(images, paths, resolved_image_folder): preprocessed tensor, file paths,
|
|
and the folder containing the source images (for sky mask caching etc.).
|
|
"""
|
|
if video_path is not None:
|
|
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
|
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
cap = cv2.VideoCapture(video_path)
|
|
src_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
interval = max(1, round(src_fps / fps))
|
|
idx, saved = 0, []
|
|
pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame")
|
|
while True:
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
if idx % interval == 0:
|
|
path = os.path.join(out_dir, f"{len(saved):06d}.jpg")
|
|
cv2.imwrite(path, frame)
|
|
saved.append(path)
|
|
idx += 1
|
|
pbar.update(1)
|
|
pbar.close()
|
|
cap.release()
|
|
paths = saved
|
|
resolved_folder = out_dir
|
|
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
|
|
else:
|
|
exts = image_ext.split(",")
|
|
paths = []
|
|
for ext in exts:
|
|
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
|
|
paths = sorted(paths)
|
|
resolved_folder = image_folder
|
|
|
|
if first_k is not None and first_k > 0:
|
|
paths = paths[:first_k]
|
|
if stride > 1:
|
|
paths = paths[::stride]
|
|
|
|
print(f"Loading {len(paths)} images...")
|
|
images = load_and_preprocess_images(
|
|
paths,
|
|
mode="crop",
|
|
image_size=image_size,
|
|
patch_size=patch_size,
|
|
)
|
|
h, w = images.shape[-2:]
|
|
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
|
|
return images, paths, resolved_folder
|
|
|
|
|
|
# =============================================================================
|
|
# Model loading
|
|
# =============================================================================
|
|
|
|
def load_model(args, device):
|
|
"""Load GCTStream model from checkpoint."""
|
|
if getattr(args, "mode", "streaming") == "windowed":
|
|
from lingbot_map.models.gct_stream_window import GCTStream
|
|
else:
|
|
from lingbot_map.models.gct_stream import GCTStream
|
|
|
|
print("Building model...")
|
|
model = GCTStream(
|
|
img_size=args.image_size,
|
|
patch_size=args.patch_size,
|
|
enable_3d_rope=args.enable_3d_rope,
|
|
max_frame_num=args.max_frame_num,
|
|
kv_cache_sliding_window=args.kv_cache_sliding_window,
|
|
kv_cache_scale_frames=args.num_scale_frames,
|
|
kv_cache_cross_frame_special=True,
|
|
kv_cache_include_scale_frames=True,
|
|
use_sdpa=args.use_sdpa,
|
|
camera_num_iterations=args.camera_num_iterations,
|
|
)
|
|
|
|
if args.model_path:
|
|
print(f"Loading checkpoint: {args.model_path}")
|
|
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
|
|
state_dict = ckpt.get("model", ckpt)
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
|
if missing:
|
|
print(f" Missing keys: {len(missing)}")
|
|
if unexpected:
|
|
print(f" Unexpected keys: {len(unexpected)}")
|
|
print(" Checkpoint loaded.")
|
|
|
|
return model.to(device).eval()
|
|
|
|
|
|
# =============================================================================
|
|
# torch.compile (opt-in via --compile)
|
|
# =============================================================================
|
|
|
|
def compile_model(model):
|
|
"""Compile hot, fixed-shape modules with mode="reduce-overhead".
|
|
|
|
Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script,
|
|
`model.point_head` is **kept** — the demo needs world_points for visualization.
|
|
"""
|
|
agg = model.aggregator
|
|
for i, b in enumerate(agg.frame_blocks):
|
|
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
|
|
for i, b in enumerate(agg.patch_embed.blocks):
|
|
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
|
|
for b in agg.global_blocks:
|
|
if hasattr(b, 'attn_pre'):
|
|
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
|
|
if hasattr(b, 'ffn_residual'):
|
|
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
|
|
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
|
|
|
|
|
|
def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1):
|
|
"""Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times.
|
|
|
|
Warmup inputs are sliced from the already-preprocessed `images` tensor, so their
|
|
spatial shape matches what real inference will feed — this is what makes the
|
|
captured CUDA graphs reusable (reduce-overhead mode keys on shape).
|
|
"""
|
|
# images: [S, 3, H, W] on device already; slice and add batch dim.
|
|
warm_scale = images[:scale_frames].unsqueeze(0).to(dtype)
|
|
warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype)
|
|
|
|
for _ in range(passes):
|
|
model.clean_kv_cache()
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
|
model.forward(
|
|
warm_scale,
|
|
num_frame_for_scale=scale_frames,
|
|
num_frame_per_block=scale_frames,
|
|
causal_inference=True,
|
|
)
|
|
for i in range(warm_stream_n):
|
|
torch.compiler.cudagraph_mark_step_begin()
|
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
|
model.forward(
|
|
warm_stream[:, i:i + 1],
|
|
num_frame_for_scale=scale_frames,
|
|
num_frame_per_block=1,
|
|
causal_inference=True,
|
|
)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
# Wipe warmup KV so real inference_streaming starts clean (it also calls
|
|
# clean_kv_cache internally, but this is defensive + makes intent obvious).
|
|
model.clean_kv_cache()
|
|
|
|
|
|
# =============================================================================
|
|
# Post-processing
|
|
# =============================================================================
|
|
|
|
_BATCHED_NDIMS = {
|
|
"pose_enc": 3,
|
|
"depth": 5,
|
|
"depth_conf": 4,
|
|
"world_points": 5,
|
|
"world_points_conf": 4,
|
|
"extrinsic": 4,
|
|
"intrinsic": 4,
|
|
"chunk_scales": 2,
|
|
"chunk_transforms": 4,
|
|
"images": 5,
|
|
}
|
|
|
|
|
|
def _squeeze_single_batch(key, value):
|
|
"""Drop the leading batch dimension for single-sequence demo outputs."""
|
|
batched_ndim = _BATCHED_NDIMS.get(key)
|
|
if batched_ndim is None or not hasattr(value, "ndim"):
|
|
return value
|
|
if value.ndim == batched_ndim and value.shape[0] == 1:
|
|
return value[0]
|
|
return value
|
|
|
|
|
|
def postprocess(predictions, images):
|
|
"""Convert pose encoding to extrinsics (c2w) and move to CPU."""
|
|
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
|
|
|
|
# Convert w2c to c2w
|
|
extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
|
|
extrinsic_4x4[..., :3, :4] = extrinsic
|
|
extrinsic_4x4[..., 3, 3] = 1.0
|
|
extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
|
|
extrinsic = extrinsic_4x4[..., :3, :4]
|
|
|
|
predictions["extrinsic"] = extrinsic
|
|
predictions["intrinsic"] = intrinsic
|
|
predictions.pop("pose_enc_list", None)
|
|
predictions.pop("images", None)
|
|
|
|
print("Moving results to CPU...")
|
|
for k in list(predictions.keys()):
|
|
if isinstance(predictions[k], torch.Tensor):
|
|
predictions[k] = _squeeze_single_batch(
|
|
k, predictions[k].to("cpu", non_blocking=True)
|
|
)
|
|
images_cpu = images.to("cpu", non_blocking=True)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
return predictions, images_cpu
|
|
|
|
|
|
def _save_ply(predictions: dict, out_path: str, conf_threshold: float = 1.5) -> None:
|
|
"""Save world_points above conf_threshold as PLY file."""
|
|
import open3d as o3d
|
|
pts = predictions.get("world_points")
|
|
conf = predictions.get("world_points_conf")
|
|
if pts is None:
|
|
print("WARNING: no world_points in predictions — PLY not saved")
|
|
return
|
|
if hasattr(pts, "numpy"):
|
|
pts = pts.numpy()
|
|
pts_flat = pts.reshape(-1, 3)
|
|
if conf is not None:
|
|
if hasattr(conf, "numpy"):
|
|
conf = conf.numpy()
|
|
mask = conf.reshape(-1) > conf_threshold
|
|
pts_flat = pts_flat[mask]
|
|
os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
|
|
pcd = o3d.geometry.PointCloud()
|
|
pcd.points = o3d.utility.Vector3dVector(pts_flat.astype(np.float64))
|
|
o3d.io.write_point_cloud(out_path, pcd)
|
|
print(f"PLY saved: {out_path} ({len(pcd.points):,} pts)")
|
|
|
|
|
|
def _save_poses(predictions: dict, n_frames: int, source_fps: float,
|
|
out_path: str, save_fps: float = 2.0) -> None:
|
|
"""Save camera extrinsics subsampled to save_fps as NPZ.
|
|
|
|
Output NPZ keys:
|
|
poses (M, 3, 4) float32 c2w extrinsic matrices
|
|
timestamps_ns (M,) int64 relative timestamps (ns from frame 0)
|
|
frame_ids (M,) int64 original frame indices
|
|
"""
|
|
ext = predictions.get("extrinsic")
|
|
if ext is None:
|
|
print("WARNING: no extrinsic in predictions — poses not saved")
|
|
return
|
|
if hasattr(ext, "numpy"):
|
|
ext = ext.numpy()
|
|
N = ext.shape[0]
|
|
step = max(1, round(source_fps / save_fps)) if save_fps > 0 else 1
|
|
idxs = np.arange(0, N, step, dtype=np.int64)
|
|
poses_sub = ext[idxs].astype(np.float32)
|
|
t_ns = (idxs * (1e9 / source_fps)).astype(np.int64)
|
|
os.makedirs(os.path.dirname(os.path.abspath(out_path)) or ".", exist_ok=True)
|
|
np.savez(out_path, poses=poses_sub, timestamps_ns=t_ns, frame_ids=idxs)
|
|
print(f"Poses saved: {out_path} ({len(idxs)} poses @ {save_fps} fps)")
|
|
|
|
|
|
def prepare_for_visualization(predictions, images=None):
|
|
"""Convert predictions to the unbatched NumPy format used by vis code."""
|
|
vis_predictions = {}
|
|
for k, v in predictions.items():
|
|
if isinstance(v, torch.Tensor):
|
|
v = _squeeze_single_batch(k, v.detach().cpu())
|
|
vis_predictions[k] = v.numpy()
|
|
elif isinstance(v, np.ndarray):
|
|
vis_predictions[k] = _squeeze_single_batch(k, v)
|
|
else:
|
|
vis_predictions[k] = v
|
|
|
|
if images is None:
|
|
images = predictions.get("images")
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
images = images.detach().cpu()
|
|
if isinstance(images, np.ndarray):
|
|
images = _squeeze_single_batch("images", images)
|
|
elif isinstance(images, torch.Tensor):
|
|
images = _squeeze_single_batch("images", images).numpy()
|
|
|
|
if isinstance(images, torch.Tensor):
|
|
images = images.numpy()
|
|
|
|
if images is not None:
|
|
vis_predictions["images"] = images
|
|
|
|
return vis_predictions
|
|
|
|
|
|
# =============================================================================
|
|
# Main
|
|
# =============================================================================
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo")
|
|
|
|
# Input
|
|
parser.add_argument("--image_folder", type=str, default=None)
|
|
parser.add_argument("--video_path", type=str, default=None)
|
|
parser.add_argument("--fps", type=int, default=10)
|
|
parser.add_argument("--first_k", type=int, default=None)
|
|
parser.add_argument("--stride", type=int, default=1)
|
|
|
|
# Model
|
|
parser.add_argument("--model_path", type=str, required=True)
|
|
parser.add_argument("--image_size", type=int, default=518)
|
|
parser.add_argument("--patch_size", type=int, default=14)
|
|
|
|
# Inference mode
|
|
parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"],
|
|
help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences")
|
|
|
|
# Streaming options
|
|
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
|
|
parser.add_argument("--max_frame_num", type=int, default=1024)
|
|
parser.add_argument("--num_scale_frames", type=int, default=8)
|
|
parser.add_argument(
|
|
"--keyframe_interval",
|
|
type=int,
|
|
default=None,
|
|
help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame. "
|
|
"If unset, auto-selected: 1 when num_frames <= 320, else ceil(num_frames / 320).",
|
|
)
|
|
parser.add_argument("--kv_cache_sliding_window", type=int, default=64)
|
|
parser.add_argument("--camera_num_iterations", type=int, default=4,
|
|
help="Camera head iterative-refinement steps. Default 4; set 1 for faster inference "
|
|
"(skips 3 refinement passes at a small accuracy cost).")
|
|
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
|
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
|
parser.add_argument("--compile", action="store_true", default=False,
|
|
help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. "
|
|
"Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.")
|
|
parser.add_argument(
|
|
"--offload_to_cpu",
|
|
action=argparse.BooleanOptionalAction,
|
|
help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. "
|
|
"Use --no-offload_to_cpu to keep outputs on GPU.",
|
|
)
|
|
# Windowed options
|
|
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
|
|
parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows")
|
|
|
|
|
|
# Visualization
|
|
parser.add_argument("--port", type=int, default=8080)
|
|
parser.add_argument("--conf_threshold", type=float, default=1.5)
|
|
parser.add_argument("--downsample_factor", type=int, default=10)
|
|
parser.add_argument("--point_size", type=float, default=0.00001)
|
|
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
|
|
parser.add_argument("--sky_mask_dir", type=str, default=None,
|
|
help="Directory for cached sky masks (default: <image_folder>_sky_masks/)")
|
|
parser.add_argument("--sky_mask_visualization_dir", type=str, default=None,
|
|
help="Save sky mask visualizations (original | mask | overlay) to this directory")
|
|
parser.add_argument("--export_preprocessed", type=str, default=None,
|
|
help="Export stride-sampled, resized/cropped images to this folder")
|
|
|
|
# Output export
|
|
parser.add_argument("--save_ply", type=str, default=None,
|
|
help="Save point cloud to this PLY file path")
|
|
parser.add_argument("--save_poses", type=str, default=None,
|
|
help="Save camera extrinsics to this NPZ file path")
|
|
parser.add_argument("--save_poses_fps", type=float, default=2.0,
|
|
help="Subsampling FPS for saved poses (default 2.0)")
|
|
|
|
args = parser.parse_args()
|
|
assert args.image_folder or args.video_path, \
|
|
"Provide --image_folder or --video_path"
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# ── Load images & model ──────────────────────────────────────────────────
|
|
t0 = time.time()
|
|
images, paths, resolved_image_folder = load_images(
|
|
image_folder=args.image_folder, video_path=args.video_path,
|
|
fps=args.fps, first_k=args.first_k, stride=args.stride,
|
|
image_size=args.image_size, patch_size=args.patch_size,
|
|
)
|
|
|
|
# Export preprocessed images if requested
|
|
if args.export_preprocessed:
|
|
os.makedirs(args.export_preprocessed, exist_ok=True)
|
|
print(f"Exporting {images.shape[0]} preprocessed images to {args.export_preprocessed}...")
|
|
for i in range(images.shape[0]):
|
|
img = (images[i].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
cv2.imwrite(
|
|
os.path.join(args.export_preprocessed, f"{i:06d}.png"),
|
|
cv2.cvtColor(img, cv2.COLOR_RGB2BGR),
|
|
)
|
|
print(f"Exported to {args.export_preprocessed}")
|
|
|
|
model = load_model(args, device)
|
|
print(f"Total load time: {time.time() - t0:.1f}s")
|
|
|
|
# Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm).
|
|
if torch.cuda.is_available():
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
|
else:
|
|
dtype = torch.float32
|
|
|
|
# Cast the aggregator (DINOv2-style trunk) to the inference dtype to remove the
|
|
# redundant fp32 master weight copy + autocast bf16 weight cache (~2-3 GB saved,
|
|
# no measurable quality change). gct_base._predict_* upcasts inputs to fp32 and
|
|
# runs each head under `autocast(enabled=False)`, so camera/depth/point heads
|
|
# keep fp32 weights automatically.
|
|
if dtype != torch.float32 and getattr(model, "aggregator", None) is not None:
|
|
print(f"Casting aggregator to {dtype} (heads kept in fp32)")
|
|
model.aggregator = model.aggregator.to(dtype=dtype)
|
|
|
|
images = images.to(device)
|
|
num_frames = images.shape[0]
|
|
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
|
|
print(f"Mode: {args.mode}")
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
print(
|
|
f"GPU mem after load: "
|
|
f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, "
|
|
f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB"
|
|
)
|
|
|
|
if args.keyframe_interval is None:
|
|
if args.mode == "streaming" and num_frames > 320:
|
|
args.keyframe_interval = (num_frames + 319) // 320
|
|
print(
|
|
f"Auto-selected --keyframe_interval={args.keyframe_interval} "
|
|
f"(num_frames={num_frames} > 320)."
|
|
)
|
|
else:
|
|
args.keyframe_interval = 1
|
|
|
|
if args.mode != "streaming" and args.keyframe_interval != 1:
|
|
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
|
|
args.keyframe_interval = 1
|
|
elif args.mode == "streaming" and args.keyframe_interval > 1:
|
|
print(
|
|
f"Keyframe streaming enabled: interval={args.keyframe_interval} "
|
|
f"(after the first {args.num_scale_frames} scale frames)."
|
|
)
|
|
|
|
# ── Optional: torch.compile + CUDA-graph warmup (streaming only) ────────
|
|
if args.compile:
|
|
if args.mode != "streaming":
|
|
print(
|
|
f"--compile only applies to --mode streaming (got {args.mode!r}); "
|
|
"skipping compile."
|
|
)
|
|
else:
|
|
scale_for_warm = min(args.num_scale_frames, num_frames)
|
|
warm_stream_n = min(10, max(1, num_frames - scale_for_warm))
|
|
print(f"Warmup eager (scale + {warm_stream_n} streaming)...")
|
|
t_warm = time.time()
|
|
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1)
|
|
print(f" eager warmup: {time.time() - t_warm:.1f}s")
|
|
|
|
print("Compiling hot modules...")
|
|
compile_model(model)
|
|
|
|
# 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so
|
|
# the caching allocator / graph-address map converge on the state the
|
|
# real inference will see. See gct_profile.py:302-306 for rationale.
|
|
print("Warmup compiled (3x dress rehearsal)...")
|
|
t_warm = time.time()
|
|
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3)
|
|
print(f" compiled warmup: {time.time() - t_warm:.1f}s")
|
|
|
|
# ── Inference ────────────────────────────────────────────────────────────
|
|
print(f"Running {args.mode} inference (dtype={dtype})...")
|
|
t0 = time.time()
|
|
|
|
output_device = torch.device("cpu") if args.offload_to_cpu else None
|
|
|
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
|
if args.mode == "streaming":
|
|
predictions = model.inference_streaming(
|
|
images,
|
|
num_scale_frames=args.num_scale_frames,
|
|
keyframe_interval=args.keyframe_interval,
|
|
output_device=output_device,
|
|
)
|
|
else: # windowed
|
|
predictions = model.inference_windowed(
|
|
images,
|
|
window_size=args.window_size,
|
|
overlap_size=args.overlap_size,
|
|
num_scale_frames=args.num_scale_frames,
|
|
output_device=output_device,
|
|
)
|
|
|
|
print(f"Inference done in {time.time() - t0:.1f}s")
|
|
if torch.cuda.is_available():
|
|
print(
|
|
f"GPU peak during inference: "
|
|
f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB "
|
|
f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)"
|
|
)
|
|
|
|
# ── Post-process ─────────────────────────────────────────────────────────
|
|
if args.offload_to_cpu:
|
|
del images
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
images_for_post = predictions["images"] # already CPU
|
|
else:
|
|
images_for_post = images
|
|
|
|
predictions, images_cpu = postprocess(predictions, images_for_post)
|
|
|
|
# ── Export ───────────────────────────────────────────────────────────────
|
|
source_fps = args.fps if args.video_path else 1.0
|
|
if args.save_ply:
|
|
_save_ply(predictions, args.save_ply, args.conf_threshold)
|
|
if args.save_poses:
|
|
_save_poses(predictions, len(paths), source_fps, args.save_poses, args.save_poses_fps)
|
|
|
|
# ── Visualize ────────────────────────────────────────────────────────────
|
|
try:
|
|
from lingbot_map.vis import PointCloudViewer
|
|
viewer = PointCloudViewer(
|
|
pred_dict=prepare_for_visualization(predictions, images_cpu),
|
|
port=args.port,
|
|
vis_threshold=args.conf_threshold,
|
|
downsample_factor=args.downsample_factor,
|
|
point_size=args.point_size,
|
|
mask_sky=args.mask_sky,
|
|
image_folder=resolved_image_folder,
|
|
sky_mask_dir=args.sky_mask_dir,
|
|
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
|
|
)
|
|
print(f"3D viewer at http://localhost:{args.port}")
|
|
viewer.run()
|
|
except ImportError:
|
|
print("viser not installed. Install with: pip install lingbot-map[vis]")
|
|
print(f"Predictions contain keys: {list(predictions.keys())}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|