remove sky mask

This commit is contained in:
LinZhuoChen
2026-04-16 18:53:54 +08:00
parent c7e49e1cbe
commit 1317fbb7b3
648 changed files with 30 additions and 54 deletions

View File

@@ -90,6 +90,8 @@ class PointCloudViewer:
use_point_map: bool = False,
mask_sky: bool = False,
image_folder: Optional[str] = None,
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
depth_stride: int = 1,
):
self.model = model
@@ -107,6 +109,8 @@ class PointCloudViewer:
if pred_dict is not None:
pc_list, color_list, conf_list, cam_dict = self._process_pred_dict(
pred_dict, use_point_map, mask_sky, image_folder,
sky_mask_dir=sky_mask_dir,
sky_mask_visualization_dir=sky_mask_visualization_dir,
depth_stride=depth_stride,
)
else:
@@ -138,6 +142,8 @@ class PointCloudViewer:
use_point_map: bool,
mask_sky: bool,
image_folder: Optional[str],
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
depth_stride: int = 1,
) -> Tuple[List, List, List, Dict]:
"""Process prediction dictionary to extract visualization data.
@@ -147,6 +153,8 @@ class PointCloudViewer:
use_point_map: Use point map instead of depth-based projection.
mask_sky: Apply sky segmentation to filter sky points.
image_folder: Path to images for sky segmentation.
sky_mask_dir: Directory for cached sky masks.
sky_mask_visualization_dir: Directory for sky mask visualization images.
depth_stride: Only project depth to point cloud every N frames.
Frames not projected will have empty point clouds but still
show camera frustums and images. 1 = every frame (default).
@@ -169,7 +177,11 @@ class PointCloudViewer:
# Apply sky segmentation if enabled
if mask_sky:
conf = apply_sky_segmentation(conf, image_folder=image_folder, images=images)
conf = apply_sky_segmentation(
conf, image_folder=image_folder, images=images,
sky_mask_dir=sky_mask_dir,
sky_mask_visualization_dir=sky_mask_visualization_dir,
)
# Convert images from (S, 3, H, W) to (S, H, W, 3)
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
@@ -404,7 +416,8 @@ class PointCloudViewer:
"Show Camera", initial_value=self.show_camera
)
self.vis_threshold_slider = self.server.gui.add_slider(
"Visibility Threshold", min=0.1, max=30.0, step=0.1, initial_value=self.vis_threshold
"Visibility Threshold", min=1.0, max=5.0, step=0.01,
initial_value=self.vis_threshold,
)
self.camera_downsample_slider = self.server.gui.add_slider(
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
@@ -412,11 +425,6 @@ class PointCloudViewer:
# Point cloud filtering controls
with self.server.gui.add_folder("Point Cloud Filtering"):
self.conf_percentile_slider = self.server.gui.add_slider(
"Confidence Percentile (%)",
min=0, max=95, step=1, initial_value=0,
hint="Remove the lowest N% of points by confidence. 0 = disabled.",
)
self.bbox_clip_slider = self.server.gui.add_slider(
"Bounding Box Keep (%)",
min=50.0, max=100.0, step=0.5, initial_value=100.0,
@@ -1346,20 +1354,6 @@ class PointCloudViewer:
if len(pred_pts) == 0:
return pred_pts, color
# Confidence percentile filter
if conf is not None and hasattr(self, 'conf_percentile_slider'):
pct = self.conf_percentile_slider.value
if pct > 0:
conf_remaining = conf_flat[mask] if conf is not None else None
if conf_remaining is not None and len(conf_remaining) > 0:
threshold = np.percentile(conf_remaining, pct)
pct_mask = conf_remaining >= threshold
pred_pts = pred_pts[pct_mask]
color = color[pct_mask]
if len(pred_pts) == 0:
return pred_pts, color
# Bounding box clip: remove points far from the scene center
if hasattr(self, 'bbox_clip_slider'):
clip_pct = self.bbox_clip_slider.value
@@ -1450,7 +1444,7 @@ class PointCloudViewer:
name=f"/frames/{step}/pred_pts",
points=pred_pts,
colors=color,
point_size=0.005,
point_size=self.psize_slider.value,
)
)
@@ -1509,8 +1503,8 @@ class PointCloudViewer:
)
gui_next_frame = self.server.gui.add_button("Next Step", disabled=False)
gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False)
gui_playing = self.server.gui.add_checkbox("Playing", False)
gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=1)
gui_playing = self.server.gui.add_checkbox("Playing", True)
gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20)
gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60"))
@gui_next_frame.on_click

View File

@@ -32,27 +32,16 @@ def _get_cache_version_path(sky_mask_dir: str) -> str:
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
"""Ensure the sky mask cache directory exists and write the version stamp."""
if sky_mask_dir is None:
return False
return
os.makedirs(sky_mask_dir, exist_ok=True)
version_path = _get_cache_version_path(sky_mask_dir)
refresh_cache = True
if os.path.exists(version_path):
with open(version_path, "r", encoding="utf-8") as f:
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
if refresh_cache:
print(
f"Sky mask cache at {sky_mask_dir} uses an older format; "
"regenerating masks with ImageNet-normalized skyseg input"
)
if not os.path.exists(version_path):
with open(version_path, "w", encoding="utf-8") as f:
f.write(_SKYSEG_CACHE_VERSION)
return refresh_cache
def run_skyseg(
onnx_session,
@@ -279,7 +268,7 @@ def load_or_create_sky_masks(
if sky_mask_dir is None and image_folder is not None:
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
_prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image array...")
for i in tqdm(range(num_images)):
@@ -288,17 +277,12 @@ def load_or_create_sky_masks(
image_name = _get_mask_filename(image_paths, i)
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
if mask_filepath is not None and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
elif sky_mask.shape[:2] != (image_h, image_w):
print(
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
)
if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
# Reuse cached mask
pass
else:
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
else:
@@ -338,14 +322,14 @@ def load_or_create_sky_masks(
if image_folder is None:
image_folder = os.path.dirname(image_paths[0])
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
_prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image files...")
for image_path in tqdm(image_paths):
image_name = os.path.basename(image_path)
mask_filepath = os.path.join(sky_mask_dir, image_name)
if not refresh_cache and os.path.exists(mask_filepath):
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")