update demp gpu usage and fix some bug

This commit is contained in:
LinZhuoChen
2026-04-17 17:56:17 +08:00
parent 6242dc5218
commit f307fdba68
2 changed files with 41 additions and 253 deletions

44
demo.py
View File

@@ -247,7 +247,7 @@ def main():
# Streaming options
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
parser.add_argument("--max_frame_num", type=int, default=1024)
parser.add_argument("--num_scale_frames", type=int, default=8)
parser.add_argument("--num_scale_frames", type=int, default=4)
parser.add_argument(
"--keyframe_interval",
type=int,
@@ -258,6 +258,13 @@ def main():
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
parser.add_argument("--use_sdpa", action="store_true", default=False,
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
parser.add_argument(
"--offload_to_cpu",
action=argparse.BooleanOptionalAction,
default=True,
help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. "
"Use --no-offload_to_cpu to keep outputs on GPU.",
)
# Windowed options
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
@@ -306,10 +313,24 @@ def main():
model = load_model(args, device)
print(f"Total load time: {time.time() - t0:.1f}s")
# Keep model in its loaded dtype — autocast handles bf16/fp16 for the ops
# that benefit from it and keeps LayerNorm / reductions in fp32.
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
else:
dtype = torch.float32
images = images.to(device)
num_frames = images.shape[0]
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
print(f"Mode: {args.mode}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(
f"GPU mem after load: "
f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, "
f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB"
)
if args.mode != "streaming" and args.keyframe_interval != 1:
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
@@ -321,16 +342,18 @@ def main():
)
# ── Inference ────────────────────────────────────────────────────────────
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
print(f"Running {args.mode} inference (dtype={dtype})...")
t0 = time.time()
output_device = torch.device("cpu") if args.offload_to_cpu else None
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
if args.mode == "streaming":
predictions = model.inference_streaming(
images,
num_scale_frames=args.num_scale_frames,
keyframe_interval=args.keyframe_interval,
output_device=output_device,
)
else: # windowed
predictions = model.inference_windowed(
@@ -338,12 +361,27 @@ def main():
window_size=args.window_size,
overlap_size=args.overlap_size,
num_scale_frames=args.num_scale_frames,
output_device=output_device,
)
print(f"Inference done in {time.time() - t0:.1f}s")
if torch.cuda.is_available():
print(
f"GPU peak during inference: "
f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB "
f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)"
)
# ── Post-process ─────────────────────────────────────────────────────────
predictions, images_cpu = postprocess(predictions, images)
if args.offload_to_cpu:
del images
if torch.cuda.is_available():
torch.cuda.empty_cache()
images_for_post = predictions["images"] # already CPU
else:
images_for_post = images
predictions, images_cpu = postprocess(predictions, images_for_post)
# ── Visualize ────────────────────────────────────────────────────────────
try:

View File

