update demo

This commit is contained in:
LinZhuoChen
2026-04-16 18:53:26 +08:00
parent 843d9ec31d
commit c7e49e1cbe
1578 changed files with 82 additions and 9 deletions

41
demo.py
View File

@@ -39,7 +39,12 @@ from lingbot_map.utils.load_fn import load_and_preprocess_images
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
"""Load images from folder or video and preprocess into a tensor."""
"""Load images from folder or video and preprocess into a tensor.
Returns:
(images, paths, resolved_image_folder): preprocessed tensor, file paths,
and the folder containing the source images (for sky mask caching etc.).
"""
if video_path is not None:
video_name = os.path.splitext(os.path.basename(video_path))[0]
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
@@ -63,6 +68,7 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
pbar.close()
cap.release()
paths = saved
resolved_folder = out_dir
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
else:
exts = image_ext.split(",")
@@ -70,11 +76,12 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
for ext in exts:
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
paths = sorted(paths)
resolved_folder = image_folder
if stride > 1:
paths = paths[::stride]
if first_k is not None and first_k > 0:
paths = paths[:first_k]
if stride > 1:
paths = paths[::stride]
print(f"Loading {len(paths)} images...")
images = load_and_preprocess_images(
@@ -85,7 +92,7 @@ def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png
)
h, w = images.shape[-2:]
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
return images, paths
return images, paths, resolved_folder
# =============================================================================
@@ -261,8 +268,14 @@ def main():
parser.add_argument("--port", type=int, default=8080)
parser.add_argument("--conf_threshold", type=float, default=1.5)
parser.add_argument("--downsample_factor", type=int, default=10)
parser.add_argument("--point_size", type=float, default=0.0007)
parser.add_argument("--point_size", type=float, default=0.00001)
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
parser.add_argument("--sky_mask_dir", type=str, default=None,
help="Directory for cached sky masks (default: <image_folder>_sky_masks/)")
parser.add_argument("--sky_mask_visualization_dir", type=str, default=None,
help="Save sky mask visualizations (original | mask | overlay) to this directory")
parser.add_argument("--export_preprocessed", type=str, default=None,
help="Export stride-sampled, resized/cropped images to this folder")
args = parser.parse_args()
assert args.image_folder or args.video_path, \
@@ -272,11 +285,24 @@ def main():
# ── Load images & model ──────────────────────────────────────────────────
t0 = time.time()
images, paths = load_images(
images, paths, resolved_image_folder = load_images(
image_folder=args.image_folder, video_path=args.video_path,
fps=args.fps, first_k=args.first_k, stride=args.stride,
image_size=args.image_size, patch_size=args.patch_size,
)
# Export preprocessed images if requested
if args.export_preprocessed:
os.makedirs(args.export_preprocessed, exist_ok=True)
print(f"Exporting {images.shape[0]} preprocessed images to {args.export_preprocessed}...")
for i in range(images.shape[0]):
img = (images[i].permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite(
os.path.join(args.export_preprocessed, f"{i:06d}.png"),
cv2.cvtColor(img, cv2.COLOR_RGB2BGR),
)
print(f"Exported to {args.export_preprocessed}")
model = load_model(args, device)
print(f"Total load time: {time.time() - t0:.1f}s")
@@ -330,6 +356,9 @@ def main():
downsample_factor=args.downsample_factor,
point_size=args.point_size,
mask_sky=args.mask_sky,
image_folder=resolved_image_folder,
sky_mask_dir=args.sky_mask_dir,
sky_mask_visualization_dir=args.sky_mask_visualization_dir,
)
print(f"3D viewer at http://localhost:{args.port}")
viewer.run()