From f307fdba68dd8dbc9857654c25e4a1ee991bcf69 Mon Sep 17 00:00:00 2001 From: LinZhuoChen Date: Fri, 17 Apr 2026 17:56:17 +0800 Subject: [PATCH] update demp gpu usage and fix some bug --- demo.py | 44 ++++- lingbot_map/vis/point_cloud_viewer.py | 250 -------------------------- 2 files changed, 41 insertions(+), 253 deletions(-) diff --git a/demo.py b/demo.py index 6beda8e..c55d3f5 100644 --- a/demo.py +++ b/demo.py @@ -247,7 +247,7 @@ def main(): # Streaming options parser.add_argument("--enable_3d_rope", action="store_true", default=True) parser.add_argument("--max_frame_num", type=int, default=1024) - parser.add_argument("--num_scale_frames", type=int, default=8) + parser.add_argument("--num_scale_frames", type=int, default=4) parser.add_argument( "--keyframe_interval", type=int, @@ -258,6 +258,13 @@ def main(): parser.add_argument("--kv_cache_scale_frames", type=int, default=8) parser.add_argument("--use_sdpa", action="store_true", default=False, help="Use SDPA backend (no flashinfer needed). Default: FlashInfer") + parser.add_argument( + "--offload_to_cpu", + action=argparse.BooleanOptionalAction, + default=True, + help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. " + "Use --no-offload_to_cpu to keep outputs on GPU.", + ) # Windowed options parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)") @@ -306,10 +313,24 @@ def main(): model = load_model(args, device) print(f"Total load time: {time.time() - t0:.1f}s") + # Keep model in its loaded dtype — autocast handles bf16/fp16 for the ops + # that benefit from it and keeps LayerNorm / reductions in fp32. + if torch.cuda.is_available(): + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + else: + dtype = torch.float32 + images = images.to(device) num_frames = images.shape[0] print(f"Input: {num_frames} frames, shape {tuple(images.shape)}") print(f"Mode: {args.mode}") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print( + f"GPU mem after load: " + f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, " + f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB" + ) if args.mode != "streaming" and args.keyframe_interval != 1: print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.") @@ -321,16 +342,18 @@ def main(): ) # ── Inference ──────────────────────────────────────────────────────────── - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 print(f"Running {args.mode} inference (dtype={dtype})...") t0 = time.time() + output_device = torch.device("cpu") if args.offload_to_cpu else None + with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): if args.mode == "streaming": predictions = model.inference_streaming( images, num_scale_frames=args.num_scale_frames, keyframe_interval=args.keyframe_interval, + output_device=output_device, ) else: # windowed predictions = model.inference_windowed( @@ -338,12 +361,27 @@ def main(): window_size=args.window_size, overlap_size=args.overlap_size, num_scale_frames=args.num_scale_frames, + output_device=output_device, ) print(f"Inference done in {time.time() - t0:.1f}s") + if torch.cuda.is_available(): + print( + f"GPU peak during inference: " + f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB " + f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)" + ) # ── Post-process ───────────────────────────────────────────────────────── - predictions, images_cpu = postprocess(predictions, images) + if args.offload_to_cpu: + del images + if torch.cuda.is_available(): + torch.cuda.empty_cache() + images_for_post = predictions["images"] # already CPU + else: + images_for_post = images + + predictions, images_cpu = postprocess(predictions, images_for_post) # ── Visualize ──────────────────────────────────────────────────────────── try: diff --git a/lingbot_map/vis/point_cloud_viewer.py b/lingbot_map/vis/point_cloud_viewer.py index 64e2031..57f5083 100644 --- a/lingbot_map/vis/point_cloud_viewer.py +++ b/lingbot_map/vis/point_cloud_viewer.py @@ -115,10 +115,6 @@ class PointCloudViewer: ) else: self.original_images = [] - self.tsdf_depth_maps = None - self.tsdf_extrinsics = None - self.tsdf_intrinsics = None - self.tsdf_images = None self.pcs, self.all_steps = self.read_data( pc_list, color_list, conf_list, edge_color_list @@ -187,12 +183,6 @@ class PointCloudViewer: colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) S = world_points.shape[0] - # Store raw data for TSDF fusion - self.tsdf_depth_maps = depth_map # (S, H, W, 1) - self.tsdf_extrinsics = extrinsics_cam # (S, 3, 4) camera-from-world - self.tsdf_intrinsics = intrinsics_cam # (S, 3, 3) - self.tsdf_images = images # (S, 3, H, W) - # Store original images for camera frustum display self.original_images = [] for i in range(S): @@ -423,88 +413,6 @@ class PointCloudViewer: "Camera Downsample Factor", min=1, max=50, step=1, initial_value=1 ) - # Point cloud filtering controls - with self.server.gui.add_folder("Point Cloud Filtering"): - self.bbox_clip_slider = self.server.gui.add_slider( - "Bounding Box Keep (%)", - min=50.0, max=100.0, step=0.5, initial_value=100.0, - hint="Keep the central N% of points per axis. 100 = no clipping.", - ) - self.sor_checkbox = self.server.gui.add_checkbox( - "Statistical Outlier Removal", - initial_value=False, - hint="Remove isolated floating points based on KNN distance.", - ) - self.sor_neighbors_slider = self.server.gui.add_slider( - "SOR Neighbors (K)", - min=5, max=50, step=1, initial_value=20, disabled=True, - hint="Number of nearest neighbors for outlier detection.", - ) - self.sor_std_slider = self.server.gui.add_slider( - "SOR Std Ratio", - min=0.5, max=5.0, step=0.1, initial_value=2.0, disabled=True, - hint="Lower = more aggressive filtering. Points beyond mean + ratio*std are removed.", - ) - self.filter_apply_button = self.server.gui.add_button( - "Apply Filters", - hint="Regenerate point clouds with current filter settings.", - ) - - @self.sor_checkbox.on_update - def _(_) -> None: - self.sor_neighbors_slider.disabled = not self.sor_checkbox.value - self.sor_std_slider.disabled = not self.sor_checkbox.value - - @self.filter_apply_button.on_click - def _(_) -> None: - self._regenerate_point_clouds() - - # TSDF Fusion controls - with self.server.gui.add_folder("TSDF Fusion"): - self.tsdf_voxel_size_slider = self.server.gui.add_slider( - "Voxel Size", min=0.001, max=0.1, step=0.001, initial_value=0.01, - hint="TSDF voxel size. Smaller = finer detail but slower.", - ) - self.tsdf_sdf_trunc_slider = self.server.gui.add_slider( - "SDF Truncation", min=0.01, max=0.5, step=0.01, initial_value=0.04, - hint="Truncation distance. Typically 3-5x voxel size.", - ) - self.tsdf_depth_scale_slider = self.server.gui.add_slider( - "Depth Scale", min=1.0, max=10000.0, step=1.0, initial_value=1.0, - hint="Depth scale factor. 1.0 if depth is in meters.", - ) - self.tsdf_depth_trunc_slider = self.server.gui.add_slider( - "Depth Truncation", min=0.5, max=50.0, step=0.5, initial_value=5.0, - hint="Max depth value to integrate (meters).", - ) - self.tsdf_run_button = self.server.gui.add_button( - "Run TSDF Fusion", - hint="Fuse all frames into a single point cloud via TSDF.", - ) - self.tsdf_clear_button = self.server.gui.add_button( - "Clear TSDF Result", - hint="Remove the TSDF fused point cloud from the scene.", - ) - self.tsdf_status = self.server.gui.add_text( - "Status", initial_value="Ready", - ) - - self._tsdf_handle = None - - @self.tsdf_run_button.on_click - def _(_) -> None: - self._run_tsdf_fusion() - - @self.tsdf_clear_button.on_click - def _(_) -> None: - if self._tsdf_handle is not None: - try: - self._tsdf_handle.remove() - except (KeyError, AttributeError): - pass - self._tsdf_handle = None - self.tsdf_status.value = "Cleared" - # Range visualization controls with self.server.gui.add_folder("Frame Range Control"): self.range_mode_checkbox = self.server.gui.add_checkbox("Range Mode", initial_value=False) @@ -781,100 +689,6 @@ class PointCloudViewer: if i % downsample_factor == 0: self.add_camera(step) - def _run_tsdf_fusion(self): - """Run TSDF fusion on all frames and display result as a point cloud.""" - if not hasattr(self, 'tsdf_depth_maps') or self.tsdf_depth_maps is None: - self.tsdf_status.value = "Error: no depth data (need pred_dict)" - return - - try: - import open3d as o3d - except ImportError: - self.tsdf_status.value = "Error: pip install open3d" - return - - self.tsdf_status.value = "Running TSDF fusion..." - print("Starting TSDF fusion...") - - voxel_size = self.tsdf_voxel_size_slider.value - sdf_trunc = self.tsdf_sdf_trunc_slider.value - depth_scale = self.tsdf_depth_scale_slider.value - depth_trunc = self.tsdf_depth_trunc_slider.value - - volume = o3d.pipelines.integration.ScalableTSDFVolume( - voxel_length=voxel_size, - sdf_trunc=sdf_trunc, - color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, - ) - - S = self.tsdf_depth_maps.shape[0] - H, W = self.tsdf_depth_maps.shape[1], self.tsdf_depth_maps.shape[2] - - for i in tqdm(range(S), desc="TSDF integrating"): - # Depth: (H, W, 1) -> (H, W) - depth = self.tsdf_depth_maps[i] - if depth.ndim == 3: - depth = depth[..., 0] - - # Color: (3, H, W) -> (H, W, 3), uint8 - color = self.tsdf_images[i].transpose(1, 2, 0) # (H, W, 3) - color = (np.clip(color, 0, 1) * 255).astype(np.uint8) - - # Camera extrinsic: (3, 4) -> (4, 4) camera-from-world - extr_34 = self.tsdf_extrinsics[i] - extr_44 = np.eye(4, dtype=np.float64) - extr_44[:3, :] = extr_34 - - intrinsic = o3d.camera.PinholeCameraIntrinsic( - width=W, height=H, - fx=float(self.tsdf_intrinsics[i, 0, 0]), - fy=float(self.tsdf_intrinsics[i, 1, 1]), - cx=float(self.tsdf_intrinsics[i, 0, 2]), - cy=float(self.tsdf_intrinsics[i, 1, 2]), - ) - - depth_o3d = o3d.geometry.Image( - (depth.astype(np.float32) * depth_scale).astype(np.float32) - ) - color_o3d = o3d.geometry.Image(np.ascontiguousarray(color)) - - rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( - color_o3d, depth_o3d, - depth_scale=depth_scale, - depth_trunc=depth_trunc, - convert_rgb_to_intensity=False, - ) - - volume.integrate(rgbd, intrinsic, extr_44) - - print("Extracting point cloud from TSDF volume...") - pcd = volume.extract_point_cloud() - - points = np.asarray(pcd.points, dtype=np.float32) - colors = np.asarray(pcd.colors, dtype=np.float32) # already 0-1 - - if len(points) == 0: - self.tsdf_status.value = "Error: empty result, try adjusting parameters" - print("TSDF fusion produced 0 points.") - return - - # Remove previous TSDF result - if self._tsdf_handle is not None: - try: - self._tsdf_handle.remove() - except (KeyError, AttributeError): - pass - - self._tsdf_handle = self.server.scene.add_point_cloud( - name="/tsdf_fusion", - points=points, - colors=colors, - point_size=self.psize_slider.value, - ) - - self.tsdf_status.value = f"Done: {len(points):,} points" - print(f"TSDF fusion complete: {len(points):,} points") - def _export_glb(self): """Export current filtered point clouds and cameras as a GLB file.""" try: @@ -1354,27 +1168,6 @@ class PointCloudViewer: if len(pred_pts) == 0: return pred_pts, color - # Bounding box clip: remove points far from the scene center - if hasattr(self, 'bbox_clip_slider'): - clip_pct = self.bbox_clip_slider.value - if clip_pct < 100.0: - lo = np.percentile(pred_pts, (100.0 - clip_pct) / 2, axis=0) - hi = np.percentile(pred_pts, 100.0 - (100.0 - clip_pct) / 2, axis=0) - bbox_mask = np.all((pred_pts >= lo) & (pred_pts <= hi), axis=1) - pred_pts = pred_pts[bbox_mask] - color = color[bbox_mask] - - if len(pred_pts) == 0: - return pred_pts, color - - # Statistical Outlier Removal (SOR) - if hasattr(self, 'sor_checkbox') and self.sor_checkbox.value and len(pred_pts) > 0: - pred_pts, color = self._statistical_outlier_removal( - pred_pts, color, - nb_neighbors=int(self.sor_neighbors_slider.value), - std_ratio=self.sor_std_slider.value, - ) - # Downsample if downsample_factor > 1 and len(pred_pts) > 0: indices = np.arange(0, len(pred_pts), downsample_factor) @@ -1383,49 +1176,6 @@ class PointCloudViewer: return pred_pts, color - @staticmethod - def _statistical_outlier_removal( - points: np.ndarray, - colors: np.ndarray, - nb_neighbors: int = 20, - std_ratio: float = 2.0, - ) -> Tuple[np.ndarray, np.ndarray]: - """Remove statistical outliers based on mean distance to k-nearest neighbors. - - Args: - points: (N, 3) point positions. - colors: (N, 3) point colors. - nb_neighbors: Number of nearest neighbors to consider. - std_ratio: Standard deviation multiplier for the distance threshold. - - Returns: - Filtered (points, colors) tuple. - """ - if len(points) <= nb_neighbors: - return points, colors - - try: - from scipy.spatial import cKDTree - except ImportError: - # Fallback: skip SOR if scipy not available - return points, colors - - # Subsample for KD-tree if too many points (speed) - max_pts_for_tree = 200_000 - if len(points) > max_pts_for_tree: - subsample_idx = np.random.choice(len(points), max_pts_for_tree, replace=False) - tree = cKDTree(points[subsample_idx]) - else: - tree = cKDTree(points) - - dists, _ = tree.query(points, k=nb_neighbors + 1) # +1 because first is self - mean_dists = dists[:, 1:].mean(axis=1) # exclude self - - threshold = mean_dists.mean() + std_ratio * mean_dists.std() - inlier_mask = mean_dists < threshold - - return points[inlier_mask], colors[inlier_mask] - def add_pc(self, step): """Add point cloud for a frame.""" pc = self.pcs[step]["pc"]