update demo
This commit is contained in:
41
demo.py
41
demo.py
@@ -39,7 +39,12 @@ from lingbot_map.utils.load_fn import load_and_preprocess_images
|
||||
|
||||
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
|
||||
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
|
||||
"""Load images from folder or video and preprocess into a tensor."""
|
||||
"""Load images from folder or video and preprocess into a tensor.
|
||||
|
||||
Returns:
|
||||
(images, paths, resolved_image_folder): preprocessed tensor, file paths,
|
||||
and the folder containing the source images (for sky mask caching etc.).
|
||||
"""
|
||||
if video_path is not None:
|
||||
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
||||
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
|
||||
@@ -63,6 +68,7 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
|
||||
pbar.close()
|
||||
cap.release()
|
||||
paths = saved
|
||||
resolved_folder = out_dir
|
||||
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
|
||||
else:
|
||||
exts = image_ext.split(",")
|
||||
@@ -70,11 +76,12 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
|
||||
for ext in exts:
|
||||
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
|
||||
paths = sorted(paths)
|
||||
resolved_folder = image_folder
|
||||
|
||||
if stride > 1:
|
||||
paths = paths[::stride]
|
||||
if first_k is not None and first_k > 0:
|
||||
paths = paths[:first_k]
|
||||
if stride > 1:
|
||||
paths = paths[::stride]
|
||||
|
||||
print(f"Loading {len(paths)} images...")
|
||||
images = load_and_preprocess_images(
|
||||
@@ -85,7 +92,7 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
|
||||
)
|
||||
h, w = images.shape[-2:]
|
||||
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
|
||||
return images, paths
|
||||
return images, paths, resolved_folder
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -261,8 +268,14 @@ def main():
|
||||
parser.add_argument("--port", type=int, default=8080)
|
||||
parser.add_argument("--conf_threshold", type=float, default=1.5)
|
||||
parser.add_argument("--downsample_factor", type=int, default=10)
|
||||
parser.add_argument("--point_size", type=float, default=0.0007)
|
||||
parser.add_argument("--point_size", type=float, default=0.00001)
|
||||
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
|
||||
parser.add_argument("--sky_mask_dir", type=str, default=None,
|
||||
help="Directory for cached sky masks (default: <image_folder>_sky_masks/)")
|
||||
parser.add_argument("--sky_mask_visualization_dir", type=str, default=None,
|
||||
help="Save sky mask visualizations (original | mask | overlay) to this directory")
|
||||
parser.add_argument("--export_preprocessed", type=str, default=None,
|
||||
help="Export stride-sampled, resized/cropped images to this folder")
|
||||
|
||||
args = parser.parse_args()
|
||||
assert args.image_folder or args.video_path, \
|
||||
@@ -272,11 +285,24 @@ def main():
|
||||
|
||||
# ── Load images & model ──────────────────────────────────────────────────
|
||||
t0 = time.time()
|
||||
images, paths = load_images(
|
||||
images, paths, resolved_image_folder = load_images(
|
||||
image_folder=args.image_folder, video_path=args.video_path,
|
||||
fps=args.fps, first_k=args.first_k, stride=args.stride,
|
||||
image_size=args.image_size, patch_size=args.patch_size,
|
||||
)
|
||||
|
||||
# Export preprocessed images if requested
|
||||
if args.export_preprocessed:
|
||||
os.makedirs(args.export_preprocessed, exist_ok=True)
|
||||
print(f"Exporting {images.shape[0]} preprocessed images to {args.export_preprocessed}...")
|
||||
for i in range(images.shape[0]):
|
||||
img = (images[i].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
cv2.imwrite(
|
||||
os.path.join(args.export_preprocessed, f"{i:06d}.png"),
|
||||
cv2.cvtColor(img, cv2.COLOR_RGB2BGR),
|
||||
)
|
||||
print(f"Exported to {args.export_preprocessed}")
|
||||
|
||||
model = load_model(args, device)
|
||||
print(f"Total load time: {time.time() - t0:.1f}s")
|
||||
|
||||
@@ -330,6 +356,9 @@ def main():
|
||||
downsample_factor=args.downsample_factor,
|
||||
point_size=args.point_size,
|
||||
mask_sky=args.mask_sky,
|
||||
image_folder=resolved_image_folder,
|
||||
sky_mask_dir=args.sky_mask_dir,
|
||||
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
|
||||
)
|
||||
print(f"3D viewer at http://localhost:{args.port}")
|
||||
viewer.run()
|
||||
|
||||
Reference in New Issue
Block a user