# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ GLB 3D export utilities for GCT predictions. """ import os import copy from typing import Optional, Tuple import numpy as np import cv2 import matplotlib from scipy.spatial.transform import Rotation from lingbot_map.vis.sky_segmentation import ( _SKYSEG_INPUT_SIZE, _SKYSEG_SOFT_THRESHOLD, _mask_to_float, _mask_to_uint8, _result_map_to_non_sky_conf, ) try: import trimesh except ImportError: trimesh = None print("trimesh not found. GLB export will not work.") def predictions_to_glb( predictions: dict, conf_thres: float = 50.0, filter_by_frames: str = "all", mask_black_bg: bool = False, mask_white_bg: bool = False, show_cam: bool = True, mask_sky: bool = False, target_dir: Optional[str] = None, prediction_mode: str = "Predicted Pointmap", ) -> "trimesh.Scene": """ Converts GCT predictions to a 3D scene represented as a GLB file. Args: predictions: Dictionary containing model predictions with keys: - world_points: 3D point coordinates (S, H, W, 3) - world_points_conf: Confidence scores (S, H, W) - images: Input images (S, H, W, 3) or (S, 3, H, W) - extrinsic: Camera extrinsic matrices (S, 3, 4) conf_thres: Percentage of low-confidence points to filter out filter_by_frames: Frame filter specification ("all" or frame index) mask_black_bg: Mask out black background pixels mask_white_bg: Mask out white background pixels show_cam: Include camera visualization mask_sky: Apply sky segmentation mask target_dir: Output directory for intermediate files prediction_mode: "Predicted Pointmap" or "Predicted Depthmap" Returns: trimesh.Scene: Processed 3D scene containing point cloud and cameras Raises: ValueError: If input predictions structure is invalid ImportError: If trimesh is not available """ if trimesh is None: raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh") if not isinstance(predictions, dict): raise ValueError("predictions must be a dictionary") if conf_thres is None: conf_thres = 10.0 print("Building GLB scene") # Parse frame filter selected_frame_idx = None if filter_by_frames != "all" and filter_by_frames != "All": try: selected_frame_idx = int(filter_by_frames.split(":")[0]) except (ValueError, IndexError): pass # Select prediction source if "Pointmap" in prediction_mode: print("Using Pointmap Branch") if "world_points" in predictions: pred_world_points = predictions["world_points"] pred_world_points_conf = predictions.get( "world_points_conf", np.ones_like(pred_world_points[..., 0]) ) else: print("Warning: world_points not found, falling back to depth-based points") pred_world_points = predictions["world_points_from_depth"] pred_world_points_conf = predictions.get( "depth_conf", np.ones_like(pred_world_points[..., 0]) ) else: print("Using Depthmap and Camera Branch") pred_world_points = predictions["world_points_from_depth"] pred_world_points_conf = predictions.get( "depth_conf", np.ones_like(pred_world_points[..., 0]) ) images = predictions["images"] camera_matrices = predictions["extrinsic"] # Apply sky segmentation if enabled if mask_sky and target_dir is not None: pred_world_points_conf = _apply_sky_mask( pred_world_points_conf, target_dir, images ) # Apply frame filter if selected_frame_idx is not None: pred_world_points = pred_world_points[selected_frame_idx][None] pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] images = images[selected_frame_idx][None] camera_matrices = camera_matrices[selected_frame_idx][None] # Prepare vertices and colors vertices_3d = pred_world_points.reshape(-1, 3) # Handle different image formats if images.ndim == 4 and images.shape[1] == 3: # NCHW format colors_rgb = np.transpose(images, (0, 2, 3, 1)) else: colors_rgb = images colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) # Apply confidence filtering conf = pred_world_points_conf.reshape(-1) conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0 conf_mask = (conf >= conf_threshold) & (conf > 1e-5) # Apply background masking if mask_black_bg: black_bg_mask = colors_rgb.sum(axis=1) >= 16 conf_mask = conf_mask & black_bg_mask if mask_white_bg: white_bg_mask = ~( (colors_rgb[:, 0] > 240) & (colors_rgb[:, 1] > 240) & (colors_rgb[:, 2] > 240) ) conf_mask = conf_mask & white_bg_mask vertices_3d = vertices_3d[conf_mask] colors_rgb = colors_rgb[conf_mask] # Handle empty point cloud if vertices_3d is None or np.asarray(vertices_3d).size == 0: vertices_3d = np.array([[1, 0, 0]]) colors_rgb = np.array([[255, 255, 255]]) scene_scale = 1 else: lower_percentile = np.percentile(vertices_3d, 5, axis=0) upper_percentile = np.percentile(vertices_3d, 95, axis=0) scene_scale = np.linalg.norm(upper_percentile - lower_percentile) colormap = matplotlib.colormaps.get_cmap("gist_rainbow") # Build scene scene_3d = trimesh.Scene() point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) scene_3d.add_geometry(point_cloud_data) # Prepare camera matrices num_cameras = len(camera_matrices) extrinsics_matrices = np.zeros((num_cameras, 4, 4)) extrinsics_matrices[:, :3, :4] = camera_matrices extrinsics_matrices[:, 3, 3] = 1 # Add cameras if show_cam: for i in range(num_cameras): world_to_camera = extrinsics_matrices[i] camera_to_world = np.linalg.inv(world_to_camera) rgba_color = colormap(i / num_cameras) current_color = tuple(int(255 * x) for x in rgba_color[:3]) integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) # Align scene scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) print("GLB Scene built") return scene_3d def _apply_sky_mask( conf: np.ndarray, target_dir: str, images: np.ndarray ) -> np.ndarray: """Apply sky segmentation mask to confidence scores.""" try: import onnxruntime except ImportError: print("Warning: onnxruntime not available, skipping sky masking") return conf target_dir_images = os.path.join(target_dir, "images") if not os.path.exists(target_dir_images): print(f"Warning: Images directory not found at {target_dir_images}") return conf image_list = sorted(os.listdir(target_dir_images)) S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2]) skyseg_model_path = "skyseg.onnx" if not os.path.exists(skyseg_model_path): print("Downloading skyseg.onnx...") download_file_from_url( "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", skyseg_model_path ) skyseg_session = onnxruntime.InferenceSession(skyseg_model_path) sky_mask_list = [] for i, image_name in enumerate(image_list[:S]): image_filepath = os.path.join(target_dir_images, image_name) mask_filepath = os.path.join(target_dir, "sky_masks", image_name) if os.path.exists(mask_filepath): sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) else: sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) if sky_mask.shape[0] != H or sky_mask.shape[1] != W: sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR) sky_mask_list.append(_mask_to_float(sky_mask)) sky_mask_array = np.array(sky_mask_list) sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32) return conf * sky_mask_binary def integrate_camera_into_scene( scene: "trimesh.Scene", transform: np.ndarray, face_colors: Tuple[int, int, int], scene_scale: float, frustum_thickness: float = 1.0, ): """ Integrates a camera mesh into the 3D scene. Args: scene: The 3D scene to add the camera model transform: Transformation matrix for camera positioning face_colors: RGB color tuple for the camera scene_scale: Scale of the scene frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker) """ cam_width = scene_scale * 0.05 cam_height = scene_scale * 0.1 rot_45_degree = np.eye(4) rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() rot_45_degree[2, 3] = -cam_height opengl_transform = get_opengl_conversion_matrix() complete_transform = transform @ opengl_transform @ rot_45_degree camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) # Build thicker frustum by stacking rotated copies slight_rotation = np.eye(4) slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() shell_scales = [1.0, 0.95] shell_transforms = [np.eye(4), slight_rotation] # Add extra shells for thickness if frustum_thickness > 1.0: n_extra = max(1, int(frustum_thickness - 1)) for k in range(1, n_extra + 1): # Progressively rotated and scaled copies angle = 2.0 + k * 2.0 scale = 1.0 + k * 0.02 rot = np.eye(4) rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix() shell_scales.append(scale) shell_transforms.append(rot) rot_neg = np.eye(4) rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix() shell_scales.append(scale) shell_transforms.append(rot_neg) vertices_parts = [] for s, t_mat in zip(shell_scales, shell_transforms): vertices_parts.append( transform_points(t_mat, s * camera_cone_shape.vertices) ) vertices_combined = np.concatenate(vertices_parts) vertices_transformed = transform_points(complete_transform, vertices_combined) mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales)) camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) camera_mesh.visual.face_colors[:, :3] = face_colors scene.add_geometry(camera_mesh) def apply_scene_alignment( scene_3d: "trimesh.Scene", extrinsics_matrices: np.ndarray ) -> "trimesh.Scene": """ Aligns the 3D scene based on the extrinsics of the first camera. Args: scene_3d: The 3D scene to be aligned extrinsics_matrices: Camera extrinsic matrices Returns: Aligned 3D scene """ opengl_conversion_matrix = get_opengl_conversion_matrix() align_rotation = np.eye(4) align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() initial_transformation = ( np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation ) scene_3d.apply_transform(initial_transformation) return scene_3d def get_opengl_conversion_matrix() -> np.ndarray: """Returns the OpenGL conversion matrix (flips Y and Z axes).""" matrix = np.identity(4) matrix[1, 1] = -1 matrix[2, 2] = -1 return matrix def transform_points( transformation: np.ndarray, points: np.ndarray, dim: Optional[int] = None ) -> np.ndarray: """ Applies a 4x4 transformation to a set of points. Args: transformation: Transformation matrix points: Points to be transformed dim: Dimension for reshaping the result Returns: Transformed points """ points = np.asarray(points) initial_shape = points.shape[:-1] dim = dim or points.shape[-1] transformation = transformation.swapaxes(-1, -2) points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] return points[..., :dim].reshape(*initial_shape, dim) def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray: """Computes the faces for the camera mesh.""" faces_list = [] num_vertices_cone = len(cone_shape.vertices) for face in cone_shape.faces: if 0 in face: continue v1, v2, v3 = face v1_offset, v2_offset, v3_offset = face + num_vertices_cone v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone faces_list.extend([ (v1, v2, v2_offset), (v1, v1_offset, v3), (v3_offset, v2, v3), (v1, v2, v2_offset_2), (v1, v1_offset_2, v3), (v3_offset_2, v2, v3), ]) faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] return np.array(faces_list) def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray: """Computes faces for a camera mesh with multiple shells (for thicker frustums). Connects each consecutive pair of vertex shells to form the frustum edges. """ faces_list = [] nv = len(cone_shape.vertices) for s in range(num_shells - 1): off_a = s * nv off_b = (s + 1) * nv for face in cone_shape.faces: if 0 in face: continue v1, v2, v3 = face faces_list.extend([ (v1 + off_a, v2 + off_a, v2 + off_b), (v1 + off_a, v1 + off_b, v3 + off_a), (v3 + off_b, v2 + off_a, v3 + off_a), ]) faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] return np.array(faces_list) def segment_sky( image_path: str, onnx_session, mask_filename: str ) -> np.ndarray: """ Segments sky from an image using an ONNX model. Args: image_path: Path to input image onnx_session: ONNX runtime session with loaded model mask_filename: Path to save the output mask Returns: Continuous non-sky confidence map in [0, 1] """ image = cv2.imread(image_path) result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image) result_map_original = cv2.resize( result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR ) output_mask = _result_map_to_non_sky_conf(result_map_original) os.makedirs(os.path.dirname(mask_filename), exist_ok=True) cv2.imwrite(mask_filename, _mask_to_uint8(output_mask)) return output_mask def run_skyseg( onnx_session, input_size: Tuple[int, int], image: np.ndarray ) -> np.ndarray: """ Runs sky segmentation inference using ONNX model. Args: onnx_session: ONNX runtime session input_size: Target size for model input (width, height) image: Input image in BGR format Returns: Segmentation mask """ temp_image = copy.deepcopy(image) resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) x = np.array(x, dtype=np.float32) mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] x = (x / 255 - mean) / std x = x.transpose(2, 0, 1) x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") input_name = onnx_session.get_inputs()[0].name output_name = onnx_session.get_outputs()[0].name onnx_result = onnx_session.run([output_name], {input_name: x}) onnx_result = np.array(onnx_result).squeeze() min_value = np.min(onnx_result) max_value = np.max(onnx_result) onnx_result = (onnx_result - min_value) / (max_value - min_value) onnx_result *= 255 return onnx_result.astype("uint8") def download_file_from_url(url: str, filename: str): """Downloads a file from a URL, handling redirects.""" import requests try: response = requests.get(url, allow_redirects=False) response.raise_for_status() if response.status_code == 302: redirect_url = response.headers["Location"] response = requests.get(redirect_url, stream=True) response.raise_for_status() else: print(f"Unexpected status code: {response.status_code}") return with open(filename, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Downloaded {filename} successfully.") except requests.exceptions.RequestException as e: print(f"Error downloading file: {e}")