# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from concurrent.futures import ThreadPoolExecutor import torch from PIL import Image from torchvision import transforms as TF from tqdm.auto import tqdm import numpy as np def load_and_preprocess_images_square(image_path_list, target_size=1024): """ Load and preprocess images by center padding to square and resizing to target size. Also returns the position information of original pixels after transformation. Args: image_path_list (list): List of paths to image files target_size (int, optional): Target size for both width and height. Defaults to 518. Returns: tuple: ( torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image ) Raises: ValueError: If the input list is empty """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") images = [] original_coords = [] # Renamed from position_info to be more descriptive to_tensor = TF.ToTensor() for image_path in image_path_list: # Open image img = Image.open(image_path) # If there's an alpha channel, blend onto white background if img.mode == "RGBA": background = Image.new("RGBA", img.size, (255, 255, 255, 255)) img = Image.alpha_composite(background, img) # Convert to RGB img = img.convert("RGB") # Get original dimensions width, height = img.size # Make the image square by padding the shorter dimension max_dim = max(width, height) # Calculate padding left = (max_dim - width) // 2 top = (max_dim - height) // 2 # Calculate scale factor for resizing scale = target_size / max_dim # Calculate final coordinates of original image in target space x1 = left * scale y1 = top * scale x2 = (left + width) * scale y2 = (top + height) * scale # Store original image coordinates and scale original_coords.append(np.array([x1, y1, x2, y2, width, height])) # Create a new black square image and paste original square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) square_img.paste(img, (left, top)) # Resize to target size square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) # Convert to tensor img_tensor = to_tensor(square_img) images.append(img_tensor) # Stack all images images = torch.stack(images) original_coords = torch.from_numpy(np.array(original_coords)).float() # Add additional dimension if single image to ensure correct shape if len(image_path_list) == 1: if images.dim() == 3: images = images.unsqueeze(0) original_coords = original_coords.unsqueeze(0) return images, original_coords def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16): """ A quick start function to load and preprocess images for model input. This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. Args: image_path_list (list): List of paths to image files mode (str, optional): Preprocessing mode, either "crop" or "pad". - "crop" (default): Sets width to 518px and center crops height if needed. - "pad": Preserves all pixels by making the largest dimension 518px and padding the smaller dimension to reach a square shape. Returns: torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) Raises: ValueError: If the input list is empty or if mode is invalid Notes: - Images with different dimensions will be padded with white (value=1.0) - A warning is printed when images have different shapes - When mode="crop": The function ensures width=518px while maintaining aspect ratio and height is center-cropped if larger than 518px - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio and the smaller dimension is padded to reach a square shape (518x518) - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements """ # Check for empty list if len(image_path_list) == 0: raise ValueError("At least 1 image is required") # Validate mode if mode not in ["crop", "pad"]: raise ValueError("Mode must be either 'crop' or 'pad'") target_size = image_size to_tensor = TF.ToTensor() def _load_one(idx_path): i, image_path = idx_path img = Image.open(image_path) if img.mode == "RGBA": background = Image.new("RGBA", img.size, (255, 255, 255, 255)) img = Image.alpha_composite(background, img) img = img.convert("RGB") width, height = img.size fx_val = fy_val = cx_val = cy_val = None if fx is not None: fx_val = fx[i] * width fy_val = fy[i] * height cx_val = cx[i] * width cy_val = cy[i] * height if mode == "pad": if width >= height: new_width = target_size new_height = round(height * (new_width / width) / patch_size) * patch_size else: new_height = target_size new_width = round(width * (new_height / height) / patch_size) * patch_size else: # crop new_width = target_size new_height = round(height * (new_width / width) / patch_size) * patch_size img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) img = to_tensor(img) if mode == "crop" and new_height > target_size: start_y = (new_height - target_size) // 2 img = img[:, start_y : start_y + target_size, :] if fx is not None: fx_val = fx_val * new_width / width fy_val = fy_val * new_height / height cx_val = img.shape[2] / 2 cy_val = img.shape[1] / 2 if mode == "pad": h_padding = target_size - img.shape[1] w_padding = target_size - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) return i, img, (fx_val, fy_val, cx_val, cy_val) # Parallel load with progress bar num_workers = min(16, len(image_path_list)) results = [None] * len(image_path_list) with ThreadPoolExecutor(max_workers=num_workers) as pool: futures = pool.map(_load_one, enumerate(image_path_list)) for i, img, calib in tqdm(futures, total=len(image_path_list), desc="Loading images"): results[i] = img if fx is not None: fx[i], fy[i], cx[i], cy[i] = calib images = results shapes = set((img.shape[1], img.shape[2]) for img in images) # Check if we have different shapes # In theory our model can also work well with different shapes if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") # Find maximum dimensions max_height = max(shape[0] for shape in shapes) max_width = max(shape[1] for shape in shapes) # Pad images if necessary padded_images = [] for img in images: h_padding = max_height - img.shape[1] w_padding = max_width - img.shape[2] if h_padding > 0 or w_padding > 0: pad_top = h_padding // 2 pad_bottom = h_padding - pad_top pad_left = w_padding // 2 pad_right = w_padding - pad_left img = torch.nn.functional.pad( img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 ) padded_images.append(img) images = padded_images images = torch.stack(images) # concatenate images # Ensure correct shape when single image if len(image_path_list) == 1: # Verify shape is (1, C, H, W) if images.dim() == 3: images = images.unsqueeze(0) if fx is not None: return images, fx, fy, cx, cy return images