first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
GCT Visualization Module
This module provides visualization utilities for 3D reconstruction results:
- PointCloudViewer: Interactive point cloud viewer with camera visualization
- viser_wrapper: Quick visualization wrapper for predictions
- predictions_to_glb: Export predictions to GLB 3D format
- Colorization and utility functions
Usage:
from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
# Interactive visualization
viewer = PointCloudViewer(pred_dict=predictions, port=8080)
viewer.run()
# Quick visualization
viser_wrapper(predictions, port=8080)
# Export to GLB
scene = predictions_to_glb(predictions)
scene.export("output.glb")
"""
from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
from lingbot_map.vis.viser_wrapper import viser_wrapper
from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
from lingbot_map.vis.sky_segmentation import (
apply_sky_segmentation,
download_skyseg_model,
load_or_create_sky_masks,
segment_sky,
)
from lingbot_map.vis.glb_export import predictions_to_glb
__all__ = [
# Main viewer
"PointCloudViewer",
# Quick visualization
"viser_wrapper",
# GLB export
"predictions_to_glb",
# Utilities
"CameraState",
"colorize",
"colorize_np",
"get_vertical_colorbar",
# Sky segmentation
"apply_sky_segmentation",
"segment_sky",
"download_skyseg_model",
"load_or_create_sky_masks",
]

View File

@@ -0,0 +1,509 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
GLB 3D export utilities for GCT predictions.
"""
import os
import copy
from typing import Optional, Tuple
import numpy as np
import cv2
import matplotlib
from scipy.spatial.transform import Rotation
from lingbot_map.vis.sky_segmentation import (
_SKYSEG_INPUT_SIZE,
_SKYSEG_SOFT_THRESHOLD,
_mask_to_float,
_mask_to_uint8,
_result_map_to_non_sky_conf,
)
try:
import trimesh
except ImportError:
trimesh = None
print("trimesh not found. GLB export will not work.")
def predictions_to_glb(
predictions: dict,
conf_thres: float = 50.0,
filter_by_frames: str = "all",
mask_black_bg: bool = False,
mask_white_bg: bool = False,
show_cam: bool = True,
mask_sky: bool = False,
target_dir: Optional[str] = None,
prediction_mode: str = "Predicted Pointmap",
) -> "trimesh.Scene":
"""
Converts GCT predictions to a 3D scene represented as a GLB file.
Args:
predictions: Dictionary containing model predictions with keys:
- world_points: 3D point coordinates (S, H, W, 3)
- world_points_conf: Confidence scores (S, H, W)
- images: Input images (S, H, W, 3) or (S, 3, H, W)
- extrinsic: Camera extrinsic matrices (S, 3, 4)
conf_thres: Percentage of low-confidence points to filter out
filter_by_frames: Frame filter specification ("all" or frame index)
mask_black_bg: Mask out black background pixels
mask_white_bg: Mask out white background pixels
show_cam: Include camera visualization
mask_sky: Apply sky segmentation mask
target_dir: Output directory for intermediate files
prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
Returns:
trimesh.Scene: Processed 3D scene containing point cloud and cameras
Raises:
ValueError: If input predictions structure is invalid
ImportError: If trimesh is not available
"""
if trimesh is None:
raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
if not isinstance(predictions, dict):
raise ValueError("predictions must be a dictionary")
if conf_thres is None:
conf_thres = 10.0
print("Building GLB scene")
# Parse frame filter
selected_frame_idx = None
if filter_by_frames != "all" and filter_by_frames != "All":
try:
selected_frame_idx = int(filter_by_frames.split(":")[0])
except (ValueError, IndexError):
pass
# Select prediction source
if "Pointmap" in prediction_mode:
print("Using Pointmap Branch")
if "world_points" in predictions:
pred_world_points = predictions["world_points"]
pred_world_points_conf = predictions.get(
"world_points_conf", np.ones_like(pred_world_points[..., 0])
)
else:
print("Warning: world_points not found, falling back to depth-based points")
pred_world_points = predictions["world_points_from_depth"]
pred_world_points_conf = predictions.get(
"depth_conf", np.ones_like(pred_world_points[..., 0])
)
else:
print("Using Depthmap and Camera Branch")
pred_world_points = predictions["world_points_from_depth"]
pred_world_points_conf = predictions.get(
"depth_conf", np.ones_like(pred_world_points[..., 0])
)
images = predictions["images"]
camera_matrices = predictions["extrinsic"]
# Apply sky segmentation if enabled
if mask_sky and target_dir is not None:
pred_world_points_conf = _apply_sky_mask(
pred_world_points_conf, target_dir, images
)
# Apply frame filter
if selected_frame_idx is not None:
pred_world_points = pred_world_points[selected_frame_idx][None]
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
images = images[selected_frame_idx][None]
camera_matrices = camera_matrices[selected_frame_idx][None]
# Prepare vertices and colors
vertices_3d = pred_world_points.reshape(-1, 3)
# Handle different image formats
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
colors_rgb = np.transpose(images, (0, 2, 3, 1))
else:
colors_rgb = images
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
# Apply confidence filtering
conf = pred_world_points_conf.reshape(-1)
conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
# Apply background masking
if mask_black_bg:
black_bg_mask = colors_rgb.sum(axis=1) >= 16
conf_mask = conf_mask & black_bg_mask
if mask_white_bg:
white_bg_mask = ~(
(colors_rgb[:, 0] > 240) &
(colors_rgb[:, 1] > 240) &
(colors_rgb[:, 2] > 240)
)
conf_mask = conf_mask & white_bg_mask
vertices_3d = vertices_3d[conf_mask]
colors_rgb = colors_rgb[conf_mask]
# Handle empty point cloud
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
vertices_3d = np.array([[1, 0, 0]])
colors_rgb = np.array([[255, 255, 255]])
scene_scale = 1
else:
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
# Build scene
scene_3d = trimesh.Scene()
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
scene_3d.add_geometry(point_cloud_data)
# Prepare camera matrices
num_cameras = len(camera_matrices)
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
extrinsics_matrices[:, :3, :4] = camera_matrices
extrinsics_matrices[:, 3, 3] = 1
# Add cameras
if show_cam:
for i in range(num_cameras):
world_to_camera = extrinsics_matrices[i]
camera_to_world = np.linalg.inv(world_to_camera)
rgba_color = colormap(i / num_cameras)
current_color = tuple(int(255 * x) for x in rgba_color[:3])
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
# Align scene
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
print("GLB Scene built")
return scene_3d
def _apply_sky_mask(
conf: np.ndarray,
target_dir: str,
images: np.ndarray
) -> np.ndarray:
"""Apply sky segmentation mask to confidence scores."""
try:
import onnxruntime
except ImportError:
print("Warning: onnxruntime not available, skipping sky masking")
return conf
target_dir_images = os.path.join(target_dir, "images")
if not os.path.exists(target_dir_images):
print(f"Warning: Images directory not found at {target_dir_images}")
return conf
image_list = sorted(os.listdir(target_dir_images))
S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
skyseg_model_path = "skyseg.onnx"
if not os.path.exists(skyseg_model_path):
print("Downloading skyseg.onnx...")
download_file_from_url(
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
skyseg_model_path
)
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
sky_mask_list = []
for i, image_name in enumerate(image_list[:S]):
image_filepath = os.path.join(target_dir_images, image_name)
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
sky_mask_list.append(_mask_to_float(sky_mask))
sky_mask_array = np.array(sky_mask_list)
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
return conf * sky_mask_binary
def integrate_camera_into_scene(
scene: "trimesh.Scene",
transform: np.ndarray,
face_colors: Tuple[int, int, int],
scene_scale: float,
frustum_thickness: float = 1.0,
):
"""
Integrates a camera mesh into the 3D scene.
Args:
scene: The 3D scene to add the camera model
transform: Transformation matrix for camera positioning
face_colors: RGB color tuple for the camera
scene_scale: Scale of the scene
frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
"""
cam_width = scene_scale * 0.05
cam_height = scene_scale * 0.1
rot_45_degree = np.eye(4)
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
rot_45_degree[2, 3] = -cam_height
opengl_transform = get_opengl_conversion_matrix()
complete_transform = transform @ opengl_transform @ rot_45_degree
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
# Build thicker frustum by stacking rotated copies
slight_rotation = np.eye(4)
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
shell_scales = [1.0, 0.95]
shell_transforms = [np.eye(4), slight_rotation]
# Add extra shells for thickness
if frustum_thickness > 1.0:
n_extra = max(1, int(frustum_thickness - 1))
for k in range(1, n_extra + 1):
# Progressively rotated and scaled copies
angle = 2.0 + k * 2.0
scale = 1.0 + k * 0.02
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
shell_scales.append(scale)
shell_transforms.append(rot)
rot_neg = np.eye(4)
rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
shell_scales.append(scale)
shell_transforms.append(rot_neg)
vertices_parts = []
for s, t_mat in zip(shell_scales, shell_transforms):
vertices_parts.append(
transform_points(t_mat, s * camera_cone_shape.vertices)
)
vertices_combined = np.concatenate(vertices_parts)
vertices_transformed = transform_points(complete_transform, vertices_combined)
mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
camera_mesh.visual.face_colors[:, :3] = face_colors
scene.add_geometry(camera_mesh)
def apply_scene_alignment(
scene_3d: "trimesh.Scene",
extrinsics_matrices: np.ndarray
) -> "trimesh.Scene":
"""
Aligns the 3D scene based on the extrinsics of the first camera.
Args:
scene_3d: The 3D scene to be aligned
extrinsics_matrices: Camera extrinsic matrices
Returns:
Aligned 3D scene
"""
opengl_conversion_matrix = get_opengl_conversion_matrix()
align_rotation = np.eye(4)
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
initial_transformation = (
np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
)
scene_3d.apply_transform(initial_transformation)
return scene_3d
def get_opengl_conversion_matrix() -> np.ndarray:
"""Returns the OpenGL conversion matrix (flips Y and Z axes)."""
matrix = np.identity(4)
matrix[1, 1] = -1
matrix[2, 2] = -1
return matrix
def transform_points(
transformation: np.ndarray,
points: np.ndarray,
dim: Optional[int] = None
) -> np.ndarray:
"""
Applies a 4x4 transformation to a set of points.
Args:
transformation: Transformation matrix
points: Points to be transformed
dim: Dimension for reshaping the result
Returns:
Transformed points
"""
points = np.asarray(points)
initial_shape = points.shape[:-1]
dim = dim or points.shape[-1]
transformation = transformation.swapaxes(-1, -2)
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
return points[..., :dim].reshape(*initial_shape, dim)
def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
"""Computes the faces for the camera mesh."""
faces_list = []
num_vertices_cone = len(cone_shape.vertices)
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
faces_list.extend([
(v1, v2, v2_offset),
(v1, v1_offset, v3),
(v3_offset, v2, v3),
(v1, v2, v2_offset_2),
(v1, v1_offset_2, v3),
(v3_offset_2, v2, v3),
])
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)
def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
"""Computes faces for a camera mesh with multiple shells (for thicker frustums).
Connects each consecutive pair of vertex shells to form the frustum edges.
"""
faces_list = []
nv = len(cone_shape.vertices)
for s in range(num_shells - 1):
off_a = s * nv
off_b = (s + 1) * nv
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
faces_list.extend([
(v1 + off_a, v2 + off_a, v2 + off_b),
(v1 + off_a, v1 + off_b, v3 + off_a),
(v3 + off_b, v2 + off_a, v3 + off_a),
])
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)
def segment_sky(
image_path: str,
onnx_session,
mask_filename: str
) -> np.ndarray:
"""
Segments sky from an image using an ONNX model.
Args:
image_path: Path to input image
onnx_session: ONNX runtime session with loaded model
mask_filename: Path to save the output mask
Returns:
Continuous non-sky confidence map in [0, 1]
"""
image = cv2.imread(image_path)
result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
result_map_original = cv2.resize(
result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
)
output_mask = _result_map_to_non_sky_conf(result_map_original)
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
return output_mask
def run_skyseg(
onnx_session,
input_size: Tuple[int, int],
image: np.ndarray
) -> np.ndarray:
"""
Runs sky segmentation inference using ONNX model.
Args:
onnx_session: ONNX runtime session
input_size: Target size for model input (width, height)
image: Input image in BGR format
Returns:
Segmentation mask
"""
temp_image = copy.deepcopy(image)
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
onnx_result = (onnx_result - min_value) / (max_value - min_value)
onnx_result *= 255
return onnx_result.astype("uint8")
def download_file_from_url(url: str, filename: str):
"""Downloads a file from a URL, handling redirects."""
import requests
try:
response = requests.get(url, allow_redirects=False)
response.raise_for_status()
if response.status_code == 302:
redirect_url = response.headers["Location"]
response = requests.get(redirect_url, stream=True)
response.raise_for_status()
else:
print(f"Unexpected status code: {response.status_code}")
return
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {filename} successfully.")
except requests.exceptions.RequestException as e:
print(f"Error downloading file: {e}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,473 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Sky segmentation utilities for filtering sky points from point clouds.
"""
import glob
import os
from typing import Optional, Tuple
import numpy as np
import cv2
from tqdm.auto import tqdm
try:
import onnxruntime
except ImportError:
onnxruntime = None
print("onnxruntime not found. Sky segmentation may not work.")
_SKYSEG_INPUT_SIZE = (320, 320)
_SKYSEG_SOFT_THRESHOLD = 0.1
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
def _get_cache_version_path(sky_mask_dir: str) -> str:
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
if sky_mask_dir is None:
return False
os.makedirs(sky_mask_dir, exist_ok=True)
version_path = _get_cache_version_path(sky_mask_dir)
refresh_cache = True
if os.path.exists(version_path):
with open(version_path, "r", encoding="utf-8") as f:
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
if refresh_cache:
print(
f"Sky mask cache at {sky_mask_dir} uses an older format; "
"regenerating masks with ImageNet-normalized skyseg input"
)
with open(version_path, "w", encoding="utf-8") as f:
f.write(_SKYSEG_CACHE_VERSION)
return refresh_cache
def run_skyseg(
onnx_session,
input_size: Tuple[int, int],
image: np.ndarray,
) -> np.ndarray:
"""
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
"""
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
x = (x / 255.0 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
denom = max(max_value - min_value, 1e-8)
onnx_result = (onnx_result - min_value) / denom
onnx_result *= 255.0
return onnx_result.astype(np.uint8)
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
mask = mask.astype(np.float32)
if mask.size == 0:
return mask
return np.clip(mask, 0.0, 1.0)
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
mask = np.asarray(mask)
if mask.dtype == np.uint8:
return mask
mask = mask.astype(np.float32)
if mask.size > 0 and mask.max() <= 1.0:
mask = mask * 255.0
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
# The raw skyseg map is higher on sky and lower on non-sky.
return 1.0 - _mask_to_float(result_map)
def segment_sky_from_array(
image: np.ndarray,
skyseg_session,
target_h: int,
target_w: int
) -> np.ndarray:
"""
Segment sky from an image array using ONNX model.
Args:
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
skyseg_session: ONNX runtime inference session
target_h: Target output height
target_w: Target output width
Returns:
Continuous non-sky confidence map in [0, 1].
"""
image_rgb = _image_to_rgb_uint8(image)
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
return _result_map_to_non_sky_conf(result_map)
def segment_sky(
image_path: str,
skyseg_session,
output_path: Optional[str] = None
) -> np.ndarray:
"""
Segment sky from an image using ONNX model.
Args:
image_path: Path to the input image
skyseg_session: ONNX runtime inference session
output_path: Optional path to save the mask
Returns:
Continuous non-sky confidence map in [0, 1].
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to read image: {image_path}")
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
mask = _result_map_to_non_sky_conf(result_map)
if output_path is not None:
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(output_path, _mask_to_uint8(mask))
return mask
def _list_image_files(image_folder: str) -> list[str]:
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
image = image.transpose(1, 2, 0)
if image.ndim != 3 or image.shape[2] != 3:
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
if image.dtype != np.uint8:
image = image.astype(np.float32)
if image.max() <= 1.0:
image = image * 255.0
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
return image
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
if image_paths is not None and index < len(image_paths):
return os.path.basename(image_paths[index])
return f"frame_{index:06d}.png"
def _save_sky_mask_visualization(
image: np.ndarray,
sky_mask: np.ndarray,
output_path: str,
) -> None:
image_rgb = _image_to_rgb_uint8(image)
if sky_mask.shape[:2] != image_rgb.shape[:2]:
sky_mask = cv2.resize(
sky_mask,
(image_rgb.shape[1], image_rgb.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
mask_uint8 = _mask_to_uint8(sky_mask)
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
overlay = image_rgb.astype(np.float32).copy()
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
def load_or_create_sky_masks(
image_folder: Optional[str] = None,
image_paths: Optional[list[str]] = None,
images: Optional[np.ndarray] = None,
skyseg_model_path: str = "skyseg.onnx",
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
target_shape: Optional[Tuple[int, int]] = None,
num_frames: Optional[int] = None,
) -> Optional[np.ndarray]:
"""
Load cached sky masks or generate them with the ONNX model.
Args:
image_folder: Folder containing input images.
image_paths: Optional explicit image file list, in the exact order to process.
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
skyseg_model_path: Path to the sky segmentation ONNX model.
sky_mask_dir: Optional directory for cached raw masks.
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
target_shape: Optional output mask shape (H, W) after resizing.
num_frames: Optional maximum number of frames to process.
Returns:
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
"""
if onnxruntime is None:
print("Warning: onnxruntime not available, skipping sky segmentation")
return None
if image_folder is None and image_paths is None and images is None:
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
return None
if not os.path.exists(skyseg_model_path):
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
try:
download_skyseg_model(skyseg_model_path)
except Exception as e:
print(f"Warning: Failed to download sky segmentation model: {e}")
return None
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
sky_masks = []
if sky_mask_visualization_dir is not None:
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
if images is not None:
if image_paths is None and image_folder is not None:
image_paths = _list_image_files(image_folder)
num_images = images.shape[0]
if num_frames is not None:
num_images = min(num_images, num_frames)
if image_paths is not None:
image_paths = image_paths[:num_images]
if sky_mask_dir is None and image_folder is not None:
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image array...")
for i in tqdm(range(num_images)):
image_rgb = _image_to_rgb_uint8(images[i])
image_h, image_w = image_rgb.shape[:2]
image_name = _get_mask_filename(image_paths, i)
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
elif sky_mask.shape[:2] != (image_h, image_w):
print(
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
)
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
else:
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
if mask_filepath is not None:
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
if sky_mask_visualization_dir is not None:
_save_sky_mask_visualization(
image_rgb,
sky_mask,
os.path.join(sky_mask_visualization_dir, image_name),
)
if target_shape is not None and sky_mask.shape[:2] != target_shape:
sky_mask = cv2.resize(
sky_mask,
(target_shape[1], target_shape[0]),
interpolation=cv2.INTER_LINEAR,
)
sky_masks.append(_mask_to_float(sky_mask))
else:
if image_paths is None and image_folder is not None:
image_paths = _list_image_files(image_folder)
if images is None and image_paths is not None:
if len(image_paths) == 0:
print("Warning: No image files provided, skipping sky segmentation")
return None
if num_frames is not None:
image_paths = image_paths[:num_frames]
if sky_mask_dir is None:
if image_folder is None:
image_folder = os.path.dirname(image_paths[0])
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image files...")
for image_path in tqdm(image_paths):
image_name = os.path.basename(image_path)
mask_filepath = os.path.join(sky_mask_dir, image_name)
if not refresh_cache and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
else:
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
if sky_mask is None:
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
continue
if sky_mask_visualization_dir is not None:
image_bgr = cv2.imread(image_path)
if image_bgr is not None:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
_save_sky_mask_visualization(
image_rgb,
sky_mask,
os.path.join(sky_mask_visualization_dir, image_name),
)
if target_shape is not None and sky_mask.shape[:2] != target_shape:
sky_mask = cv2.resize(
sky_mask,
(target_shape[1], target_shape[0]),
interpolation=cv2.INTER_LINEAR,
)
sky_masks.append(_mask_to_float(sky_mask))
if len(sky_masks) == 0:
print("Warning: No sky masks generated, skipping sky segmentation")
return None
try:
return np.stack(sky_masks, axis=0)
except ValueError:
return np.array(sky_masks, dtype=object)
def apply_sky_segmentation(
conf: np.ndarray,
image_folder: Optional[str] = None,
image_paths: Optional[list[str]] = None,
images: Optional[np.ndarray] = None,
skyseg_model_path: str = "skyseg.onnx",
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
) -> np.ndarray:
"""
Apply sky segmentation to confidence scores.
Args:
conf: Confidence scores with shape (S, H, W)
image_folder: Path to the folder containing input images (optional if images provided)
image_paths: Optional explicit image file list in processing order
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
skyseg_model_path: Path to the sky segmentation ONNX model
sky_mask_dir: Optional directory for cached raw masks
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
Returns:
Updated confidence scores with sky regions masked out
"""
S, H, W = conf.shape
sky_mask_array = load_or_create_sky_masks(
image_folder=image_folder,
image_paths=image_paths,
images=images,
skyseg_model_path=skyseg_model_path,
sky_mask_dir=sky_mask_dir,
sky_mask_visualization_dir=sky_mask_visualization_dir,
target_shape=(H, W),
num_frames=S,
)
if sky_mask_array is None:
return conf
if sky_mask_array.shape[0] < S:
print(
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
"leaving the remaining frames unmasked"
)
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
padded[: sky_mask_array.shape[0]] = sky_mask_array
sky_mask_array = padded
elif sky_mask_array.shape[0] > S:
sky_mask_array = sky_mask_array[:S]
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
conf = conf * sky_mask_binary
print("Sky segmentation applied successfully")
return conf
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
"""
Download sky segmentation model from HuggingFace.
Args:
output_path: Path to save the model
Returns:
Path to the downloaded model
"""
import requests
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
print(f"Downloading sky segmentation model from {url}...")
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(output_path, 'wb') as f:
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
pbar.update(len(chunk))
print(f"Model saved to {output_path}")
return output_path

206
lingbot_map/vis/utils.py Normal file
View File

@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Visualization utility functions for colorization and color bars.
"""
import dataclasses
from typing import Optional, Tuple
import numpy as np
import torch
import cv2
import matplotlib.cm as cm
@dataclasses.dataclass
class CameraState:
"""Camera state for rendering."""
fov: float
aspect: float
c2w: np.ndarray
def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
"""Get camera intrinsic matrix from FOV and image size."""
W, H = img_wh
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
K = np.array([
[focal_length, 0.0, W / 2.0],
[0.0, focal_length, H / 2.0],
[0.0, 0.0, 1.0],
])
return K
def get_vertical_colorbar(
h: int,
vmin: float,
vmax: float,
cmap_name: str = "jet",
label: Optional[str] = None,
cbar_precision: int = 2
) -> np.ndarray:
"""
Create a vertical colorbar image.
Args:
h: Height in pixels
vmin: Minimum value
vmax: Maximum value
cmap_name: Colormap name
label: Optional label for the colorbar
cbar_precision: Decimal precision for tick labels
Returns:
Colorbar image as numpy array (H, W, 3)
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib as mpl
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x: np.ndarray,
cmap_name: str = "jet",
mask: Optional[np.ndarray] = None,
range: Optional[Tuple[float, float]] = None,
append_cbar: bool = False,
cbar_in_image: bool = False,
cbar_precision: int = 2,
) -> np.ndarray:
"""
Turn a grayscale image into a color image.
Args:
x: Input grayscale image [H, W]
cmap_name: Colormap name
mask: Optional mask image [H, W]
range: Value range for scaling [min, max], automatic if None
append_cbar: Whether to append colorbar
cbar_in_image: Put colorbar inside image
cbar_precision: Colorbar tick precision
Returns:
Colorized image [H, W, 3]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
x[np.logical_not(mask)] = vmin
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1]:, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
def colorize(
x: torch.Tensor,
cmap_name: str = "jet",
mask: Optional[torch.Tensor] = None,
range: Optional[Tuple[float, float]] = None,
append_cbar: bool = False,
cbar_in_image: bool = False
) -> torch.Tensor:
"""
Turn a grayscale image into a color image (PyTorch tensor version).
Args:
x: Grayscale image tensor [H, W] or [B, H, W]
cmap_name: Colormap name
mask: Optional mask tensor [H, W] or [B, H, W]
range: Value range for scaling
append_cbar: Whether to append colorbar
cbar_in_image: Put colorbar inside image
Returns:
Colorized tensor
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out

View File

@@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Quick visualization wrapper for GCT predictions using Viser.
"""
import time
import threading
from typing import List, Optional
import numpy as np
import viser
import viser.transforms as tf
from tqdm.auto import tqdm
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
def viser_wrapper(
pred_dict: dict,
port: int = 8080,
init_conf_threshold: float = 50.0,
use_point_map: bool = False,
background_mode: bool = False,
mask_sky: bool = False,
image_folder: Optional[str] = None,
):
"""
Visualize predicted 3D points and camera poses with viser.
This is a simplified wrapper for quick visualization without the full
PointCloudViewer controls.
Args:
pred_dict: Dictionary containing predictions with keys:
- images: (S, 3, H, W) - Input images
- world_points: (S, H, W, 3)
- world_points_conf: (S, H, W)
- depth: (S, H, W, 1)
- depth_conf: (S, H, W)
- extrinsic: (S, 3, 4)
- intrinsic: (S, 3, 3)
port: Port number for the viser server
init_conf_threshold: Initial percentage of low-confidence points to filter out
use_point_map: Whether to visualize world_points or use depth-based points
background_mode: Whether to run the server in background thread
mask_sky: Whether to apply sky segmentation to filter out sky points
image_folder: Path to the folder containing input images (for sky segmentation)
Returns:
viser.ViserServer: The viser server instance
"""
print(f"Starting viser server on port {port}")
server = viser.ViserServer(host="0.0.0.0", port=port)
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
# Unpack prediction dict
images = pred_dict["images"] # (S, 3, H, W)
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
conf_map = pred_dict["world_points_conf"] # (S, H, W)
depth_map = pred_dict["depth"] # (S, H, W, 1)
depth_conf = pred_dict["depth_conf"] # (S, H, W)
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
# Compute world points from depth if not using the precomputed point map
if not use_point_map:
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
conf = depth_conf
else:
world_points = world_points_map
conf = conf_map
# Apply sky segmentation if enabled
if mask_sky and image_folder is not None:
conf = apply_sky_segmentation(conf, image_folder)
# Convert images from (S, 3, H, W) to (S, H, W, 3)
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
shape = world_points.shape
S: int = shape[0]
H: int = shape[1]
W: int = shape[2]
# Flatten
points = world_points.reshape(-1, 3)
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
conf_flat = conf.reshape(-1)
# Random sample points if too many
indices = None
if points.shape[0] > 6000000:
print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
indices = np.random.choice(points.shape[0], size=6000000, replace=False)
points = points[indices]
colors_flat = colors_flat[indices]
conf_flat = conf_flat[indices]
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
cam_to_world = cam_to_world_mat[:, :3, :]
# Compute scene center and recenter
scene_center = np.mean(points, axis=0)
points_centered = points - scene_center
cam_to_world[..., -1] -= scene_center
# Store frame indices for filtering
frame_indices = (
np.repeat(np.arange(S), H * W)[indices]
if indices is not None
else np.repeat(np.arange(S), H * W)
)
# Build the viser GUI
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
gui_points_conf = server.gui.add_slider(
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
)
gui_frame_selector = server.gui.add_dropdown(
"Show Points from Frames",
options=["All"] + [str(i) for i in range(S)],
initial_value="All"
)
# Create the main point cloud
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
point_cloud = server.scene.add_point_cloud(
name="viser_pcd",
points=points_centered[init_conf_mask],
colors=colors_flat[init_conf_mask],
point_size=0.0005,
point_shape="circle",
)
frames: List[viser.FrameHandle] = []
frustums: List[viser.CameraFrustumHandle] = []
def visualize_frames(extrinsics, images_: np.ndarray) -> None:
"""Add camera frames and frustums to the scene."""
for f in frames:
f.remove()
frames.clear()
for fr in frustums:
fr.remove()
frustums.clear()
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
@frustum.on_click
def _(_) -> None:
for client in server.get_clients().values():
client.camera.wxyz = frame.wxyz
client.camera.position = frame.position
for img_id in tqdm(range(S)):
cam2world_3x4 = extrinsics[img_id]
T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
frame_axis = server.scene.add_frame(
f"frame_{img_id}",
wxyz=T_world_camera.rotation().wxyz,
position=T_world_camera.translation(),
axes_length=0.05,
axes_radius=0.002,
origin_radius=0.002,
)
frames.append(frame_axis)
img = images_[img_id]
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
h, w = img.shape[:2]
fy = 1.1 * h
fov = 2 * np.arctan2(h / 2, fy)
frustum_cam = server.scene.add_camera_frustum(
f"frame_{img_id}/frustum",
fov=fov,
aspect=w / h,
scale=0.05,
image=img,
line_width=1.0
)
frustums.append(frustum_cam)
attach_callback(frustum_cam, frame_axis)
def update_point_cloud() -> None:
"""Update point cloud based on current GUI selections."""
current_percentage = gui_points_conf.value
threshold_val = np.percentile(conf_flat, current_percentage)
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
if gui_frame_selector.value == "All":
frame_mask = np.ones_like(conf_mask, dtype=bool)
else:
selected_idx = int(gui_frame_selector.value)
frame_mask = frame_indices == selected_idx
combined_mask = conf_mask & frame_mask
point_cloud.points = points_centered[combined_mask]
point_cloud.colors = colors_flat[combined_mask]
@gui_points_conf.on_update
def _(_) -> None:
update_point_cloud()
@gui_frame_selector.on_update
def _(_) -> None:
update_point_cloud()
@gui_show_frames.on_update
def _(_) -> None:
for f in frames:
f.visible = gui_show_frames.value
for fr in frustums:
fr.visible = gui_show_frames.value
# Add camera frames
import torch
if torch.is_tensor(cam_to_world):
cam_to_world_np = cam_to_world.cpu().numpy()
else:
cam_to_world_np = cam_to_world
visualize_frames(cam_to_world_np, images)
print("Starting viser server...")
if background_mode:
def server_loop():
while True:
time.sleep(0.001)
thread = threading.Thread(target=server_loop, daemon=True)
thread.start()
else:
while True:
time.sleep(0.01)
return server