remove sky mask
This commit is contained in:
@@ -32,27 +32,16 @@ def _get_cache_version_path(sky_mask_dir: str) -> str:
|
||||
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
||||
|
||||
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> None:
|
||||
"""Ensure the sky mask cache directory exists and write the version stamp."""
|
||||
if sky_mask_dir is None:
|
||||
return False
|
||||
|
||||
return
|
||||
os.makedirs(sky_mask_dir, exist_ok=True)
|
||||
version_path = _get_cache_version_path(sky_mask_dir)
|
||||
refresh_cache = True
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
|
||||
|
||||
if refresh_cache:
|
||||
print(
|
||||
f"Sky mask cache at {sky_mask_dir} uses an older format; "
|
||||
"regenerating masks with ImageNet-normalized skyseg input"
|
||||
)
|
||||
if not os.path.exists(version_path):
|
||||
with open(version_path, "w", encoding="utf-8") as f:
|
||||
f.write(_SKYSEG_CACHE_VERSION)
|
||||
|
||||
return refresh_cache
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
@@ -279,7 +268,7 @@ def load_or_create_sky_masks(
|
||||
|
||||
if sky_mask_dir is None and image_folder is not None:
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
_prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image array...")
|
||||
for i in tqdm(range(num_images)):
|
||||
@@ -288,17 +277,12 @@ def load_or_create_sky_masks(
|
||||
image_name = _get_mask_filename(image_paths, i)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
||||
|
||||
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
|
||||
if mask_filepath is not None and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
elif sky_mask.shape[:2] != (image_h, image_w):
|
||||
print(
|
||||
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
|
||||
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
|
||||
)
|
||||
if sky_mask is not None and sky_mask.shape[:2] == (image_h, image_w):
|
||||
# Reuse cached mask
|
||||
pass
|
||||
else:
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
else:
|
||||
@@ -338,14 +322,14 @@ def load_or_create_sky_masks(
|
||||
if image_folder is None:
|
||||
image_folder = os.path.dirname(image_paths[0])
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
_prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image files...")
|
||||
for image_path in tqdm(image_paths):
|
||||
image_name = os.path.basename(image_path)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
||||
|
||||
if not refresh_cache and os.path.exists(mask_filepath):
|
||||
if os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
|
||||
Reference in New Issue
Block a user