remove sky mask

This commit is contained in:
LinZhuoChen
2026-04-16 18:53:54 +08:00
parent c7e49e1cbe
commit 1317fbb7b3
648 changed files with 30 additions and 54 deletions

View File

@@ -32,27 +32,16 @@ def _get_cache_version_path(sky_mask_dir: str) -> str:
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
"""Ensure the sky mask cache directory exists and write the version stamp."""
if sky_mask_dir is None:
return False
return
os.makedirs(sky_mask_dir, exist_ok=True)
version_path = _get_cache_version_path(sky_mask_dir)
refresh_cache = True
if os.path.exists(version_path):
with open(version_path, "r", encoding="utf-8") as f:
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
if refresh_cache:
print(
f"Sky mask cache at {sky_mask_dir} uses an older format; "
"regenerating masks with ImageNet-normalized skyseg input"
)
if not os.path.exists(version_path):
with open(version_path, "w", encoding="utf-8") as f:
f.write(_SKYSEG_CACHE_VERSION)
return refresh_cache
def run_skyseg(
onnx_session,
@@ -279,7 +268,7 @@ def load_or_create_sky_masks(
if sky_mask_dir is None and image_folder is not None:
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
_prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image array...")
for i in tqdm(range(num_images)):
@@ -288,17 +277,12 @@ def load_or_create_sky_masks(
image_name = _get_mask_filename(image_paths, i)
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
if mask_filepath is not None and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
elif sky_mask.shape[:2] != (image_h, image_w):
print(
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
)
if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
# Reuse cached mask
pass
else:
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
else:
@@ -338,14 +322,14 @@ def load_or_create_sky_masks(
if image_folder is None:
image_folder = os.path.dirname(image_paths[0])
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
_prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image files...")
for image_path in tqdm(image_paths):
image_name = os.path.basename(image_path)
mask_filepath = os.path.join(sky_mask_dir, image_name)
if not refresh_cache and os.path.exists(mask_filepath):
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")