@@ -115,10 +115,6 @@ class PointCloudViewer:
)
else:
self.original_images = []
self.tsdf_depth_maps = None
self.tsdf_extrinsics = None
self.tsdf_intrinsics = None
self.tsdf_images = None
self.pcs, self.all_steps = self.read_data(
pc_list, color_list, conf_list, edge_color_list
@@ -187,12 +183,6 @@ class PointCloudViewer:
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
S = world_points.shape[0]
# Store raw data for TSDF fusion
self.tsdf_depth_maps = depth_map # (S, H, W, 1)
self.tsdf_extrinsics = extrinsics_cam # (S, 3, 4) camera-from-world
self.tsdf_intrinsics = intrinsics_cam # (S, 3, 3)
self.tsdf_images = images # (S, 3, H, W)
# Store original images for camera frustum display
self.original_images = []
for i in range(S):
@@ -423,88 +413,6 @@ class PointCloudViewer:
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
)
# Point cloud filtering controls
with self.server.gui.add_folder("Point Cloud Filtering"):
self.bbox_clip_slider = self.server.gui.add_slider(
"Bounding Box Keep (%)",
min=50.0, max=100.0, step=0.5, initial_value=100.0,
hint="Keep the central N% of points per axis. 100 = no clipping.",
)
self.sor_checkbox = self.server.gui.add_checkbox(
"Statistical Outlier Removal",
initial_value=False,
hint="Remove isolated floating points based on KNN distance.",
)
self.sor_neighbors_slider = self.server.gui.add_slider(
"SOR Neighbors (K)",
min=5, max=50, step=1, initial_value=20, disabled=True,
hint="Number of nearest neighbors for outlier detection.",
)
self.sor_std_slider = self.server.gui.add_slider(
"SOR Std Ratio",
min=0.5, max=5.0, step=0.1, initial_value=2.0, disabled=True,
hint="Lower = more aggressive filtering. Points beyond mean + ratio*std are removed.",
)
self.filter_apply_button = self.server.gui.add_button(
"Apply Filters",
hint="Regenerate point clouds with current filter settings.",
)
@self.sor_checkbox.on_update
def _(_) -> None:
self.sor_neighbors_slider.disabled = not self.sor_checkbox.value
self.sor_std_slider.disabled = not self.sor_checkbox.value
@self.filter_apply_button.on_click
def _(_) -> None:
self._regenerate_point_clouds()
# TSDF Fusion controls
with self.server.gui.add_folder("TSDF Fusion"):
self.tsdf_voxel_size_slider = self.server.gui.add_slider(
"Voxel Size", min=0.001, max=0.1, step=0.001, initial_value=0.01,
hint="TSDF voxel size. Smaller = finer detail but slower.",
)
self.tsdf_sdf_trunc_slider = self.server.gui.add_slider(
"SDF Truncation", min=0.01, max=0.5, step=0.01, initial_value=0.04,
hint="Truncation distance. Typically 3-5x voxel size.",
)
self.tsdf_depth_scale_slider = self.server.gui.add_slider(
"Depth Scale", min=1.0, max=10000.0, step=1.0, initial_value=1.0,
hint="Depth scale factor. 1.0 if depth is in meters.",
)
self.tsdf_depth_trunc_slider = self.server.gui.add_slider(
"Depth Truncation", min=0.5, max=50.0, step=0.5, initial_value=5.0,
hint="Max depth value to integrate (meters).",
)
self.tsdf_run_button = self.server.gui.add_button(
"Run TSDF Fusion",
hint="Fuse all frames into a single point cloud via TSDF.",
)
self.tsdf_clear_button = self.server.gui.add_button(
"Clear TSDF Result",
hint="Remove the TSDF fused point cloud from the scene.",
)
self.tsdf_status = self.server.gui.add_text(
"Status", initial_value="Ready",
)
self._tsdf_handle = None
@self.tsdf_run_button.on_click
def _(_) -> None:
self._run_tsdf_fusion()
@self.tsdf_clear_button.on_click
def _(_) -> None:
if self._tsdf_handle is not None:
try:
self._tsdf_handle.remove()
except (KeyError, AttributeError):
pass
self._tsdf_handle = None
self.tsdf_status.value = "Cleared"
# Range visualization controls
with self.server.gui.add_folder("Frame Range Control"):
self.range_mode_checkbox = self.server.gui.add_checkbox("Range Mode", initial_value=False)
@@ -781,100 +689,6 @@ class PointCloudViewer:
if i % downsample_factor == 0:
self.add_camera(step)
def _run_tsdf_fusion(self):
"""Run TSDF fusion on all frames and display result as a point cloud."""
if not hasattr(self, 'tsdf_depth_maps') or self.tsdf_depth_maps is None:
self.tsdf_status.value = "Error: no depth data (need pred_dict)"
return
try:
import open3d as o3d
except ImportError:
self.tsdf_status.value = "Error: pip install open3d"
return
self.tsdf_status.value = "Running TSDF fusion..."
print("Starting TSDF fusion...")
voxel_size = self.tsdf_voxel_size_slider.value
sdf_trunc = self.tsdf_sdf_trunc_slider.value
depth_scale = self.tsdf_depth_scale_slider.value
depth_trunc = self.tsdf_depth_trunc_slider.value
volume = o3d.pipelines.integration.ScalableTSDFVolume(
voxel_length=voxel_size,
sdf_trunc=sdf_trunc,
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8,
)
S = self.tsdf_depth_maps.shape[0]
H, W = self.tsdf_depth_maps.shape[1], self.tsdf_depth_maps.shape[2]
for i in tqdm(range(S), desc="TSDF integrating"):
# Depth: (H, W, 1) -> (H, W)
depth = self.tsdf_depth_maps[i]
if depth.ndim == 3:
depth = depth[..., 0]
# Color: (3, H, W) -> (H, W, 3), uint8
color = self.tsdf_images[i].transpose(1, 2, 0) # (H, W, 3)
color = (np.clip(color, 0, 1) * 255).astype(np.uint8)
# Camera extrinsic: (3, 4) -> (4, 4) camera-from-world
extr_34 = self.tsdf_extrinsics[i]
extr_44 = np.eye(4, dtype=np.float64)
extr_44[:3, :] = extr_34
intrinsic = o3d.camera.PinholeCameraIntrinsic(
width=W, height=H,
fx=float(self.tsdf_intrinsics[i, 0, 0]),
fy=float(self.tsdf_intrinsics[i, 1, 1]),
cx=float(self.tsdf_intrinsics[i, 0, 2]),
cy=float(self.tsdf_intrinsics[i, 1, 2]),
)
depth_o3d = o3d.geometry.Image(
(depth.astype(np.float32) * depth_scale).astype(np.float32)
)
color_o3d = o3d.geometry.Image(np.ascontiguousarray(color))
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
color_o3d, depth_o3d,
depth_scale=depth_scale,
depth_trunc=depth_trunc,
convert_rgb_to_intensity=False,
)
volume.integrate(rgbd, intrinsic, extr_44)
print("Extracting point cloud from TSDF volume...")
pcd = volume.extract_point_cloud()
points = np.asarray(pcd.points, dtype=np.float32)
colors = np.asarray(pcd.colors, dtype=np.float32) # already 0-1
if len(points) == 0:
self.tsdf_status.value = "Error: empty result, try adjusting parameters"
print("TSDF fusion produced 0 points.")
return
# Remove previous TSDF result
if self._tsdf_handle is not None:
try:
self._tsdf_handle.remove()
except (KeyError, AttributeError):
pass
self._tsdf_handle = self.server.scene.add_point_cloud(
name="/tsdf_fusion",
points=points,
colors=colors,
point_size=self.psize_slider.value,
)
self.tsdf_status.value = f"Done: {len(points):,} points"
print(f"TSDF fusion complete: {len(points):,} points")
def _export_glb(self):
"""Export current filtered point clouds and cameras as a GLB file."""
try:
@@ -1354,27 +1168,6 @@ class PointCloudViewer:
if len(pred_pts) == 0:
return pred_pts, color
# Bounding box clip: remove points far from the scene center
if hasattr(self, 'bbox_clip_slider'):
clip_pct = self.bbox_clip_slider.value
if clip_pct < 100.0:
lo = np.percentile(pred_pts, (100.0 - clip_pct) / 2, axis=0)
hi = np.percentile(pred_pts, 100.0 - (100.0 - clip_pct) / 2, axis=0)
bbox_mask = np.all((pred_pts >= lo) & (pred_pts <= hi), axis=1)
pred_pts = pred_pts[bbox_mask]
color = color[bbox_mask]
if len(pred_pts) == 0:
return pred_pts, color
# Statistical Outlier Removal (SOR)
if hasattr(self, 'sor_checkbox') and self.sor_checkbox.value and len(pred_pts) > 0:
pred_pts, color = self._statistical_outlier_removal(
pred_pts, color,
nb_neighbors=int(self.sor_neighbors_slider.value),
std_ratio=self.sor_std_slider.value,
)
# Downsample
if downsample_factor > 1 and len(pred_pts) > 0:
indices = np.arange(0, len(pred_pts), downsample_factor)
@@ -1383,49 +1176,6 @@ class PointCloudViewer:
return pred_pts, color
@staticmethod
def _statistical_outlier_removal(
points: np.ndarray,
colors: np.ndarray,
nb_neighbors: int = 20,
std_ratio: float = 2.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Remove statistical outliers based on mean distance to k-nearest neighbors.
Args:
points: (N, 3) point positions.
colors: (N, 3) point colors.
nb_neighbors: Number of nearest neighbors to consider.
std_ratio: Standard deviation multiplier for the distance threshold.
Returns:
Filtered (points, colors) tuple.
"""
if len(points) <= nb_neighbors:
return points, colors
try:
from scipy.spatial import cKDTree
except ImportError:
# Fallback: skip SOR if scipy not available
return points, colors
# Subsample for KD-tree if too many points (speed)
max_pts_for_tree = 200_000
if len(points) > max_pts_for_tree:
subsample_idx = np.random.choice(len(points), max_pts_for_tree, replace=False)
tree = cKDTree(points[subsample_idx])
else:
tree = cKDTree(points)
dists, _ = tree.query(points, k=nb_neighbors + 1) # +1 because first is self
mean_dists = dists[:, 1:].mean(axis=1) # exclude self
threshold = mean_dists.mean() + std_ratio * mean_dists.std()
inlier_mask = mean_dists < threshold
return points[inlier_mask], colors[inlier_mask]
def add_pc(self, step):
"""Add point cloud for a frame."""
pc = self.pcs[step]["pc"]