add window inference

This commit is contained in:
LinZhuoChen
2026-04-16 14:07:07 +08:00
parent 42de2badd2
commit 843d9ec31d
2 changed files with 1236 additions and 37 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -4,9 +4,12 @@
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from concurrent.futures import ThreadPoolExecutor
import torch import torch
from PIL import Image from PIL import Image
from torchvision import transforms as TF from torchvision import transforms as TF
from tqdm.auto import tqdm
import numpy as np import numpy as np
@@ -131,82 +134,76 @@ def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=No
if mode not in ["crop", "pad"]: if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'") raise ValueError("Mode must be either 'crop' or 'pad'")
images = []
shapes = set()
to_tensor = TF.ToTensor()
target_size = image_size target_size = image_size
to_tensor = TF.ToTensor()
# First process all images and collect their shapes def _load_one(idx_path):
for i, image_path in enumerate(image_path_list): i, image_path = idx_path
# Open image
img = Image.open(image_path) img = Image.open(image_path)
# If there's an alpha channel, blend onto white background:
if img.mode == "RGBA": if img.mode == "RGBA":
# Create white background
background = Image.new("RGBA", img.size, (255, 255, 255, 255)) background = Image.new("RGBA", img.size, (255, 255, 255, 255))
# Alpha composite onto the white background
img = Image.alpha_composite(background, img) img = Image.alpha_composite(background, img)
# Now convert to "RGB" (this step assigns white for transparent areas)
img = img.convert("RGB") img = img.convert("RGB")
width, height = img.size width, height = img.size
fx_val = fy_val = cx_val = cy_val = None
if fx is not None: if fx is not None:
fx[i] = fx[i] * width fx_val = fx[i] * width
fy[i] = fy[i] * height fy_val = fy[i] * height
cx[i] = cx[i] * width cx_val = cx[i] * width
cy[i] = cy[i] * height cy_val = cy[i] * height
if mode == "pad": if mode == "pad":
# Make the largest dimension 518px while maintaining aspect ratio
if width >= height: if width >= height:
new_width = target_size new_width = target_size
new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14 new_height = round(height * (new_width / width) / patch_size) * patch_size
else: else:
new_height = target_size new_height = target_size
new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14 new_width = round(width * (new_height / height) / patch_size) * patch_size
else: # crop
else: # mode == "crop"
# Original behavior: set width to 518px
new_width = target_size new_width = target_size
# Calculate height maintaining aspect ratio, divisible by 14
new_height = round(height * (new_width / width) / patch_size) * patch_size new_height = round(height * (new_width / width) / patch_size) * patch_size
# Resize with new dimensions (width, height)
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
img = to_tensor(img) # Convert to tensor (0, 1) img = to_tensor(img)
# Center crop height if it's larger than 518 (only in crop mode)
if mode == "crop" and new_height > target_size: if mode == "crop" and new_height > target_size:
start_y = (new_height - target_size) // 2 start_y = (new_height - target_size) // 2
img = img[:, start_y : start_y + target_size, :] img = img[:, start_y : start_y + target_size, :]
if fx is not None: if fx is not None:
fx[i] = fx[i] * new_width / width fx_val = fx_val * new_width / width
fy[i] = fy[i] * new_height / height fy_val = fy_val * new_height / height
cx_val = img.shape[2] / 2
cy_val = img.shape[1] / 2
cx[i] = img.shape[2] / 2
cy[i] = img.shape[1] / 2
# For pad mode, pad to make a square of target_size x target_size
if mode == "pad": if mode == "pad":
h_padding = target_size - img.shape[1] h_padding = target_size - img.shape[1]
w_padding = target_size - img.shape[2] w_padding = target_size - img.shape[2]
if h_padding > 0 or w_padding > 0: if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2 pad_top = h_padding // 2
pad_bottom = h_padding - pad_top pad_bottom = h_padding - pad_top
pad_left = w_padding // 2 pad_left = w_padding // 2
pad_right = w_padding - pad_left pad_right = w_padding - pad_left
# Pad with white (value=1.0)
img = torch.nn.functional.pad( img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
) )
shapes.add((img.shape[1], img.shape[2])) return i, img, (fx_val, fy_val, cx_val, cy_val)
images.append(img)
# Parallel load with progress bar
num_workers = min(16, len(image_path_list))
results = [None] * len(image_path_list)
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = pool.map(_load_one, enumerate(image_path_list))
for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"):
results[i] = img
if fx is not None:
fx[i], fy[i], cx[i], cy[i] = calib
images = results
shapes = set((img.shape[1], img.shape[2]) for img in images)
# Check if we have different shapes # Check if we have different shapes
# In theory our model can also work well with different shapes # In theory our model can also work well with different shapes