update demp gpu usage and fix some bug
This commit is contained in:
44
demo.py
44
demo.py
@@ -247,7 +247,7 @@ def main():
|
|||||||
# Streaming options
|
# Streaming options
|
||||||
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
|
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
|
||||||
parser.add_argument("--max_frame_num", type=int, default=1024)
|
parser.add_argument("--max_frame_num", type=int, default=1024)
|
||||||
parser.add_argument("--num_scale_frames", type=int, default=8)
|
parser.add_argument("--num_scale_frames", type=int, default=4)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--keyframe_interval",
|
"--keyframe_interval",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -258,6 +258,13 @@ def main():
|
|||||||
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
|
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
|
||||||
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
||||||
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
||||||
|
parser.add_argument(
|
||||||
|
"--offload_to_cpu",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. "
|
||||||
|
"Use --no-offload_to_cpu to keep outputs on GPU.",
|
||||||
|
)
|
||||||
|
|
||||||
# Windowed options
|
# Windowed options
|
||||||
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
|
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
|
||||||
@@ -306,10 +313,24 @@ def main():
|
|||||||
model = load_model(args, device)
|
model = load_model(args, device)
|
||||||
print(f"Total load time: {time.time() - t0:.1f}s")
|
print(f"Total load time: {time.time() - t0:.1f}s")
|
||||||
|
|
||||||
|
# Keep model in its loaded dtype — autocast handles bf16/fp16 for the ops
|
||||||
|
# that benefit from it and keeps LayerNorm / reductions in fp32.
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
images = images.to(device)
|
images = images.to(device)
|
||||||
num_frames = images.shape[0]
|
num_frames = images.shape[0]
|
||||||
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
|
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
|
||||||
print(f"Mode: {args.mode}")
|
print(f"Mode: {args.mode}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print(
|
||||||
|
f"GPU mem after load: "
|
||||||
|
f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, "
|
||||||
|
f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
if args.mode != "streaming" and args.keyframe_interval != 1:
|
if args.mode != "streaming" and args.keyframe_interval != 1:
|
||||||
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
|
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
|
||||||
@@ -321,16 +342,18 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ── Inference ────────────────────────────────────────────────────────────
|
# ── Inference ────────────────────────────────────────────────────────────
|
||||||
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
|
||||||
print(f"Running {args.mode} inference (dtype={dtype})...")
|
print(f"Running {args.mode} inference (dtype={dtype})...")
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
|
output_device = torch.device("cpu") if args.offload_to_cpu else None
|
||||||
|
|
||||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
||||||
if args.mode == "streaming":
|
if args.mode == "streaming":
|
||||||
predictions = model.inference_streaming(
|
predictions = model.inference_streaming(
|
||||||
images,
|
images,
|
||||||
num_scale_frames=args.num_scale_frames,
|
num_scale_frames=args.num_scale_frames,
|
||||||
keyframe_interval=args.keyframe_interval,
|
keyframe_interval=args.keyframe_interval,
|
||||||
|
output_device=output_device,
|
||||||
)
|
)
|
||||||
else: # windowed
|
else: # windowed
|
||||||
predictions = model.inference_windowed(
|
predictions = model.inference_windowed(
|
||||||
@@ -338,12 +361,27 @@ def main():
|
|||||||
window_size=args.window_size,
|
window_size=args.window_size,
|
||||||
overlap_size=args.overlap_size,
|
overlap_size=args.overlap_size,
|
||||||
num_scale_frames=args.num_scale_frames,
|
num_scale_frames=args.num_scale_frames,
|
||||||
|
output_device=output_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Inference done in {time.time() - t0:.1f}s")
|
print(f"Inference done in {time.time() - t0:.1f}s")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
f"GPU peak during inference: "
|
||||||
|
f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB "
|
||||||
|
f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)"
|
||||||
|
)
|
||||||
|
|
||||||
# ── Post-process ─────────────────────────────────────────────────────────
|
# ── Post-process ─────────────────────────────────────────────────────────
|
||||||
predictions, images_cpu = postprocess(predictions, images)
|
if args.offload_to_cpu:
|
||||||
|
del images
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
images_for_post = predictions["images"] # already CPU
|
||||||
|
else:
|
||||||
|
images_for_post = images
|
||||||
|
|
||||||
|
predictions, images_cpu = postprocess(predictions, images_for_post)
|
||||||
|
|
||||||
# ── Visualize ────────────────────────────────────────────────────────────
|
# ── Visualize ────────────────────────────────────────────────────────────
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -115,10 +115,6 @@ class PointCloudViewer:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.original_images = []
|
self.original_images = []
|
||||||
self.tsdf_depth_maps = None
|
|
||||||
self.tsdf_extrinsics = None
|
|
||||||
self.tsdf_intrinsics = None
|
|
||||||
self.tsdf_images = None
|
|
||||||
|
|
||||||
self.pcs, self.all_steps = self.read_data(
|
self.pcs, self.all_steps = self.read_data(
|
||||||
pc_list, color_list, conf_list, edge_color_list
|
pc_list, color_list, conf_list, edge_color_list
|
||||||
@@ -187,12 +183,6 @@ class PointCloudViewer:
|
|||||||
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
||||||
S = world_points.shape[0]
|
S = world_points.shape[0]
|
||||||
|
|
||||||
# Store raw data for TSDF fusion
|
|
||||||
self.tsdf_depth_maps = depth_map # (S, H, W, 1)
|
|
||||||
self.tsdf_extrinsics = extrinsics_cam # (S, 3, 4) camera-from-world
|
|
||||||
self.tsdf_intrinsics = intrinsics_cam # (S, 3, 3)
|
|
||||||
self.tsdf_images = images # (S, 3, H, W)
|
|
||||||
|
|
||||||
# Store original images for camera frustum display
|
# Store original images for camera frustum display
|
||||||
self.original_images = []
|
self.original_images = []
|
||||||
for i in range(S):
|
for i in range(S):
|
||||||
@@ -423,88 +413,6 @@ class PointCloudViewer:
|
|||||||
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
|
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Point cloud filtering controls
|
|
||||||
with self.server.gui.add_folder("Point Cloud Filtering"):
|
|
||||||
self.bbox_clip_slider = self.server.gui.add_slider(
|
|
||||||
"Bounding Box Keep (%)",
|
|
||||||
min=50.0, max=100.0, step=0.5, initial_value=100.0,
|
|
||||||
hint="Keep the central N% of points per axis. 100 = no clipping.",
|
|
||||||
)
|
|
||||||
self.sor_checkbox = self.server.gui.add_checkbox(
|
|
||||||
"Statistical Outlier Removal",
|
|
||||||
initial_value=False,
|
|
||||||
hint="Remove isolated floating points based on KNN distance.",
|
|
||||||
)
|
|
||||||
self.sor_neighbors_slider = self.server.gui.add_slider(
|
|
||||||
"SOR Neighbors (K)",
|
|
||||||
min=5, max=50, step=1, initial_value=20, disabled=True,
|
|
||||||
hint="Number of nearest neighbors for outlier detection.",
|
|
||||||
)
|
|
||||||
self.sor_std_slider = self.server.gui.add_slider(
|
|
||||||
"SOR Std Ratio",
|
|
||||||
min=0.5, max=5.0, step=0.1, initial_value=2.0, disabled=True,
|
|
||||||
hint="Lower = more aggressive filtering. Points beyond mean + ratio*std are removed.",
|
|
||||||
)
|
|
||||||
self.filter_apply_button = self.server.gui.add_button(
|
|
||||||
"Apply Filters",
|
|
||||||
hint="Regenerate point clouds with current filter settings.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@self.sor_checkbox.on_update
|
|
||||||
def _(_) -> None:
|
|
||||||
self.sor_neighbors_slider.disabled = not self.sor_checkbox.value
|
|
||||||
self.sor_std_slider.disabled = not self.sor_checkbox.value
|
|
||||||
|
|
||||||
@self.filter_apply_button.on_click
|
|
||||||
def _(_) -> None:
|
|
||||||
self._regenerate_point_clouds()
|
|
||||||
|
|
||||||
# TSDF Fusion controls
|
|
||||||
with self.server.gui.add_folder("TSDF Fusion"):
|
|
||||||
self.tsdf_voxel_size_slider = self.server.gui.add_slider(
|
|
||||||
"Voxel Size", min=0.001, max=0.1, step=0.001, initial_value=0.01,
|
|
||||||
hint="TSDF voxel size. Smaller = finer detail but slower.",
|
|
||||||
)
|
|
||||||
self.tsdf_sdf_trunc_slider = self.server.gui.add_slider(
|
|
||||||
"SDF Truncation", min=0.01, max=0.5, step=0.01, initial_value=0.04,
|
|
||||||
hint="Truncation distance. Typically 3-5x voxel size.",
|
|
||||||
)
|
|
||||||
self.tsdf_depth_scale_slider = self.server.gui.add_slider(
|
|
||||||
"Depth Scale", min=1.0, max=10000.0, step=1.0, initial_value=1.0,
|
|
||||||
hint="Depth scale factor. 1.0 if depth is in meters.",
|
|
||||||
)
|
|
||||||
self.tsdf_depth_trunc_slider = self.server.gui.add_slider(
|
|
||||||
"Depth Truncation", min=0.5, max=50.0, step=0.5, initial_value=5.0,
|
|
||||||
hint="Max depth value to integrate (meters).",
|
|
||||||
)
|
|
||||||
self.tsdf_run_button = self.server.gui.add_button(
|
|
||||||
"Run TSDF Fusion",
|
|
||||||
hint="Fuse all frames into a single point cloud via TSDF.",
|
|
||||||
)
|
|
||||||
self.tsdf_clear_button = self.server.gui.add_button(
|
|
||||||
"Clear TSDF Result",
|
|
||||||
hint="Remove the TSDF fused point cloud from the scene.",
|
|
||||||
)
|
|
||||||
self.tsdf_status = self.server.gui.add_text(
|
|
||||||
"Status", initial_value="Ready",
|
|
||||||
)
|
|
||||||
|
|
||||||
self._tsdf_handle = None
|
|
||||||
|
|
||||||
@self.tsdf_run_button.on_click
|
|
||||||
def _(_) -> None:
|
|
||||||
self._run_tsdf_fusion()
|
|
||||||
|
|
||||||
@self.tsdf_clear_button.on_click
|
|
||||||
def _(_) -> None:
|
|
||||||
if self._tsdf_handle is not None:
|
|
||||||
try:
|
|
||||||
self._tsdf_handle.remove()
|
|
||||||
except (KeyError, AttributeError):
|
|
||||||
pass
|
|
||||||
self._tsdf_handle = None
|
|
||||||
self.tsdf_status.value = "Cleared"
|
|
||||||
|
|
||||||
# Range visualization controls
|
# Range visualization controls
|
||||||
with self.server.gui.add_folder("Frame Range Control"):
|
with self.server.gui.add_folder("Frame Range Control"):
|
||||||
self.range_mode_checkbox = self.server.gui.add_checkbox("Range Mode", initial_value=False)
|
self.range_mode_checkbox = self.server.gui.add_checkbox("Range Mode", initial_value=False)
|
||||||
@@ -781,100 +689,6 @@ class PointCloudViewer:
|
|||||||
if i % downsample_factor == 0:
|
if i % downsample_factor == 0:
|
||||||
self.add_camera(step)
|
self.add_camera(step)
|
||||||
|
|
||||||
def _run_tsdf_fusion(self):
|
|
||||||
"""Run TSDF fusion on all frames and display result as a point cloud."""
|
|
||||||
if not hasattr(self, 'tsdf_depth_maps') or self.tsdf_depth_maps is None:
|
|
||||||
self.tsdf_status.value = "Error: no depth data (need pred_dict)"
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
import open3d as o3d
|
|
||||||
except ImportError:
|
|
||||||
self.tsdf_status.value = "Error: pip install open3d"
|
|
||||||
return
|
|
||||||
|
|
||||||
self.tsdf_status.value = "Running TSDF fusion..."
|
|
||||||
print("Starting TSDF fusion...")
|
|
||||||
|
|
||||||
voxel_size = self.tsdf_voxel_size_slider.value
|
|
||||||
sdf_trunc = self.tsdf_sdf_trunc_slider.value
|
|
||||||
depth_scale = self.tsdf_depth_scale_slider.value
|
|
||||||
depth_trunc = self.tsdf_depth_trunc_slider.value
|
|
||||||
|
|
||||||
volume = o3d.pipelines.integration.ScalableTSDFVolume(
|
|
||||||
voxel_length=voxel_size,
|
|
||||||
sdf_trunc=sdf_trunc,
|
|
||||||
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8,
|
|
||||||
)
|
|
||||||
|
|
||||||
S = self.tsdf_depth_maps.shape[0]
|
|
||||||
H, W = self.tsdf_depth_maps.shape[1], self.tsdf_depth_maps.shape[2]
|
|
||||||
|
|
||||||
for i in tqdm(range(S), desc="TSDF integrating"):
|
|
||||||
# Depth: (H, W, 1) -> (H, W)
|
|
||||||
depth = self.tsdf_depth_maps[i]
|
|
||||||
if depth.ndim == 3:
|
|
||||||
depth = depth[..., 0]
|
|
||||||
|
|
||||||
# Color: (3, H, W) -> (H, W, 3), uint8
|
|
||||||
color = self.tsdf_images[i].transpose(1, 2, 0) # (H, W, 3)
|
|
||||||
color = (np.clip(color, 0, 1) * 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Camera extrinsic: (3, 4) -> (4, 4) camera-from-world
|
|
||||||
extr_34 = self.tsdf_extrinsics[i]
|
|
||||||
extr_44 = np.eye(4, dtype=np.float64)
|
|
||||||
extr_44[:3, :] = extr_34
|
|
||||||
|
|
||||||
intrinsic = o3d.camera.PinholeCameraIntrinsic(
|
|
||||||
width=W, height=H,
|
|
||||||
fx=float(self.tsdf_intrinsics[i, 0, 0]),
|
|
||||||
fy=float(self.tsdf_intrinsics[i, 1, 1]),
|
|
||||||
cx=float(self.tsdf_intrinsics[i, 0, 2]),
|
|
||||||
cy=float(self.tsdf_intrinsics[i, 1, 2]),
|
|
||||||
)
|
|
||||||
|
|
||||||
depth_o3d = o3d.geometry.Image(
|
|
||||||
(depth.astype(np.float32) * depth_scale).astype(np.float32)
|
|
||||||
)
|
|
||||||
color_o3d = o3d.geometry.Image(np.ascontiguousarray(color))
|
|
||||||
|
|
||||||
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
|
|
||||||
color_o3d, depth_o3d,
|
|
||||||
depth_scale=depth_scale,
|
|
||||||
depth_trunc=depth_trunc,
|
|
||||||
convert_rgb_to_intensity=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
volume.integrate(rgbd, intrinsic, extr_44)
|
|
||||||
|
|
||||||
print("Extracting point cloud from TSDF volume...")
|
|
||||||
pcd = volume.extract_point_cloud()
|
|
||||||
|
|
||||||
points = np.asarray(pcd.points, dtype=np.float32)
|
|
||||||
colors = np.asarray(pcd.colors, dtype=np.float32) # already 0-1
|
|
||||||
|
|
||||||
if len(points) == 0:
|
|
||||||
self.tsdf_status.value = "Error: empty result, try adjusting parameters"
|
|
||||||
print("TSDF fusion produced 0 points.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove previous TSDF result
|
|
||||||
if self._tsdf_handle is not None:
|
|
||||||
try:
|
|
||||||
self._tsdf_handle.remove()
|
|
||||||
except (KeyError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
self._tsdf_handle = self.server.scene.add_point_cloud(
|
|
||||||
name="/tsdf_fusion",
|
|
||||||
points=points,
|
|
||||||
colors=colors,
|
|
||||||
point_size=self.psize_slider.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tsdf_status.value = f"Done: {len(points):,} points"
|
|
||||||
print(f"TSDF fusion complete: {len(points):,} points")
|
|
||||||
|
|
||||||
def _export_glb(self):
|
def _export_glb(self):
|
||||||
"""Export current filtered point clouds and cameras as a GLB file."""
|
"""Export current filtered point clouds and cameras as a GLB file."""
|
||||||
try:
|
try:
|
||||||
@@ -1354,27 +1168,6 @@ class PointCloudViewer:
|
|||||||
if len(pred_pts) == 0:
|
if len(pred_pts) == 0:
|
||||||
return pred_pts, color
|
return pred_pts, color
|
||||||
|
|
||||||
# Bounding box clip: remove points far from the scene center
|
|
||||||
if hasattr(self, 'bbox_clip_slider'):
|
|
||||||
clip_pct = self.bbox_clip_slider.value
|
|
||||||
if clip_pct < 100.0:
|
|
||||||
lo = np.percentile(pred_pts, (100.0 - clip_pct) / 2, axis=0)
|
|
||||||
hi = np.percentile(pred_pts, 100.0 - (100.0 - clip_pct) / 2, axis=0)
|
|
||||||
bbox_mask = np.all((pred_pts >= lo) & (pred_pts <= hi), axis=1)
|
|
||||||
pred_pts = pred_pts[bbox_mask]
|
|
||||||
color = color[bbox_mask]
|
|
||||||
|
|
||||||
if len(pred_pts) == 0:
|
|
||||||
return pred_pts, color
|
|
||||||
|
|
||||||
# Statistical Outlier Removal (SOR)
|
|
||||||
if hasattr(self, 'sor_checkbox') and self.sor_checkbox.value and len(pred_pts) > 0:
|
|
||||||
pred_pts, color = self._statistical_outlier_removal(
|
|
||||||
pred_pts, color,
|
|
||||||
nb_neighbors=int(self.sor_neighbors_slider.value),
|
|
||||||
std_ratio=self.sor_std_slider.value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Downsample
|
# Downsample
|
||||||
if downsample_factor > 1 and len(pred_pts) > 0:
|
if downsample_factor > 1 and len(pred_pts) > 0:
|
||||||
indices = np.arange(0, len(pred_pts), downsample_factor)
|
indices = np.arange(0, len(pred_pts), downsample_factor)
|
||||||
@@ -1383,49 +1176,6 @@ class PointCloudViewer:
|
|||||||
|
|
||||||
return pred_pts, color
|
return pred_pts, color
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _statistical_outlier_removal(
|
|
||||||
points: np.ndarray,
|
|
||||||
colors: np.ndarray,
|
|
||||||
nb_neighbors: int = 20,
|
|
||||||
std_ratio: float = 2.0,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""Remove statistical outliers based on mean distance to k-nearest neighbors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
points: (N, 3) point positions.
|
|
||||||
colors: (N, 3) point colors.
|
|
||||||
nb_neighbors: Number of nearest neighbors to consider.
|
|
||||||
std_ratio: Standard deviation multiplier for the distance threshold.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Filtered (points, colors) tuple.
|
|
||||||
"""
|
|
||||||
if len(points) <= nb_neighbors:
|
|
||||||
return points, colors
|
|
||||||
|
|
||||||
try:
|
|
||||||
from scipy.spatial import cKDTree
|
|
||||||
except ImportError:
|
|
||||||
# Fallback: skip SOR if scipy not available
|
|
||||||
return points, colors
|
|
||||||
|
|
||||||
# Subsample for KD-tree if too many points (speed)
|
|
||||||
max_pts_for_tree = 200_000
|
|
||||||
if len(points) > max_pts_for_tree:
|
|
||||||
subsample_idx = np.random.choice(len(points), max_pts_for_tree, replace=False)
|
|
||||||
tree = cKDTree(points[subsample_idx])
|
|
||||||
else:
|
|
||||||
tree = cKDTree(points)
|
|
||||||
|
|
||||||
dists, _ = tree.query(points, k=nb_neighbors + 1) # +1 because first is self
|
|
||||||
mean_dists = dists[:, 1:].mean(axis=1) # exclude self
|
|
||||||
|
|
||||||
threshold = mean_dists.mean() + std_ratio * mean_dists.std()
|
|
||||||
inlier_mask = mean_dists < threshold
|
|
||||||
|
|
||||||
return points[inlier_mask], colors[inlier_mask]
|
|
||||||
|
|
||||||
def add_pc(self, step):
|
def add_pc(self, step):
|
||||||
"""Add point cloud for a frame."""
|
"""Add point cloud for a frame."""
|
||||||
pc = self.pcs[step]["pc"]
|
pc = self.pcs[step]["pc"]
|
||||||
|
|||||||
Reference in New Issue
Block a user