474 lines
17 KiB
Python
474 lines
17 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Sky segmentation utilities for filtering sky points from point clouds.
|
|
"""
|
|
|
|
import glob
|
|
import os
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import cv2
|
|
from tqdm.auto import tqdm
|
|
|
|
try:
|
|
import onnxruntime
|
|
except ImportError:
|
|
onnxruntime = None
|
|
print("onnxruntime not found. Sky segmentation may not work.")
|
|
|
|
|
|
_SKYSEG_INPUT_SIZE = (320, 320)
|
|
_SKYSEG_SOFT_THRESHOLD = 0.1
|
|
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
|
|
|
|
|
|
def _get_cache_version_path(sky_mask_dir: str) -> str:
|
|
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
|
|
|
|
|
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
|
|
if sky_mask_dir is None:
|
|
return False
|
|
|
|
os.makedirs(sky_mask_dir, exist_ok=True)
|
|
version_path = _get_cache_version_path(sky_mask_dir)
|
|
refresh_cache = True
|
|
if os.path.exists(version_path):
|
|
with open(version_path, "r", encoding="utf-8") as f:
|
|
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
|
|
|
|
if refresh_cache:
|
|
print(
|
|
f"Sky mask cache at {sky_mask_dir} uses an older format; "
|
|
"regenerating masks with ImageNet-normalized skyseg input"
|
|
)
|
|
with open(version_path, "w", encoding="utf-8") as f:
|
|
f.write(_SKYSEG_CACHE_VERSION)
|
|
|
|
return refresh_cache
|
|
|
|
|
|
def run_skyseg(
|
|
onnx_session,
|
|
input_size: Tuple[int, int],
|
|
image: np.ndarray,
|
|
) -> np.ndarray:
|
|
"""
|
|
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
|
|
"""
|
|
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
|
|
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
|
|
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
x = (x / 255.0 - mean) / std
|
|
x = x.transpose(2, 0, 1)
|
|
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
|
|
|
|
input_name = onnx_session.get_inputs()[0].name
|
|
output_name = onnx_session.get_outputs()[0].name
|
|
onnx_result = onnx_session.run([output_name], {input_name: x})
|
|
|
|
onnx_result = np.array(onnx_result).squeeze()
|
|
min_value = np.min(onnx_result)
|
|
max_value = np.max(onnx_result)
|
|
denom = max(max_value - min_value, 1e-8)
|
|
onnx_result = (onnx_result - min_value) / denom
|
|
onnx_result *= 255.0
|
|
return onnx_result.astype(np.uint8)
|
|
|
|
|
|
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
|
|
mask = mask.astype(np.float32)
|
|
if mask.size == 0:
|
|
return mask
|
|
return np.clip(mask, 0.0, 1.0)
|
|
|
|
|
|
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
|
|
mask = np.asarray(mask)
|
|
if mask.dtype == np.uint8:
|
|
return mask
|
|
mask = mask.astype(np.float32)
|
|
if mask.size > 0 and mask.max() <= 1.0:
|
|
mask = mask * 255.0
|
|
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
|
|
|
|
|
|
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
|
|
# The raw skyseg map is higher on sky and lower on non-sky.
|
|
return 1.0 - _mask_to_float(result_map)
|
|
|
|
|
|
def segment_sky_from_array(
|
|
image: np.ndarray,
|
|
skyseg_session,
|
|
target_h: int,
|
|
target_w: int
|
|
) -> np.ndarray:
|
|
"""
|
|
Segment sky from an image array using ONNX model.
|
|
|
|
Args:
|
|
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
|
|
skyseg_session: ONNX runtime inference session
|
|
target_h: Target output height
|
|
target_w: Target output width
|
|
|
|
Returns:
|
|
Continuous non-sky confidence map in [0, 1].
|
|
"""
|
|
image_rgb = _image_to_rgb_uint8(image)
|
|
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
|
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
|
|
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
|
return _result_map_to_non_sky_conf(result_map)
|
|
|
|
|
|
def segment_sky(
|
|
image_path: str,
|
|
skyseg_session,
|
|
output_path: Optional[str] = None
|
|
) -> np.ndarray:
|
|
"""
|
|
Segment sky from an image using ONNX model.
|
|
|
|
Args:
|
|
image_path: Path to the input image
|
|
skyseg_session: ONNX runtime inference session
|
|
output_path: Optional path to save the mask
|
|
|
|
Returns:
|
|
Continuous non-sky confidence map in [0, 1].
|
|
"""
|
|
image = cv2.imread(image_path)
|
|
if image is None:
|
|
raise ValueError(f"Failed to read image: {image_path}")
|
|
|
|
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
|
|
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
|
mask = _result_map_to_non_sky_conf(result_map)
|
|
|
|
if output_path is not None:
|
|
output_dir = os.path.dirname(output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
cv2.imwrite(output_path, _mask_to_uint8(mask))
|
|
|
|
return mask
|
|
|
|
|
|
def _list_image_files(image_folder: str) -> list[str]:
|
|
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
|
|
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
|
|
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
|
|
|
|
|
|
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
|
|
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
|
|
image = image.transpose(1, 2, 0)
|
|
|
|
if image.ndim != 3 or image.shape[2] != 3:
|
|
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
|
|
|
|
if image.dtype != np.uint8:
|
|
image = image.astype(np.float32)
|
|
if image.max() <= 1.0:
|
|
image = image * 255.0
|
|
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
|
|
|
|
return image
|
|
|
|
|
|
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
|
|
if image_paths is not None and index < len(image_paths):
|
|
return os.path.basename(image_paths[index])
|
|
return f"frame_{index:06d}.png"
|
|
|
|
|
|
def _save_sky_mask_visualization(
|
|
image: np.ndarray,
|
|
sky_mask: np.ndarray,
|
|
output_path: str,
|
|
) -> None:
|
|
image_rgb = _image_to_rgb_uint8(image)
|
|
if sky_mask.shape[:2] != image_rgb.shape[:2]:
|
|
sky_mask = cv2.resize(
|
|
sky_mask,
|
|
(image_rgb.shape[1], image_rgb.shape[0]),
|
|
interpolation=cv2.INTER_NEAREST,
|
|
)
|
|
|
|
mask_uint8 = _mask_to_uint8(sky_mask)
|
|
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
|
|
overlay = image_rgb.astype(np.float32).copy()
|
|
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
|
|
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
|
|
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
|
|
|
|
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
|
|
output_dir = os.path.dirname(output_path)
|
|
if output_dir:
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
|
|
|
|
|
|
def load_or_create_sky_masks(
|
|
image_folder: Optional[str] = None,
|
|
image_paths: Optional[list[str]] = None,
|
|
images: Optional[np.ndarray] = None,
|
|
skyseg_model_path: str = "skyseg.onnx",
|
|
sky_mask_dir: Optional[str] = None,
|
|
sky_mask_visualization_dir: Optional[str] = None,
|
|
target_shape: Optional[Tuple[int, int]] = None,
|
|
num_frames: Optional[int] = None,
|
|
) -> Optional[np.ndarray]:
|
|
"""
|
|
Load cached sky masks or generate them with the ONNX model.
|
|
|
|
Args:
|
|
image_folder: Folder containing input images.
|
|
image_paths: Optional explicit image file list, in the exact order to process.
|
|
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
|
|
skyseg_model_path: Path to the sky segmentation ONNX model.
|
|
sky_mask_dir: Optional directory for cached raw masks.
|
|
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
|
|
target_shape: Optional output mask shape (H, W) after resizing.
|
|
num_frames: Optional maximum number of frames to process.
|
|
|
|
Returns:
|
|
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
|
|
"""
|
|
if onnxruntime is None:
|
|
print("Warning: onnxruntime not available, skipping sky segmentation")
|
|
return None
|
|
|
|
if image_folder is None and image_paths is None and images is None:
|
|
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
|
|
return None
|
|
|
|
if not os.path.exists(skyseg_model_path):
|
|
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
|
|
try:
|
|
download_skyseg_model(skyseg_model_path)
|
|
except Exception as e:
|
|
print(f"Warning: Failed to download sky segmentation model: {e}")
|
|
return None
|
|
|
|
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
|
sky_masks = []
|
|
|
|
if sky_mask_visualization_dir is not None:
|
|
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
|
|
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
|
|
|
|
if images is not None:
|
|
if image_paths is None and image_folder is not None:
|
|
image_paths = _list_image_files(image_folder)
|
|
|
|
num_images = images.shape[0]
|
|
if num_frames is not None:
|
|
num_images = min(num_images, num_frames)
|
|
if image_paths is not None:
|
|
image_paths = image_paths[:num_images]
|
|
|
|
if sky_mask_dir is None and image_folder is not None:
|
|
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
|
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
|
|
|
print("Generating sky masks from image array...")
|
|
for i in tqdm(range(num_images)):
|
|
image_rgb = _image_to_rgb_uint8(images[i])
|
|
image_h, image_w = image_rgb.shape[:2]
|
|
image_name = _get_mask_filename(image_paths, i)
|
|
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
|
|
|
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
|
|
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
|
if sky_mask is None:
|
|
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
|
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
|
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
|
elif sky_mask.shape[:2] != (image_h, image_w):
|
|
print(
|
|
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
|
|
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
|
|
)
|
|
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
|
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
|
else:
|
|
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
|
if mask_filepath is not None:
|
|
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
|
|
|
if sky_mask_visualization_dir is not None:
|
|
_save_sky_mask_visualization(
|
|
image_rgb,
|
|
sky_mask,
|
|
os.path.join(sky_mask_visualization_dir, image_name),
|
|
)
|
|
|
|
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
|
sky_mask = cv2.resize(
|
|
sky_mask,
|
|
(target_shape[1], target_shape[0]),
|
|
interpolation=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
sky_masks.append(_mask_to_float(sky_mask))
|
|
|
|
else:
|
|
if image_paths is None and image_folder is not None:
|
|
image_paths = _list_image_files(image_folder)
|
|
|
|
if images is None and image_paths is not None:
|
|
if len(image_paths) == 0:
|
|
print("Warning: No image files provided, skipping sky segmentation")
|
|
return None
|
|
|
|
if num_frames is not None:
|
|
image_paths = image_paths[:num_frames]
|
|
|
|
if sky_mask_dir is None:
|
|
if image_folder is None:
|
|
image_folder = os.path.dirname(image_paths[0])
|
|
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
|
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
|
|
|
print("Generating sky masks from image files...")
|
|
for image_path in tqdm(image_paths):
|
|
image_name = os.path.basename(image_path)
|
|
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
|
|
|
if not refresh_cache and os.path.exists(mask_filepath):
|
|
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
|
if sky_mask is None:
|
|
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
|
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
|
else:
|
|
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
|
|
|
if sky_mask is None:
|
|
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
|
|
continue
|
|
|
|
if sky_mask_visualization_dir is not None:
|
|
image_bgr = cv2.imread(image_path)
|
|
if image_bgr is not None:
|
|
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
|
_save_sky_mask_visualization(
|
|
image_rgb,
|
|
sky_mask,
|
|
os.path.join(sky_mask_visualization_dir, image_name),
|
|
)
|
|
|
|
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
|
sky_mask = cv2.resize(
|
|
sky_mask,
|
|
(target_shape[1], target_shape[0]),
|
|
interpolation=cv2.INTER_LINEAR,
|
|
)
|
|
|
|
sky_masks.append(_mask_to_float(sky_mask))
|
|
|
|
if len(sky_masks) == 0:
|
|
print("Warning: No sky masks generated, skipping sky segmentation")
|
|
return None
|
|
|
|
try:
|
|
return np.stack(sky_masks, axis=0)
|
|
except ValueError:
|
|
return np.array(sky_masks, dtype=object)
|
|
|
|
|
|
def apply_sky_segmentation(
|
|
conf: np.ndarray,
|
|
image_folder: Optional[str] = None,
|
|
image_paths: Optional[list[str]] = None,
|
|
images: Optional[np.ndarray] = None,
|
|
skyseg_model_path: str = "skyseg.onnx",
|
|
sky_mask_dir: Optional[str] = None,
|
|
sky_mask_visualization_dir: Optional[str] = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Apply sky segmentation to confidence scores.
|
|
|
|
Args:
|
|
conf: Confidence scores with shape (S, H, W)
|
|
image_folder: Path to the folder containing input images (optional if images provided)
|
|
image_paths: Optional explicit image file list in processing order
|
|
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
|
|
skyseg_model_path: Path to the sky segmentation ONNX model
|
|
sky_mask_dir: Optional directory for cached raw masks
|
|
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
|
|
|
|
Returns:
|
|
Updated confidence scores with sky regions masked out
|
|
"""
|
|
S, H, W = conf.shape
|
|
|
|
sky_mask_array = load_or_create_sky_masks(
|
|
image_folder=image_folder,
|
|
image_paths=image_paths,
|
|
images=images,
|
|
skyseg_model_path=skyseg_model_path,
|
|
sky_mask_dir=sky_mask_dir,
|
|
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
|
target_shape=(H, W),
|
|
num_frames=S,
|
|
)
|
|
if sky_mask_array is None:
|
|
return conf
|
|
|
|
if sky_mask_array.shape[0] < S:
|
|
print(
|
|
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
|
|
"leaving the remaining frames unmasked"
|
|
)
|
|
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
|
|
padded[: sky_mask_array.shape[0]] = sky_mask_array
|
|
sky_mask_array = padded
|
|
elif sky_mask_array.shape[0] > S:
|
|
sky_mask_array = sky_mask_array[:S]
|
|
|
|
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
|
conf = conf * sky_mask_binary
|
|
|
|
print("Sky segmentation applied successfully")
|
|
return conf
|
|
|
|
|
|
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
|
|
"""
|
|
Download sky segmentation model from HuggingFace.
|
|
|
|
Args:
|
|
output_path: Path to save the model
|
|
|
|
Returns:
|
|
Path to the downloaded model
|
|
"""
|
|
import requests
|
|
|
|
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
|
|
|
|
print(f"Downloading sky segmentation model from {url}...")
|
|
response = requests.get(url, stream=True)
|
|
response.raise_for_status()
|
|
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
|
|
with open(output_path, 'wb') as f:
|
|
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
f.write(chunk)
|
|
pbar.update(len(chunk))
|
|
|
|
print(f"Model saved to {output_path}")
|
|
return output_path
|