244 lines
8.9 KiB
Python
244 lines
8.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from torchvision import transforms as TF
|
|
from tqdm.auto import tqdm
|
|
import numpy as np
|
|
|
|
|
|
def load_and_preprocess_images_square(image_path_list, target_size=1024):
|
|
"""
|
|
Load and preprocess images by center padding to square and resizing to target size.
|
|
Also returns the position information of original pixels after transformation.
|
|
|
|
Args:
|
|
image_path_list (list): List of paths to image files
|
|
target_size (int, optional): Target size for both width and height. Defaults to 518.
|
|
|
|
Returns:
|
|
tuple: (
|
|
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
|
|
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
|
|
)
|
|
|
|
Raises:
|
|
ValueError: If the input list is empty
|
|
"""
|
|
# Check for empty list
|
|
if len(image_path_list) == 0:
|
|
raise ValueError("At least 1 image is required")
|
|
|
|
images = []
|
|
original_coords = [] # Renamed from position_info to be more descriptive
|
|
to_tensor = TF.ToTensor()
|
|
|
|
for image_path in image_path_list:
|
|
# Open image
|
|
img = Image.open(image_path)
|
|
|
|
# If there's an alpha channel, blend onto white background
|
|
if img.mode == "RGBA":
|
|
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
|
img = Image.alpha_composite(background, img)
|
|
|
|
# Convert to RGB
|
|
img = img.convert("RGB")
|
|
|
|
# Get original dimensions
|
|
width, height = img.size
|
|
|
|
# Make the image square by padding the shorter dimension
|
|
max_dim = max(width, height)
|
|
|
|
# Calculate padding
|
|
left = (max_dim - width) // 2
|
|
top = (max_dim - height) // 2
|
|
|
|
# Calculate scale factor for resizing
|
|
scale = target_size / max_dim
|
|
|
|
# Calculate final coordinates of original image in target space
|
|
x1 = left * scale
|
|
y1 = top * scale
|
|
x2 = (left + width) * scale
|
|
y2 = (top + height) * scale
|
|
|
|
# Store original image coordinates and scale
|
|
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
|
|
|
|
# Create a new black square image and paste original
|
|
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
|
square_img.paste(img, (left, top))
|
|
|
|
# Resize to target size
|
|
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
|
|
|
|
# Convert to tensor
|
|
img_tensor = to_tensor(square_img)
|
|
images.append(img_tensor)
|
|
|
|
# Stack all images
|
|
images = torch.stack(images)
|
|
original_coords = torch.from_numpy(np.array(original_coords)).float()
|
|
|
|
# Add additional dimension if single image to ensure correct shape
|
|
if len(image_path_list) == 1:
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0)
|
|
original_coords = original_coords.unsqueeze(0)
|
|
|
|
return images, original_coords
|
|
|
|
|
|
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
|
|
"""
|
|
A quick start function to load and preprocess images for model input.
|
|
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
|
|
|
Args:
|
|
image_path_list (list): List of paths to image files
|
|
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
|
- "crop" (default): Sets width to 518px and center crops height if needed.
|
|
- "pad": Preserves all pixels by making the largest dimension 518px
|
|
and padding the smaller dimension to reach a square shape.
|
|
|
|
Returns:
|
|
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
|
|
|
Raises:
|
|
ValueError: If the input list is empty or if mode is invalid
|
|
|
|
Notes:
|
|
- Images with different dimensions will be padded with white (value=1.0)
|
|
- A warning is printed when images have different shapes
|
|
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
|
and height is center-cropped if larger than 518px
|
|
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
|
and the smaller dimension is padded to reach a square shape (518x518)
|
|
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
|
"""
|
|
# Check for empty list
|
|
if len(image_path_list) == 0:
|
|
raise ValueError("At least 1 image is required")
|
|
|
|
|
|
|
|
# Validate mode
|
|
if mode not in ["crop", "pad"]:
|
|
raise ValueError("Mode must be either 'crop' or 'pad'")
|
|
|
|
target_size = image_size
|
|
to_tensor = TF.ToTensor()
|
|
|
|
def _load_one(idx_path):
|
|
i, image_path = idx_path
|
|
img = Image.open(image_path)
|
|
if img.mode == "RGBA":
|
|
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
|
img = Image.alpha_composite(background, img)
|
|
img = img.convert("RGB")
|
|
|
|
width, height = img.size
|
|
|
|
fx_val = fy_val = cx_val = cy_val = None
|
|
if fx is not None:
|
|
fx_val = fx[i] * width
|
|
fy_val = fy[i] * height
|
|
cx_val = cx[i] * width
|
|
cy_val = cy[i] * height
|
|
|
|
if mode == "pad":
|
|
if width >= height:
|
|
new_width = target_size
|
|
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
|
else:
|
|
new_height = target_size
|
|
new_width = round(width * (new_height / height) / patch_size) * patch_size
|
|
else: # crop
|
|
new_width = target_size
|
|
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
|
|
|
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
|
img = to_tensor(img)
|
|
|
|
if mode == "crop" and new_height > target_size:
|
|
start_y = (new_height - target_size) // 2
|
|
img = img[:, start_y : start_y + target_size, :]
|
|
|
|
if fx is not None:
|
|
fx_val = fx_val * new_width / width
|
|
fy_val = fy_val * new_height / height
|
|
cx_val = img.shape[2] / 2
|
|
cy_val = img.shape[1] / 2
|
|
|
|
if mode == "pad":
|
|
h_padding = target_size - img.shape[1]
|
|
w_padding = target_size - img.shape[2]
|
|
if h_padding > 0 or w_padding > 0:
|
|
pad_top = h_padding // 2
|
|
pad_bottom = h_padding - pad_top
|
|
pad_left = w_padding // 2
|
|
pad_right = w_padding - pad_left
|
|
img = torch.nn.functional.pad(
|
|
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
|
)
|
|
|
|
return i, img, (fx_val, fy_val, cx_val, cy_val)
|
|
|
|
# Parallel load with progress bar
|
|
num_workers = min(16, len(image_path_list))
|
|
results = [None] * len(image_path_list)
|
|
with ThreadPoolExecutor(max_workers=num_workers) as pool:
|
|
futures = pool.map(_load_one, enumerate(image_path_list))
|
|
for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"):
|
|
results[i] = img
|
|
if fx is not None:
|
|
fx[i], fy[i], cx[i], cy[i] = calib
|
|
|
|
images = results
|
|
shapes = set((img.shape[1], img.shape[2]) for img in images)
|
|
|
|
# Check if we have different shapes
|
|
# In theory our model can also work well with different shapes
|
|
if len(shapes) > 1:
|
|
print(f"Warning: Found images with different shapes: {shapes}")
|
|
# Find maximum dimensions
|
|
max_height = max(shape[0] for shape in shapes)
|
|
max_width = max(shape[1] for shape in shapes)
|
|
|
|
# Pad images if necessary
|
|
padded_images = []
|
|
for img in images:
|
|
h_padding = max_height - img.shape[1]
|
|
w_padding = max_width - img.shape[2]
|
|
|
|
if h_padding > 0 or w_padding > 0:
|
|
pad_top = h_padding // 2
|
|
pad_bottom = h_padding - pad_top
|
|
pad_left = w_padding // 2
|
|
pad_right = w_padding - pad_left
|
|
|
|
img = torch.nn.functional.pad(
|
|
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
|
)
|
|
padded_images.append(img)
|
|
images = padded_images
|
|
|
|
images = torch.stack(images) # concatenate images
|
|
|
|
# Ensure correct shape when single image
|
|
if len(image_path_list) == 1:
|
|
# Verify shape is (1, C, H, W)
|
|
if images.dim() == 3:
|
|
images = images.unsqueeze(0)
|
|
if fx is not None:
|
|
return images, fx, fy, cx, cy
|
|
return images
|