first commit
This commit is contained in:
246
lingbot_map/utils/load_fn.py
Normal file
246
lingbot_map/utils/load_fn.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms as TF
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_and_preprocess_images_square(image_path_list, target_size=1024):
|
||||
"""
|
||||
Load and preprocess images by center padding to square and resizing to target size.
|
||||
Also returns the position information of original pixels after transformation.
|
||||
|
||||
Args:
|
||||
image_path_list (list): List of paths to image files
|
||||
target_size (int, optional): Target size for both width and height. Defaults to 518.
|
||||
|
||||
Returns:
|
||||
tuple: (
|
||||
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
|
||||
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input list is empty
|
||||
"""
|
||||
# Check for empty list
|
||||
if len(image_path_list) == 0:
|
||||
raise ValueError("At least 1 image is required")
|
||||
|
||||
images = []
|
||||
original_coords = [] # Renamed from position_info to be more descriptive
|
||||
to_tensor = TF.ToTensor()
|
||||
|
||||
for image_path in image_path_list:
|
||||
# Open image
|
||||
img = Image.open(image_path)
|
||||
|
||||
# If there's an alpha channel, blend onto white background
|
||||
if img.mode == "RGBA":
|
||||
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
||||
img = Image.alpha_composite(background, img)
|
||||
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Get original dimensions
|
||||
width, height = img.size
|
||||
|
||||
# Make the image square by padding the shorter dimension
|
||||
max_dim = max(width, height)
|
||||
|
||||
# Calculate padding
|
||||
left = (max_dim - width) // 2
|
||||
top = (max_dim - height) // 2
|
||||
|
||||
# Calculate scale factor for resizing
|
||||
scale = target_size / max_dim
|
||||
|
||||
# Calculate final coordinates of original image in target space
|
||||
x1 = left * scale
|
||||
y1 = top * scale
|
||||
x2 = (left + width) * scale
|
||||
y2 = (top + height) * scale
|
||||
|
||||
# Store original image coordinates and scale
|
||||
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
|
||||
|
||||
# Create a new black square image and paste original
|
||||
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
||||
square_img.paste(img, (left, top))
|
||||
|
||||
# Resize to target size
|
||||
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
|
||||
|
||||
# Convert to tensor
|
||||
img_tensor = to_tensor(square_img)
|
||||
images.append(img_tensor)
|
||||
|
||||
# Stack all images
|
||||
images = torch.stack(images)
|
||||
original_coords = torch.from_numpy(np.array(original_coords)).float()
|
||||
|
||||
# Add additional dimension if single image to ensure correct shape
|
||||
if len(image_path_list) == 1:
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0)
|
||||
original_coords = original_coords.unsqueeze(0)
|
||||
|
||||
return images, original_coords
|
||||
|
||||
|
||||
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
|
||||
"""
|
||||
A quick start function to load and preprocess images for model input.
|
||||
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
||||
|
||||
Args:
|
||||
image_path_list (list): List of paths to image files
|
||||
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
||||
- "crop" (default): Sets width to 518px and center crops height if needed.
|
||||
- "pad": Preserves all pixels by making the largest dimension 518px
|
||||
and padding the smaller dimension to reach a square shape.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input list is empty or if mode is invalid
|
||||
|
||||
Notes:
|
||||
- Images with different dimensions will be padded with white (value=1.0)
|
||||
- A warning is printed when images have different shapes
|
||||
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
||||
and height is center-cropped if larger than 518px
|
||||
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
||||
and the smaller dimension is padded to reach a square shape (518x518)
|
||||
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
||||
"""
|
||||
# Check for empty list
|
||||
if len(image_path_list) == 0:
|
||||
raise ValueError("At least 1 image is required")
|
||||
|
||||
|
||||
|
||||
# Validate mode
|
||||
if mode not in ["crop", "pad"]:
|
||||
raise ValueError("Mode must be either 'crop' or 'pad'")
|
||||
|
||||
images = []
|
||||
shapes = set()
|
||||
to_tensor = TF.ToTensor()
|
||||
target_size = image_size
|
||||
|
||||
# First process all images and collect their shapes
|
||||
for i, image_path in enumerate(image_path_list):
|
||||
# Open image
|
||||
img = Image.open(image_path)
|
||||
|
||||
# If there's an alpha channel, blend onto white background:
|
||||
if img.mode == "RGBA":
|
||||
# Create white background
|
||||
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
||||
# Alpha composite onto the white background
|
||||
img = Image.alpha_composite(background, img)
|
||||
|
||||
# Now convert to "RGB" (this step assigns white for transparent areas)
|
||||
img = img.convert("RGB")
|
||||
|
||||
width, height = img.size
|
||||
|
||||
if fx is not None:
|
||||
fx[i] = fx[i] * width
|
||||
fy[i] = fy[i] * height
|
||||
cx[i] = cx[i] * width
|
||||
cy[i] = cy[i] * height
|
||||
|
||||
if mode == "pad":
|
||||
# Make the largest dimension 518px while maintaining aspect ratio
|
||||
if width >= height:
|
||||
new_width = target_size
|
||||
new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14
|
||||
else:
|
||||
new_height = target_size
|
||||
new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14
|
||||
|
||||
else: # mode == "crop"
|
||||
# Original behavior: set width to 518px
|
||||
new_width = target_size
|
||||
# Calculate height maintaining aspect ratio, divisible by 14
|
||||
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
||||
|
||||
# Resize with new dimensions (width, height)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
||||
img = to_tensor(img) # Convert to tensor (0, 1)
|
||||
|
||||
# Center crop height if it's larger than 518 (only in crop mode)
|
||||
if mode == "crop" and new_height > target_size:
|
||||
start_y = (new_height - target_size) // 2
|
||||
img = img[:, start_y : start_y + target_size, :]
|
||||
if fx is not None:
|
||||
fx[i] = fx[i] * new_width / width
|
||||
fy[i] = fy[i] * new_height / height
|
||||
|
||||
cx[i] = img.shape[2] / 2
|
||||
cy[i] = img.shape[1] / 2
|
||||
|
||||
# For pad mode, pad to make a square of target_size x target_size
|
||||
if mode == "pad":
|
||||
h_padding = target_size - img.shape[1]
|
||||
w_padding = target_size - img.shape[2]
|
||||
|
||||
if h_padding > 0 or w_padding > 0:
|
||||
pad_top = h_padding // 2
|
||||
pad_bottom = h_padding - pad_top
|
||||
pad_left = w_padding // 2
|
||||
pad_right = w_padding - pad_left
|
||||
|
||||
# Pad with white (value=1.0)
|
||||
img = torch.nn.functional.pad(
|
||||
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
||||
)
|
||||
|
||||
shapes.add((img.shape[1], img.shape[2]))
|
||||
images.append(img)
|
||||
|
||||
# Check if we have different shapes
|
||||
# In theory our model can also work well with different shapes
|
||||
if len(shapes) > 1:
|
||||
print(f"Warning: Found images with different shapes: {shapes}")
|
||||
# Find maximum dimensions
|
||||
max_height = max(shape[0] for shape in shapes)
|
||||
max_width = max(shape[1] for shape in shapes)
|
||||
|
||||
# Pad images if necessary
|
||||
padded_images = []
|
||||
for img in images:
|
||||
h_padding = max_height - img.shape[1]
|
||||
w_padding = max_width - img.shape[2]
|
||||
|
||||
if h_padding > 0 or w_padding > 0:
|
||||
pad_top = h_padding // 2
|
||||
pad_bottom = h_padding - pad_top
|
||||
pad_left = w_padding // 2
|
||||
pad_right = w_padding - pad_left
|
||||
|
||||
img = torch.nn.functional.pad(
|
||||
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
||||
)
|
||||
padded_images.append(img)
|
||||
images = padded_images
|
||||
|
||||
images = torch.stack(images) # concatenate images
|
||||
|
||||
# Ensure correct shape when single image
|
||||
if len(image_path_list) == 1:
|
||||
# Verify shape is (1, C, H, W)
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0)
|
||||
if fx is not None:
|
||||
return images, fx, fy, cx, cy
|
||||
return images
|
||||
Reference in New Issue
Block a user