Files
lingbot-map/lingbot_map/utils/load_fn.py
2026-04-16 14:07:07 +08:00

244 lines
8.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ThreadPoolExecutor
import torch
from PIL import Image
from torchvision import transforms as TF
from tqdm.auto import tqdm
import numpy as np
def load_and_preprocess_images_square(image_path_list, target_size=1024):
"""
Load and preprocess images by center padding to square and resizing to target size.
Also returns the position information of original pixels after transformation.
Args:
image_path_list (list): List of paths to image files
target_size (int, optional): Target size for both width and height. Defaults to 518.
Returns:
tuple: (
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
)
Raises:
ValueError: If the input list is empty
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
images = []
original_coords = [] # Renamed from position_info to be more descriptive
to_tensor = TF.ToTensor()
for image_path in image_path_list:
# Open image
img = Image.open(image_path)
# If there's an alpha channel, blend onto white background
if img.mode == "RGBA":
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
img = Image.alpha_composite(background, img)
# Convert to RGB
img = img.convert("RGB")
# Get original dimensions
width, height = img.size
# Make the image square by padding the shorter dimension
max_dim = max(width, height)
# Calculate padding
left = (max_dim - width) // 2
top = (max_dim - height) // 2
# Calculate scale factor for resizing
scale = target_size / max_dim
# Calculate final coordinates of original image in target space
x1 = left * scale
y1 = top * scale
x2 = (left + width) * scale
y2 = (top + height) * scale
# Store original image coordinates and scale
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
# Create a new black square image and paste original
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
square_img.paste(img, (left, top))
# Resize to target size
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
# Convert to tensor
img_tensor = to_tensor(square_img)
images.append(img_tensor)
# Stack all images
images = torch.stack(images)
original_coords = torch.from_numpy(np.array(original_coords)).float()
# Add additional dimension if single image to ensure correct shape
if len(image_path_list) == 1:
if images.dim() == 3:
images = images.unsqueeze(0)
original_coords = original_coords.unsqueeze(0)
return images, original_coords
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
"""
A quick start function to load and preprocess images for model input.
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
Args:
image_path_list (list): List of paths to image files
mode (str, optional): Preprocessing mode, either "crop" or "pad".
- "crop" (default): Sets width to 518px and center crops height if needed.
- "pad": Preserves all pixels by making the largest dimension 518px
and padding the smaller dimension to reach a square shape.
Returns:
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
Raises:
ValueError: If the input list is empty or if mode is invalid
Notes:
- Images with different dimensions will be padded with white (value=1.0)
- A warning is printed when images have different shapes
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
and height is center-cropped if larger than 518px
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
and the smaller dimension is padded to reach a square shape (518x518)
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
# Validate mode
if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'")
target_size = image_size
to_tensor = TF.ToTensor()
def _load_one(idx_path):
i, image_path = idx_path
img = Image.open(image_path)
if img.mode == "RGBA":
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
img = Image.alpha_composite(background, img)
img = img.convert("RGB")
width, height = img.size
fx_val = fy_val = cx_val = cy_val = None
if fx is not None:
fx_val = fx[i] * width
fy_val = fy[i] * height
cx_val = cx[i] * width
cy_val = cy[i] * height
if mode == "pad":
if width >= height:
new_width = target_size
new_height = round(height * (new_width / width) / patch_size) * patch_size
else:
new_height = target_size
new_width = round(width * (new_height / height) / patch_size) * patch_size
else: # crop
new_width = target_size
new_height = round(height * (new_width / width) / patch_size) * patch_size
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
img = to_tensor(img)
if mode == "crop" and new_height > target_size:
start_y = (new_height - target_size) // 2
img = img[:, start_y : start_y + target_size, :]
if fx is not None:
fx_val = fx_val * new_width / width
fy_val = fy_val * new_height / height
cx_val = img.shape[2] / 2
cy_val = img.shape[1] / 2
if mode == "pad":
h_padding = target_size - img.shape[1]
w_padding = target_size - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
return i, img, (fx_val, fy_val, cx_val, cy_val)
# Parallel load with progress bar
num_workers = min(16, len(image_path_list))
results = [None] * len(image_path_list)
with ThreadPoolExecutor(max_workers=num_workers) as pool:
futures = pool.map(_load_one, enumerate(image_path_list))
for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"):
results[i] = img
if fx is not None:
fx[i], fy[i], cx[i], cy[i] = calib
images = results
shapes = set((img.shape[1], img.shape[2]) for img in images)
# Check if we have different shapes
# In theory our model can also work well with different shapes
if len(shapes) > 1:
print(f"Warning: Found images with different shapes: {shapes}")
# Find maximum dimensions
max_height = max(shape[0] for shape in shapes)
max_width = max(shape[1] for shape in shapes)
# Pad images if necessary
padded_images = []
for img in images:
h_padding = max_height - img.shape[1]
w_padding = max_width - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
padded_images.append(img)
images = padded_images
images = torch.stack(images) # concatenate images
# Ensure correct shape when single image
if len(image_path_list) == 1:
# Verify shape is (1, C, H, W)
if images.dim() == 3:
images = images.unsqueeze(0)
if fx is not None:
return images, fx, fy, cx, cy
return images