remove sky mask
This commit is contained in:
@@ -90,6 +90,8 @@ class PointCloudViewer:
|
||||
use_point_map: bool = False,
|
||||
mask_sky: bool = False,
|
||||
image_folder: Optional[str] = None,
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
depth_stride: int = 1,
|
||||
):
|
||||
self.model = model
|
||||
@@ -107,6 +109,8 @@ class PointCloudViewer:
|
||||
if pred_dict is not None:
|
||||
pc_list, color_list, conf_list, cam_dict = self._process_pred_dict(
|
||||
pred_dict, use_point_map, mask_sky, image_folder,
|
||||
sky_mask_dir=sky_mask_dir,
|
||||
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
||||
depth_stride=depth_stride,
|
||||
)
|
||||
else:
|
||||
@@ -138,6 +142,8 @@ class PointCloudViewer:
|
||||
use_point_map: bool,
|
||||
mask_sky: bool,
|
||||
image_folder: Optional[str],
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
depth_stride: int = 1,
|
||||
) -> Tuple[List, List, List, Dict]:
|
||||
"""Process prediction dictionary to extract visualization data.
|
||||
@@ -147,6 +153,8 @@ class PointCloudViewer:
|
||||
use_point_map: Use point map instead of depth-based projection.
|
||||
mask_sky: Apply sky segmentation to filter sky points.
|
||||
image_folder: Path to images for sky segmentation.
|
||||
sky_mask_dir: Directory for cached sky masks.
|
||||
sky_mask_visualization_dir: Directory for sky mask visualization images.
|
||||
depth_stride: Only project depth to point cloud every N frames.
|
||||
Frames not projected will have empty point clouds but still
|
||||
show camera frustums and images. 1 = every frame (default).
|
||||
@@ -169,7 +177,11 @@ class PointCloudViewer:
|
||||
|
||||
# Apply sky segmentation if enabled
|
||||
if mask_sky:
|
||||
conf = apply_sky_segmentation(conf, image_folder=image_folder, images=images)
|
||||
conf = apply_sky_segmentation(
|
||||
conf, image_folder=image_folder, images=images,
|
||||
sky_mask_dir=sky_mask_dir,
|
||||
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
||||
)
|
||||
|
||||
# Convert images from (S, 3, H, W) to (S, H, W, 3)
|
||||
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
||||
@@ -404,7 +416,8 @@ class PointCloudViewer:
|
||||
"Show Camera", initial_value=self.show_camera
|
||||
)
|
||||
self.vis_threshold_slider = self.server.gui.add_slider(
|
||||
"Visibility Threshold", min=0.1, max=30.0, step=0.1, initial_value=self.vis_threshold
|
||||
"Visibility Threshold", min=1.0, max=5.0, step=0.01,
|
||||
initial_value=self.vis_threshold,
|
||||
)
|
||||
self.camera_downsample_slider = self.server.gui.add_slider(
|
||||
"Camera Downsample Factor", min=1, max=50, step=1, initial_value=1
|
||||
@@ -412,11 +425,6 @@ class PointCloudViewer:
|
||||
|
||||
# Point cloud filtering controls
|
||||
with self.server.gui.add_folder("Point Cloud Filtering"):
|
||||
self.conf_percentile_slider = self.server.gui.add_slider(
|
||||
"Confidence Percentile (%)",
|
||||
min=0, max=95, step=1, initial_value=0,
|
||||
hint="Remove the lowest N% of points by confidence. 0 = disabled.",
|
||||
)
|
||||
self.bbox_clip_slider = self.server.gui.add_slider(
|
||||
"Bounding Box Keep (%)",
|
||||
min=50.0, max=100.0, step=0.5, initial_value=100.0,
|
||||
@@ -1346,20 +1354,6 @@ class PointCloudViewer:
|
||||
if len(pred_pts) == 0:
|
||||
return pred_pts, color
|
||||
|
||||
# Confidence percentile filter
|
||||
if conf is not None and hasattr(self, 'conf_percentile_slider'):
|
||||
pct = self.conf_percentile_slider.value
|
||||
if pct > 0:
|
||||
conf_remaining = conf_flat[mask] if conf is not None else None
|
||||
if conf_remaining is not None and len(conf_remaining) > 0:
|
||||
threshold = np.percentile(conf_remaining, pct)
|
||||
pct_mask = conf_remaining >= threshold
|
||||
pred_pts = pred_pts[pct_mask]
|
||||
color = color[pct_mask]
|
||||
|
||||
if len(pred_pts) == 0:
|
||||
return pred_pts, color
|
||||
|
||||
# Bounding box clip: remove points far from the scene center
|
||||
if hasattr(self, 'bbox_clip_slider'):
|
||||
clip_pct = self.bbox_clip_slider.value
|
||||
@@ -1450,7 +1444,7 @@ class PointCloudViewer:
|
||||
name=f"/frames/{step}/pred_pts",
|
||||
points=pred_pts,
|
||||
colors=color,
|
||||
point_size=0.005,
|
||||
point_size=self.psize_slider.value,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1509,8 +1503,8 @@ class PointCloudViewer:
|
||||
)
|
||||
gui_next_frame = self.server.gui.add_button("Next Step", disabled=False)
|
||||
gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False)
|
||||
gui_playing = self.server.gui.add_checkbox("Playing", False)
|
||||
gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=1)
|
||||
gui_playing = self.server.gui.add_checkbox("Playing", True)
|
||||
gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=20)
|
||||
gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60"))
|
||||
|
||||
@gui_next_frame.on_click
|
||||
|
||||
@@ -32,27 +32,16 @@ def _get_cache_version_path(sky_mask_dir: str) -> str:
|
||||
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
||||
|
||||
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
|
||||
"""Ensure the sky mask cache directory exists and write the version stamp."""
|
||||
if sky_mask_dir is None:
|
||||
return False
|
||||
|
||||
return
|
||||
os.makedirs(sky_mask_dir, exist_ok=True)
|
||||
version_path = _get_cache_version_path(sky_mask_dir)
|
||||
refresh_cache = True
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
|
||||
|
||||
if refresh_cache:
|
||||
print(
|
||||
f"Sky mask cache at {sky_mask_dir} uses an older format; "
|
||||
"regenerating masks with ImageNet-normalized skyseg input"
|
||||
)
|
||||
if not os.path.exists(version_path):
|
||||
with open(version_path, "w", encoding="utf-8") as f:
|
||||
f.write(_SKYSEG_CACHE_VERSION)
|
||||
|
||||
return refresh_cache
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
@@ -279,7 +268,7 @@ def load_or_create_sky_masks(
|
||||
|
||||
if sky_mask_dir is None and image_folder is not None:
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
_prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image array...")
|
||||
for i in tqdm(range(num_images)):
|
||||
@@ -288,17 +277,12 @@ def load_or_create_sky_masks(
|
||||
image_name = _get_mask_filename(image_paths, i)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
||||
|
||||
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
|
||||
if mask_filepath is not None and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
elif sky_mask.shape[:2] != (image_h, image_w):
|
||||
print(
|
||||
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
|
||||
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
|
||||
)
|
||||
if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
|
||||
# Reuse cached mask
|
||||
pass
|
||||
else:
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
else:
|
||||
@@ -338,14 +322,14 @@ def load_or_create_sky_masks(
|
||||
if image_folder is None:
|
||||
image_folder = os.path.dirname(image_paths[0])
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
_prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image files...")
|
||||
for image_path in tqdm(image_paths):
|
||||
image_name = os.path.basename(image_path)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
||||
|
||||
if not refresh_cache and os.path.exists(mask_filepath):
|
||||
if os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
|
||||
Reference in New Issue
Block a user