first commit
This commit is contained in:
59
lingbot_map/vis/__init__.py
Normal file
59
lingbot_map/vis/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
GCT Visualization Module
|
||||
|
||||
This module provides visualization utilities for 3D reconstruction results:
|
||||
- PointCloudViewer: Interactive point cloud viewer with camera visualization
|
||||
- viser_wrapper: Quick visualization wrapper for predictions
|
||||
- predictions_to_glb: Export predictions to GLB 3D format
|
||||
- Colorization and utility functions
|
||||
|
||||
Usage:
|
||||
from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
|
||||
|
||||
# Interactive visualization
|
||||
viewer = PointCloudViewer(pred_dict=predictions, port=8080)
|
||||
viewer.run()
|
||||
|
||||
# Quick visualization
|
||||
viser_wrapper(predictions, port=8080)
|
||||
|
||||
# Export to GLB
|
||||
scene = predictions_to_glb(predictions)
|
||||
scene.export("output.glb")
|
||||
"""
|
||||
|
||||
from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
|
||||
from lingbot_map.vis.viser_wrapper import viser_wrapper
|
||||
from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
|
||||
from lingbot_map.vis.sky_segmentation import (
|
||||
apply_sky_segmentation,
|
||||
download_skyseg_model,
|
||||
load_or_create_sky_masks,
|
||||
segment_sky,
|
||||
)
|
||||
from lingbot_map.vis.glb_export import predictions_to_glb
|
||||
|
||||
__all__ = [
|
||||
# Main viewer
|
||||
"PointCloudViewer",
|
||||
# Quick visualization
|
||||
"viser_wrapper",
|
||||
# GLB export
|
||||
"predictions_to_glb",
|
||||
# Utilities
|
||||
"CameraState",
|
||||
"colorize",
|
||||
"colorize_np",
|
||||
"get_vertical_colorbar",
|
||||
# Sky segmentation
|
||||
"apply_sky_segmentation",
|
||||
"segment_sky",
|
||||
"download_skyseg_model",
|
||||
"load_or_create_sky_masks",
|
||||
]
|
||||
509
lingbot_map/vis/glb_export.py
Normal file
509
lingbot_map/vis/glb_export.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
GLB 3D export utilities for GCT predictions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import matplotlib
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lingbot_map.vis.sky_segmentation import (
|
||||
_SKYSEG_INPUT_SIZE,
|
||||
_SKYSEG_SOFT_THRESHOLD,
|
||||
_mask_to_float,
|
||||
_mask_to_uint8,
|
||||
_result_map_to_non_sky_conf,
|
||||
)
|
||||
|
||||
try:
|
||||
import trimesh
|
||||
except ImportError:
|
||||
trimesh = None
|
||||
print("trimesh not found. GLB export will not work.")
|
||||
|
||||
|
||||
def predictions_to_glb(
|
||||
predictions: dict,
|
||||
conf_thres: float = 50.0,
|
||||
filter_by_frames: str = "all",
|
||||
mask_black_bg: bool = False,
|
||||
mask_white_bg: bool = False,
|
||||
show_cam: bool = True,
|
||||
mask_sky: bool = False,
|
||||
target_dir: Optional[str] = None,
|
||||
prediction_mode: str = "Predicted Pointmap",
|
||||
) -> "trimesh.Scene":
|
||||
"""
|
||||
Converts GCT predictions to a 3D scene represented as a GLB file.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary containing model predictions with keys:
|
||||
- world_points: 3D point coordinates (S, H, W, 3)
|
||||
- world_points_conf: Confidence scores (S, H, W)
|
||||
- images: Input images (S, H, W, 3) or (S, 3, H, W)
|
||||
- extrinsic: Camera extrinsic matrices (S, 3, 4)
|
||||
conf_thres: Percentage of low-confidence points to filter out
|
||||
filter_by_frames: Frame filter specification ("all" or frame index)
|
||||
mask_black_bg: Mask out black background pixels
|
||||
mask_white_bg: Mask out white background pixels
|
||||
show_cam: Include camera visualization
|
||||
mask_sky: Apply sky segmentation mask
|
||||
target_dir: Output directory for intermediate files
|
||||
prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
|
||||
|
||||
Returns:
|
||||
trimesh.Scene: Processed 3D scene containing point cloud and cameras
|
||||
|
||||
Raises:
|
||||
ValueError: If input predictions structure is invalid
|
||||
ImportError: If trimesh is not available
|
||||
"""
|
||||
if trimesh is None:
|
||||
raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
|
||||
|
||||
if not isinstance(predictions, dict):
|
||||
raise ValueError("predictions must be a dictionary")
|
||||
|
||||
if conf_thres is None:
|
||||
conf_thres = 10.0
|
||||
|
||||
print("Building GLB scene")
|
||||
|
||||
# Parse frame filter
|
||||
selected_frame_idx = None
|
||||
if filter_by_frames != "all" and filter_by_frames != "All":
|
||||
try:
|
||||
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Select prediction source
|
||||
if "Pointmap" in prediction_mode:
|
||||
print("Using Pointmap Branch")
|
||||
if "world_points" in predictions:
|
||||
pred_world_points = predictions["world_points"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"world_points_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
else:
|
||||
print("Warning: world_points not found, falling back to depth-based points")
|
||||
pred_world_points = predictions["world_points_from_depth"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
else:
|
||||
print("Using Depthmap and Camera Branch")
|
||||
pred_world_points = predictions["world_points_from_depth"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
|
||||
images = predictions["images"]
|
||||
camera_matrices = predictions["extrinsic"]
|
||||
|
||||
# Apply sky segmentation if enabled
|
||||
if mask_sky and target_dir is not None:
|
||||
pred_world_points_conf = _apply_sky_mask(
|
||||
pred_world_points_conf, target_dir, images
|
||||
)
|
||||
|
||||
# Apply frame filter
|
||||
if selected_frame_idx is not None:
|
||||
pred_world_points = pred_world_points[selected_frame_idx][None]
|
||||
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
||||
images = images[selected_frame_idx][None]
|
||||
camera_matrices = camera_matrices[selected_frame_idx][None]
|
||||
|
||||
# Prepare vertices and colors
|
||||
vertices_3d = pred_world_points.reshape(-1, 3)
|
||||
|
||||
# Handle different image formats
|
||||
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
|
||||
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
||||
else:
|
||||
colors_rgb = images
|
||||
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
||||
|
||||
# Apply confidence filtering
|
||||
conf = pred_world_points_conf.reshape(-1)
|
||||
conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
|
||||
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
||||
|
||||
# Apply background masking
|
||||
if mask_black_bg:
|
||||
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
||||
conf_mask = conf_mask & black_bg_mask
|
||||
|
||||
if mask_white_bg:
|
||||
white_bg_mask = ~(
|
||||
(colors_rgb[:, 0] > 240) &
|
||||
(colors_rgb[:, 1] > 240) &
|
||||
(colors_rgb[:, 2] > 240)
|
||||
)
|
||||
conf_mask = conf_mask & white_bg_mask
|
||||
|
||||
vertices_3d = vertices_3d[conf_mask]
|
||||
colors_rgb = colors_rgb[conf_mask]
|
||||
|
||||
# Handle empty point cloud
|
||||
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
||||
vertices_3d = np.array([[1, 0, 0]])
|
||||
colors_rgb = np.array([[255, 255, 255]])
|
||||
scene_scale = 1
|
||||
else:
|
||||
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
||||
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
||||
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
||||
|
||||
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
||||
|
||||
# Build scene
|
||||
scene_3d = trimesh.Scene()
|
||||
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
||||
scene_3d.add_geometry(point_cloud_data)
|
||||
|
||||
# Prepare camera matrices
|
||||
num_cameras = len(camera_matrices)
|
||||
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
||||
extrinsics_matrices[:, :3, :4] = camera_matrices
|
||||
extrinsics_matrices[:, 3, 3] = 1
|
||||
|
||||
# Add cameras
|
||||
if show_cam:
|
||||
for i in range(num_cameras):
|
||||
world_to_camera = extrinsics_matrices[i]
|
||||
camera_to_world = np.linalg.inv(world_to_camera)
|
||||
rgba_color = colormap(i / num_cameras)
|
||||
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
||||
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
|
||||
|
||||
# Align scene
|
||||
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
||||
|
||||
print("GLB Scene built")
|
||||
return scene_3d
|
||||
|
||||
|
||||
def _apply_sky_mask(
|
||||
conf: np.ndarray,
|
||||
target_dir: str,
|
||||
images: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Apply sky segmentation mask to confidence scores."""
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
print("Warning: onnxruntime not available, skipping sky masking")
|
||||
return conf
|
||||
|
||||
target_dir_images = os.path.join(target_dir, "images")
|
||||
if not os.path.exists(target_dir_images):
|
||||
print(f"Warning: Images directory not found at {target_dir_images}")
|
||||
return conf
|
||||
|
||||
image_list = sorted(os.listdir(target_dir_images))
|
||||
S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
|
||||
|
||||
skyseg_model_path = "skyseg.onnx"
|
||||
if not os.path.exists(skyseg_model_path):
|
||||
print("Downloading skyseg.onnx...")
|
||||
download_file_from_url(
|
||||
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
|
||||
skyseg_model_path
|
||||
)
|
||||
|
||||
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
||||
sky_mask_list = []
|
||||
|
||||
for i, image_name in enumerate(image_list[:S]):
|
||||
image_filepath = os.path.join(target_dir_images, image_name)
|
||||
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
|
||||
|
||||
if os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
else:
|
||||
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
|
||||
|
||||
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
|
||||
sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
sky_mask_list.append(_mask_to_float(sky_mask))
|
||||
|
||||
sky_mask_array = np.array(sky_mask_list)
|
||||
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
||||
return conf * sky_mask_binary
|
||||
|
||||
|
||||
def integrate_camera_into_scene(
|
||||
scene: "trimesh.Scene",
|
||||
transform: np.ndarray,
|
||||
face_colors: Tuple[int, int, int],
|
||||
scene_scale: float,
|
||||
frustum_thickness: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Integrates a camera mesh into the 3D scene.
|
||||
|
||||
Args:
|
||||
scene: The 3D scene to add the camera model
|
||||
transform: Transformation matrix for camera positioning
|
||||
face_colors: RGB color tuple for the camera
|
||||
scene_scale: Scale of the scene
|
||||
frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
|
||||
"""
|
||||
cam_width = scene_scale * 0.05
|
||||
cam_height = scene_scale * 0.1
|
||||
|
||||
rot_45_degree = np.eye(4)
|
||||
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
||||
rot_45_degree[2, 3] = -cam_height
|
||||
|
||||
opengl_transform = get_opengl_conversion_matrix()
|
||||
complete_transform = transform @ opengl_transform @ rot_45_degree
|
||||
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
||||
|
||||
# Build thicker frustum by stacking rotated copies
|
||||
slight_rotation = np.eye(4)
|
||||
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
||||
|
||||
shell_scales = [1.0, 0.95]
|
||||
shell_transforms = [np.eye(4), slight_rotation]
|
||||
# Add extra shells for thickness
|
||||
if frustum_thickness > 1.0:
|
||||
n_extra = max(1, int(frustum_thickness - 1))
|
||||
for k in range(1, n_extra + 1):
|
||||
# Progressively rotated and scaled copies
|
||||
angle = 2.0 + k * 2.0
|
||||
scale = 1.0 + k * 0.02
|
||||
rot = np.eye(4)
|
||||
rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
|
||||
shell_scales.append(scale)
|
||||
shell_transforms.append(rot)
|
||||
rot_neg = np.eye(4)
|
||||
rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
|
||||
shell_scales.append(scale)
|
||||
shell_transforms.append(rot_neg)
|
||||
|
||||
vertices_parts = []
|
||||
for s, t_mat in zip(shell_scales, shell_transforms):
|
||||
vertices_parts.append(
|
||||
transform_points(t_mat, s * camera_cone_shape.vertices)
|
||||
)
|
||||
vertices_combined = np.concatenate(vertices_parts)
|
||||
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
||||
|
||||
mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
|
||||
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
||||
camera_mesh.visual.face_colors[:, :3] = face_colors
|
||||
scene.add_geometry(camera_mesh)
|
||||
|
||||
|
||||
def apply_scene_alignment(
|
||||
scene_3d: "trimesh.Scene",
|
||||
extrinsics_matrices: np.ndarray
|
||||
) -> "trimesh.Scene":
|
||||
"""
|
||||
Aligns the 3D scene based on the extrinsics of the first camera.
|
||||
|
||||
Args:
|
||||
scene_3d: The 3D scene to be aligned
|
||||
extrinsics_matrices: Camera extrinsic matrices
|
||||
|
||||
Returns:
|
||||
Aligned 3D scene
|
||||
"""
|
||||
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
||||
|
||||
align_rotation = np.eye(4)
|
||||
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
|
||||
|
||||
initial_transformation = (
|
||||
np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
|
||||
)
|
||||
scene_3d.apply_transform(initial_transformation)
|
||||
return scene_3d
|
||||
|
||||
|
||||
def get_opengl_conversion_matrix() -> np.ndarray:
|
||||
"""Returns the OpenGL conversion matrix (flips Y and Z axes)."""
|
||||
matrix = np.identity(4)
|
||||
matrix[1, 1] = -1
|
||||
matrix[2, 2] = -1
|
||||
return matrix
|
||||
|
||||
|
||||
def transform_points(
|
||||
transformation: np.ndarray,
|
||||
points: np.ndarray,
|
||||
dim: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Applies a 4x4 transformation to a set of points.
|
||||
|
||||
Args:
|
||||
transformation: Transformation matrix
|
||||
points: Points to be transformed
|
||||
dim: Dimension for reshaping the result
|
||||
|
||||
Returns:
|
||||
Transformed points
|
||||
"""
|
||||
points = np.asarray(points)
|
||||
initial_shape = points.shape[:-1]
|
||||
dim = dim or points.shape[-1]
|
||||
|
||||
transformation = transformation.swapaxes(-1, -2)
|
||||
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
||||
|
||||
return points[..., :dim].reshape(*initial_shape, dim)
|
||||
|
||||
|
||||
def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
|
||||
"""Computes the faces for the camera mesh."""
|
||||
faces_list = []
|
||||
num_vertices_cone = len(cone_shape.vertices)
|
||||
|
||||
for face in cone_shape.faces:
|
||||
if 0 in face:
|
||||
continue
|
||||
v1, v2, v3 = face
|
||||
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
||||
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
||||
|
||||
faces_list.extend([
|
||||
(v1, v2, v2_offset),
|
||||
(v1, v1_offset, v3),
|
||||
(v3_offset, v2, v3),
|
||||
(v1, v2, v2_offset_2),
|
||||
(v1, v1_offset_2, v3),
|
||||
(v3_offset_2, v2, v3),
|
||||
])
|
||||
|
||||
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
||||
return np.array(faces_list)
|
||||
|
||||
|
||||
def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
|
||||
"""Computes faces for a camera mesh with multiple shells (for thicker frustums).
|
||||
|
||||
Connects each consecutive pair of vertex shells to form the frustum edges.
|
||||
"""
|
||||
faces_list = []
|
||||
nv = len(cone_shape.vertices)
|
||||
|
||||
for s in range(num_shells - 1):
|
||||
off_a = s * nv
|
||||
off_b = (s + 1) * nv
|
||||
for face in cone_shape.faces:
|
||||
if 0 in face:
|
||||
continue
|
||||
v1, v2, v3 = face
|
||||
faces_list.extend([
|
||||
(v1 + off_a, v2 + off_a, v2 + off_b),
|
||||
(v1 + off_a, v1 + off_b, v3 + off_a),
|
||||
(v3 + off_b, v2 + off_a, v3 + off_a),
|
||||
])
|
||||
|
||||
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
||||
return np.array(faces_list)
|
||||
|
||||
|
||||
def segment_sky(
|
||||
image_path: str,
|
||||
onnx_session,
|
||||
mask_filename: str
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segments sky from an image using an ONNX model.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
onnx_session: ONNX runtime session with loaded model
|
||||
mask_filename: Path to save the output mask
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1]
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
|
||||
result_map_original = cv2.resize(
|
||||
result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
output_mask = _result_map_to_non_sky_conf(result_map_original)
|
||||
|
||||
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
||||
cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
|
||||
return output_mask
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
input_size: Tuple[int, int],
|
||||
image: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Runs sky segmentation inference using ONNX model.
|
||||
|
||||
Args:
|
||||
onnx_session: ONNX runtime session
|
||||
input_size: Target size for model input (width, height)
|
||||
image: Input image in BGR format
|
||||
|
||||
Returns:
|
||||
Segmentation mask
|
||||
"""
|
||||
temp_image = copy.deepcopy(image)
|
||||
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
||||
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
||||
x = np.array(x, dtype=np.float32)
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
x = (x / 255 - mean) / std
|
||||
x = x.transpose(2, 0, 1)
|
||||
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
||||
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
output_name = onnx_session.get_outputs()[0].name
|
||||
onnx_result = onnx_session.run([output_name], {input_name: x})
|
||||
|
||||
onnx_result = np.array(onnx_result).squeeze()
|
||||
min_value = np.min(onnx_result)
|
||||
max_value = np.max(onnx_result)
|
||||
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
||||
onnx_result *= 255
|
||||
return onnx_result.astype("uint8")
|
||||
|
||||
|
||||
def download_file_from_url(url: str, filename: str):
|
||||
"""Downloads a file from a URL, handling redirects."""
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.get(url, allow_redirects=False)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 302:
|
||||
redirect_url = response.headers["Location"]
|
||||
response = requests.get(redirect_url, stream=True)
|
||||
response.raise_for_status()
|
||||
else:
|
||||
print(f"Unexpected status code: {response.status_code}")
|
||||
return
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Downloaded {filename} successfully.")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error downloading file: {e}")
|
||||
1780
lingbot_map/vis/point_cloud_viewer.py
Normal file
1780
lingbot_map/vis/point_cloud_viewer.py
Normal file
File diff suppressed because it is too large
Load Diff
473
lingbot_map/vis/sky_segmentation.py
Normal file
473
lingbot_map/vis/sky_segmentation.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Sky segmentation utilities for filtering sky points from point clouds.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
onnxruntime = None
|
||||
print("onnxruntime not found. Sky segmentation may not work.")
|
||||
|
||||
|
||||
_SKYSEG_INPUT_SIZE = (320, 320)
|
||||
_SKYSEG_SOFT_THRESHOLD = 0.1
|
||||
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
|
||||
|
||||
|
||||
def _get_cache_version_path(sky_mask_dir: str) -> str:
|
||||
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
||||
|
||||
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
|
||||
if sky_mask_dir is None:
|
||||
return False
|
||||
|
||||
os.makedirs(sky_mask_dir, exist_ok=True)
|
||||
version_path = _get_cache_version_path(sky_mask_dir)
|
||||
refresh_cache = True
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
|
||||
|
||||
if refresh_cache:
|
||||
print(
|
||||
f"Sky mask cache at {sky_mask_dir} uses an older format; "
|
||||
"regenerating masks with ImageNet-normalized skyseg input"
|
||||
)
|
||||
with open(version_path, "w", encoding="utf-8") as f:
|
||||
f.write(_SKYSEG_CACHE_VERSION)
|
||||
|
||||
return refresh_cache
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
input_size: Tuple[int, int],
|
||||
image: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
|
||||
"""
|
||||
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
|
||||
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
x = (x / 255.0 - mean) / std
|
||||
x = x.transpose(2, 0, 1)
|
||||
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
|
||||
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
output_name = onnx_session.get_outputs()[0].name
|
||||
onnx_result = onnx_session.run([output_name], {input_name: x})
|
||||
|
||||
onnx_result = np.array(onnx_result).squeeze()
|
||||
min_value = np.min(onnx_result)
|
||||
max_value = np.max(onnx_result)
|
||||
denom = max(max_value - min_value, 1e-8)
|
||||
onnx_result = (onnx_result - min_value) / denom
|
||||
onnx_result *= 255.0
|
||||
return onnx_result.astype(np.uint8)
|
||||
|
||||
|
||||
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
|
||||
mask = mask.astype(np.float32)
|
||||
if mask.size == 0:
|
||||
return mask
|
||||
return np.clip(mask, 0.0, 1.0)
|
||||
|
||||
|
||||
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
|
||||
mask = np.asarray(mask)
|
||||
if mask.dtype == np.uint8:
|
||||
return mask
|
||||
mask = mask.astype(np.float32)
|
||||
if mask.size > 0 and mask.max() <= 1.0:
|
||||
mask = mask * 255.0
|
||||
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
|
||||
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
|
||||
# The raw skyseg map is higher on sky and lower on non-sky.
|
||||
return 1.0 - _mask_to_float(result_map)
|
||||
|
||||
|
||||
def segment_sky_from_array(
|
||||
image: np.ndarray,
|
||||
skyseg_session,
|
||||
target_h: int,
|
||||
target_w: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segment sky from an image array using ONNX model.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
|
||||
skyseg_session: ONNX runtime inference session
|
||||
target_h: Target output height
|
||||
target_w: Target output width
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1].
|
||||
"""
|
||||
image_rgb = _image_to_rgb_uint8(image)
|
||||
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
||||
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
|
||||
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||||
return _result_map_to_non_sky_conf(result_map)
|
||||
|
||||
|
||||
def segment_sky(
|
||||
image_path: str,
|
||||
skyseg_session,
|
||||
output_path: Optional[str] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segment sky from an image using ONNX model.
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
skyseg_session: ONNX runtime inference session
|
||||
output_path: Optional path to save the mask
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1].
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to read image: {image_path}")
|
||||
|
||||
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
|
||||
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
mask = _result_map_to_non_sky_conf(result_map)
|
||||
|
||||
if output_path is not None:
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
cv2.imwrite(output_path, _mask_to_uint8(mask))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _list_image_files(image_folder: str) -> list[str]:
|
||||
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
|
||||
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
|
||||
|
||||
|
||||
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
|
||||
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
|
||||
image = image.transpose(1, 2, 0)
|
||||
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
image = image.astype(np.float32)
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
|
||||
if image_paths is not None and index < len(image_paths):
|
||||
return os.path.basename(image_paths[index])
|
||||
return f"frame_{index:06d}.png"
|
||||
|
||||
|
||||
def _save_sky_mask_visualization(
|
||||
image: np.ndarray,
|
||||
sky_mask: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
image_rgb = _image_to_rgb_uint8(image)
|
||||
if sky_mask.shape[:2] != image_rgb.shape[:2]:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(image_rgb.shape[1], image_rgb.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
mask_uint8 = _mask_to_uint8(sky_mask)
|
||||
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
|
||||
overlay = image_rgb.astype(np.float32).copy()
|
||||
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
|
||||
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
|
||||
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
|
||||
|
||||
|
||||
def load_or_create_sky_masks(
|
||||
image_folder: Optional[str] = None,
|
||||
image_paths: Optional[list[str]] = None,
|
||||
images: Optional[np.ndarray] = None,
|
||||
skyseg_model_path: str = "skyseg.onnx",
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
target_shape: Optional[Tuple[int, int]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Load cached sky masks or generate them with the ONNX model.
|
||||
|
||||
Args:
|
||||
image_folder: Folder containing input images.
|
||||
image_paths: Optional explicit image file list, in the exact order to process.
|
||||
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
|
||||
skyseg_model_path: Path to the sky segmentation ONNX model.
|
||||
sky_mask_dir: Optional directory for cached raw masks.
|
||||
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
|
||||
target_shape: Optional output mask shape (H, W) after resizing.
|
||||
num_frames: Optional maximum number of frames to process.
|
||||
|
||||
Returns:
|
||||
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
|
||||
"""
|
||||
if onnxruntime is None:
|
||||
print("Warning: onnxruntime not available, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if image_folder is None and image_paths is None and images is None:
|
||||
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if not os.path.exists(skyseg_model_path):
|
||||
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
|
||||
try:
|
||||
download_skyseg_model(skyseg_model_path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to download sky segmentation model: {e}")
|
||||
return None
|
||||
|
||||
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
||||
sky_masks = []
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
|
||||
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
|
||||
|
||||
if images is not None:
|
||||
if image_paths is None and image_folder is not None:
|
||||
image_paths = _list_image_files(image_folder)
|
||||
|
||||
num_images = images.shape[0]
|
||||
if num_frames is not None:
|
||||
num_images = min(num_images, num_frames)
|
||||
if image_paths is not None:
|
||||
image_paths = image_paths[:num_images]
|
||||
|
||||
if sky_mask_dir is None and image_folder is not None:
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image array...")
|
||||
for i in tqdm(range(num_images)):
|
||||
image_rgb = _image_to_rgb_uint8(images[i])
|
||||
image_h, image_w = image_rgb.shape[:2]
|
||||
image_name = _get_mask_filename(image_paths, i)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
||||
|
||||
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
elif sky_mask.shape[:2] != (image_h, image_w):
|
||||
print(
|
||||
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
|
||||
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
|
||||
)
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
else:
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
if mask_filepath is not None:
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
_save_sky_mask_visualization(
|
||||
image_rgb,
|
||||
sky_mask,
|
||||
os.path.join(sky_mask_visualization_dir, image_name),
|
||||
)
|
||||
|
||||
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(target_shape[1], target_shape[0]),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
sky_masks.append(_mask_to_float(sky_mask))
|
||||
|
||||
else:
|
||||
if image_paths is None and image_folder is not None:
|
||||
image_paths = _list_image_files(image_folder)
|
||||
|
||||
if images is None and image_paths is not None:
|
||||
if len(image_paths) == 0:
|
||||
print("Warning: No image files provided, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if num_frames is not None:
|
||||
image_paths = image_paths[:num_frames]
|
||||
|
||||
if sky_mask_dir is None:
|
||||
if image_folder is None:
|
||||
image_folder = os.path.dirname(image_paths[0])
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image files...")
|
||||
for image_path in tqdm(image_paths):
|
||||
image_name = os.path.basename(image_path)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
||||
|
||||
if not refresh_cache and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
||||
else:
|
||||
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
||||
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
|
||||
continue
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
image_bgr = cv2.imread(image_path)
|
||||
if image_bgr is not None:
|
||||
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
_save_sky_mask_visualization(
|
||||
image_rgb,
|
||||
sky_mask,
|
||||
os.path.join(sky_mask_visualization_dir, image_name),
|
||||
)
|
||||
|
||||
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(target_shape[1], target_shape[0]),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
sky_masks.append(_mask_to_float(sky_mask))
|
||||
|
||||
if len(sky_masks) == 0:
|
||||
print("Warning: No sky masks generated, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
try:
|
||||
return np.stack(sky_masks, axis=0)
|
||||
except ValueError:
|
||||
return np.array(sky_masks, dtype=object)
|
||||
|
||||
|
||||
def apply_sky_segmentation(
|
||||
conf: np.ndarray,
|
||||
image_folder: Optional[str] = None,
|
||||
image_paths: Optional[list[str]] = None,
|
||||
images: Optional[np.ndarray] = None,
|
||||
skyseg_model_path: str = "skyseg.onnx",
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply sky segmentation to confidence scores.
|
||||
|
||||
Args:
|
||||
conf: Confidence scores with shape (S, H, W)
|
||||
image_folder: Path to the folder containing input images (optional if images provided)
|
||||
image_paths: Optional explicit image file list in processing order
|
||||
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
|
||||
skyseg_model_path: Path to the sky segmentation ONNX model
|
||||
sky_mask_dir: Optional directory for cached raw masks
|
||||
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
|
||||
|
||||
Returns:
|
||||
Updated confidence scores with sky regions masked out
|
||||
"""
|
||||
S, H, W = conf.shape
|
||||
|
||||
sky_mask_array = load_or_create_sky_masks(
|
||||
image_folder=image_folder,
|
||||
image_paths=image_paths,
|
||||
images=images,
|
||||
skyseg_model_path=skyseg_model_path,
|
||||
sky_mask_dir=sky_mask_dir,
|
||||
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
||||
target_shape=(H, W),
|
||||
num_frames=S,
|
||||
)
|
||||
if sky_mask_array is None:
|
||||
return conf
|
||||
|
||||
if sky_mask_array.shape[0] < S:
|
||||
print(
|
||||
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
|
||||
"leaving the remaining frames unmasked"
|
||||
)
|
||||
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
|
||||
padded[: sky_mask_array.shape[0]] = sky_mask_array
|
||||
sky_mask_array = padded
|
||||
elif sky_mask_array.shape[0] > S:
|
||||
sky_mask_array = sky_mask_array[:S]
|
||||
|
||||
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
||||
conf = conf * sky_mask_binary
|
||||
|
||||
print("Sky segmentation applied successfully")
|
||||
return conf
|
||||
|
||||
|
||||
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
|
||||
"""
|
||||
Download sky segmentation model from HuggingFace.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the model
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
|
||||
|
||||
print(f"Downloading sky segmentation model from {url}...")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
print(f"Model saved to {output_path}")
|
||||
return output_path
|
||||
206
lingbot_map/vis/utils.py
Normal file
206
lingbot_map/vis/utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Visualization utility functions for colorization and color bars.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import matplotlib.cm as cm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CameraState:
|
||||
"""Camera state for rendering."""
|
||||
fov: float
|
||||
aspect: float
|
||||
c2w: np.ndarray
|
||||
|
||||
def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
|
||||
"""Get camera intrinsic matrix from FOV and image size."""
|
||||
W, H = img_wh
|
||||
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
|
||||
K = np.array([
|
||||
[focal_length, 0.0, W / 2.0],
|
||||
[0.0, focal_length, H / 2.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
])
|
||||
return K
|
||||
|
||||
|
||||
def get_vertical_colorbar(
|
||||
h: int,
|
||||
vmin: float,
|
||||
vmax: float,
|
||||
cmap_name: str = "jet",
|
||||
label: Optional[str] = None,
|
||||
cbar_precision: int = 2
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create a vertical colorbar image.
|
||||
|
||||
Args:
|
||||
h: Height in pixels
|
||||
vmin: Minimum value
|
||||
vmax: Maximum value
|
||||
cmap_name: Colormap name
|
||||
label: Optional label for the colorbar
|
||||
cbar_precision: Decimal precision for tick labels
|
||||
|
||||
Returns:
|
||||
Colorbar image as numpy array (H, W, 3)
|
||||
"""
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
import matplotlib as mpl
|
||||
|
||||
fig = Figure(figsize=(2, 8), dpi=100)
|
||||
fig.subplots_adjust(right=1.5)
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
|
||||
ax = fig.add_subplot(111)
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
||||
|
||||
tick_cnt = 6
|
||||
tick_loc = np.linspace(vmin, vmax, tick_cnt)
|
||||
cb1 = mpl.colorbar.ColorbarBase(
|
||||
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
|
||||
)
|
||||
|
||||
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
|
||||
if cbar_precision == 0:
|
||||
tick_label = [x[:-2] for x in tick_label]
|
||||
|
||||
cb1.set_ticklabels(tick_label)
|
||||
cb1.ax.tick_params(labelsize=18, rotation=0)
|
||||
if label is not None:
|
||||
cb1.set_label(label)
|
||||
|
||||
canvas.draw()
|
||||
s, (width, height) = canvas.print_to_buffer()
|
||||
|
||||
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
|
||||
im = im[:, :, :3].astype(np.float32) / 255.0
|
||||
|
||||
if h != im.shape[0]:
|
||||
w = int(im.shape[1] / im.shape[0] * h)
|
||||
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def colorize_np(
|
||||
x: np.ndarray,
|
||||
cmap_name: str = "jet",
|
||||
mask: Optional[np.ndarray] = None,
|
||||
range: Optional[Tuple[float, float]] = None,
|
||||
append_cbar: bool = False,
|
||||
cbar_in_image: bool = False,
|
||||
cbar_precision: int = 2,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Turn a grayscale image into a color image.
|
||||
|
||||
Args:
|
||||
x: Input grayscale image [H, W]
|
||||
cmap_name: Colormap name
|
||||
mask: Optional mask image [H, W]
|
||||
range: Value range for scaling [min, max], automatic if None
|
||||
append_cbar: Whether to append colorbar
|
||||
cbar_in_image: Put colorbar inside image
|
||||
cbar_precision: Colorbar tick precision
|
||||
|
||||
Returns:
|
||||
Colorized image [H, W, 3]
|
||||
"""
|
||||
if range is not None:
|
||||
vmin, vmax = range
|
||||
elif mask is not None:
|
||||
vmin = np.min(x[mask][np.nonzero(x[mask])])
|
||||
vmax = np.max(x[mask])
|
||||
x[np.logical_not(mask)] = vmin
|
||||
else:
|
||||
vmin, vmax = np.percentile(x, (1, 100))
|
||||
vmax += 1e-6
|
||||
|
||||
x = np.clip(x, vmin, vmax)
|
||||
x = (x - vmin) / (vmax - vmin)
|
||||
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
x_new = cmap(x)[:, :, :3]
|
||||
|
||||
if mask is not None:
|
||||
mask = np.float32(mask[:, :, np.newaxis])
|
||||
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
|
||||
|
||||
cbar = get_vertical_colorbar(
|
||||
h=x.shape[0],
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
cmap_name=cmap_name,
|
||||
cbar_precision=cbar_precision,
|
||||
)
|
||||
|
||||
if append_cbar:
|
||||
if cbar_in_image:
|
||||
x_new[:, -cbar.shape[1]:, :] = cbar
|
||||
else:
|
||||
x_new = np.concatenate(
|
||||
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
|
||||
)
|
||||
return x_new
|
||||
else:
|
||||
return x_new
|
||||
|
||||
|
||||
def colorize(
|
||||
x: torch.Tensor,
|
||||
cmap_name: str = "jet",
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
range: Optional[Tuple[float, float]] = None,
|
||||
append_cbar: bool = False,
|
||||
cbar_in_image: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Turn a grayscale image into a color image (PyTorch tensor version).
|
||||
|
||||
Args:
|
||||
x: Grayscale image tensor [H, W] or [B, H, W]
|
||||
cmap_name: Colormap name
|
||||
mask: Optional mask tensor [H, W] or [B, H, W]
|
||||
range: Value range for scaling
|
||||
append_cbar: Whether to append colorbar
|
||||
cbar_in_image: Put colorbar inside image
|
||||
|
||||
Returns:
|
||||
Colorized tensor
|
||||
"""
|
||||
device = x.device
|
||||
x = x.cpu().numpy()
|
||||
if mask is not None:
|
||||
mask = mask.cpu().numpy() > 0.99
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = x[None]
|
||||
if mask is not None:
|
||||
mask = mask[None]
|
||||
|
||||
out = []
|
||||
for x_ in x:
|
||||
if mask is not None:
|
||||
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
|
||||
|
||||
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
|
||||
out.append(torch.from_numpy(x_).to(device).float())
|
||||
out = torch.stack(out).squeeze(0)
|
||||
return out
|
||||
248
lingbot_map/vis/viser_wrapper.py
Normal file
248
lingbot_map/vis/viser_wrapper.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Quick visualization wrapper for GCT predictions using Viser.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import viser
|
||||
import viser.transforms as tf
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
||||
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
|
||||
|
||||
|
||||
def viser_wrapper(
|
||||
pred_dict: dict,
|
||||
port: int = 8080,
|
||||
init_conf_threshold: float = 50.0,
|
||||
use_point_map: bool = False,
|
||||
background_mode: bool = False,
|
||||
mask_sky: bool = False,
|
||||
image_folder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Visualize predicted 3D points and camera poses with viser.
|
||||
|
||||
This is a simplified wrapper for quick visualization without the full
|
||||
PointCloudViewer controls.
|
||||
|
||||
Args:
|
||||
pred_dict: Dictionary containing predictions with keys:
|
||||
- images: (S, 3, H, W) - Input images
|
||||
- world_points: (S, H, W, 3)
|
||||
- world_points_conf: (S, H, W)
|
||||
- depth: (S, H, W, 1)
|
||||
- depth_conf: (S, H, W)
|
||||
- extrinsic: (S, 3, 4)
|
||||
- intrinsic: (S, 3, 3)
|
||||
port: Port number for the viser server
|
||||
init_conf_threshold: Initial percentage of low-confidence points to filter out
|
||||
use_point_map: Whether to visualize world_points or use depth-based points
|
||||
background_mode: Whether to run the server in background thread
|
||||
mask_sky: Whether to apply sky segmentation to filter out sky points
|
||||
image_folder: Path to the folder containing input images (for sky segmentation)
|
||||
|
||||
Returns:
|
||||
viser.ViserServer: The viser server instance
|
||||
"""
|
||||
print(f"Starting viser server on port {port}")
|
||||
|
||||
server = viser.ViserServer(host="0.0.0.0", port=port)
|
||||
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
|
||||
|
||||
# Unpack prediction dict
|
||||
images = pred_dict["images"] # (S, 3, H, W)
|
||||
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
|
||||
conf_map = pred_dict["world_points_conf"] # (S, H, W)
|
||||
|
||||
depth_map = pred_dict["depth"] # (S, H, W, 1)
|
||||
depth_conf = pred_dict["depth_conf"] # (S, H, W)
|
||||
|
||||
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
|
||||
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
|
||||
|
||||
# Compute world points from depth if not using the precomputed point map
|
||||
if not use_point_map:
|
||||
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
|
||||
conf = depth_conf
|
||||
else:
|
||||
world_points = world_points_map
|
||||
conf = conf_map
|
||||
|
||||
# Apply sky segmentation if enabled
|
||||
if mask_sky and image_folder is not None:
|
||||
conf = apply_sky_segmentation(conf, image_folder)
|
||||
|
||||
# Convert images from (S, 3, H, W) to (S, H, W, 3)
|
||||
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
||||
shape = world_points.shape
|
||||
S: int = shape[0]
|
||||
H: int = shape[1]
|
||||
W: int = shape[2]
|
||||
|
||||
# Flatten
|
||||
points = world_points.reshape(-1, 3)
|
||||
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
|
||||
conf_flat = conf.reshape(-1)
|
||||
|
||||
# Random sample points if too many
|
||||
indices = None
|
||||
if points.shape[0] > 6000000:
|
||||
print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
|
||||
indices = np.random.choice(points.shape[0], size=6000000, replace=False)
|
||||
points = points[indices]
|
||||
colors_flat = colors_flat[indices]
|
||||
conf_flat = conf_flat[indices]
|
||||
|
||||
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
|
||||
cam_to_world = cam_to_world_mat[:, :3, :]
|
||||
|
||||
# Compute scene center and recenter
|
||||
scene_center = np.mean(points, axis=0)
|
||||
points_centered = points - scene_center
|
||||
cam_to_world[..., -1] -= scene_center
|
||||
|
||||
# Store frame indices for filtering
|
||||
frame_indices = (
|
||||
np.repeat(np.arange(S), H * W)[indices]
|
||||
if indices is not None
|
||||
else np.repeat(np.arange(S), H * W)
|
||||
)
|
||||
|
||||
# Build the viser GUI
|
||||
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
|
||||
gui_points_conf = server.gui.add_slider(
|
||||
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
|
||||
)
|
||||
gui_frame_selector = server.gui.add_dropdown(
|
||||
"Show Points from Frames",
|
||||
options=["All"] + [str(i) for i in range(S)],
|
||||
initial_value="All"
|
||||
)
|
||||
|
||||
# Create the main point cloud
|
||||
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
|
||||
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
|
||||
point_cloud = server.scene.add_point_cloud(
|
||||
name="viser_pcd",
|
||||
points=points_centered[init_conf_mask],
|
||||
colors=colors_flat[init_conf_mask],
|
||||
point_size=0.0005,
|
||||
point_shape="circle",
|
||||
)
|
||||
|
||||
frames: List[viser.FrameHandle] = []
|
||||
frustums: List[viser.CameraFrustumHandle] = []
|
||||
|
||||
def visualize_frames(extrinsics, images_: np.ndarray) -> None:
|
||||
"""Add camera frames and frustums to the scene."""
|
||||
for f in frames:
|
||||
f.remove()
|
||||
frames.clear()
|
||||
for fr in frustums:
|
||||
fr.remove()
|
||||
frustums.clear()
|
||||
|
||||
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
|
||||
@frustum.on_click
|
||||
def _(_) -> None:
|
||||
for client in server.get_clients().values():
|
||||
client.camera.wxyz = frame.wxyz
|
||||
client.camera.position = frame.position
|
||||
|
||||
for img_id in tqdm(range(S)):
|
||||
cam2world_3x4 = extrinsics[img_id]
|
||||
T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
|
||||
|
||||
frame_axis = server.scene.add_frame(
|
||||
f"frame_{img_id}",
|
||||
wxyz=T_world_camera.rotation().wxyz,
|
||||
position=T_world_camera.translation(),
|
||||
axes_length=0.05,
|
||||
axes_radius=0.002,
|
||||
origin_radius=0.002,
|
||||
)
|
||||
frames.append(frame_axis)
|
||||
|
||||
img = images_[img_id]
|
||||
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
fy = 1.1 * h
|
||||
fov = 2 * np.arctan2(h / 2, fy)
|
||||
|
||||
frustum_cam = server.scene.add_camera_frustum(
|
||||
f"frame_{img_id}/frustum",
|
||||
fov=fov,
|
||||
aspect=w / h,
|
||||
scale=0.05,
|
||||
image=img,
|
||||
line_width=1.0
|
||||
)
|
||||
frustums.append(frustum_cam)
|
||||
attach_callback(frustum_cam, frame_axis)
|
||||
|
||||
def update_point_cloud() -> None:
|
||||
"""Update point cloud based on current GUI selections."""
|
||||
current_percentage = gui_points_conf.value
|
||||
threshold_val = np.percentile(conf_flat, current_percentage)
|
||||
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
|
||||
|
||||
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
|
||||
|
||||
if gui_frame_selector.value == "All":
|
||||
frame_mask = np.ones_like(conf_mask, dtype=bool)
|
||||
else:
|
||||
selected_idx = int(gui_frame_selector.value)
|
||||
frame_mask = frame_indices == selected_idx
|
||||
|
||||
combined_mask = conf_mask & frame_mask
|
||||
point_cloud.points = points_centered[combined_mask]
|
||||
point_cloud.colors = colors_flat[combined_mask]
|
||||
|
||||
@gui_points_conf.on_update
|
||||
def _(_) -> None:
|
||||
update_point_cloud()
|
||||
|
||||
@gui_frame_selector.on_update
|
||||
def _(_) -> None:
|
||||
update_point_cloud()
|
||||
|
||||
@gui_show_frames.on_update
|
||||
def _(_) -> None:
|
||||
for f in frames:
|
||||
f.visible = gui_show_frames.value
|
||||
for fr in frustums:
|
||||
fr.visible = gui_show_frames.value
|
||||
|
||||
# Add camera frames
|
||||
import torch
|
||||
if torch.is_tensor(cam_to_world):
|
||||
cam_to_world_np = cam_to_world.cpu().numpy()
|
||||
else:
|
||||
cam_to_world_np = cam_to_world
|
||||
visualize_frames(cam_to_world_np, images)
|
||||
|
||||
print("Starting viser server...")
|
||||
if background_mode:
|
||||
def server_loop():
|
||||
while True:
|
||||
time.sleep(0.001)
|
||||
|
||||
thread = threading.Thread(target=server_loop, daemon=True)
|
||||
thread.start()
|
||||
else:
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
|
||||
return server
|
||||
Reference in New Issue
Block a user