diff --git a/README.md b/README.md index 9a07485..6cdb59d 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,13 @@ python demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ --use_sdpa ``` +### Running on Limited GPU Memory + +If you run into out-of-memory issues, try one (or both) of the following: + +- **`--offload_to_cpu`** — offload per-frame predictions to CPU during inference (on by default; use `--no-offload_to_cpu` only if you have memory to spare). +- **`--num_scale_frames 2`** — reduce the number of bidirectional scale frames from the default 8 down to 2, which shrinks the activation peak of the initial scale phase. + # 📜 License This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details. diff --git a/demo.py b/demo.py index c55d3f5..ae3d4c5 100644 --- a/demo.py +++ b/demo.py @@ -23,6 +23,11 @@ import glob import os import time +# Must be set before `import torch` / any CUDA init. Reduces the reserved-vs-allocated +# memory gap by letting the caching allocator grow segments on demand instead of +# pre-reserving fixed-size blocks. +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + import cv2 import numpy as np import torch @@ -113,7 +118,7 @@ def load_model(args, device): enable_3d_rope=args.enable_3d_rope, max_frame_num=args.max_frame_num, kv_cache_sliding_window=args.kv_cache_sliding_window, - kv_cache_scale_frames=args.kv_cache_scale_frames, + kv_cache_scale_frames=args.num_scale_frames, kv_cache_cross_frame_special=True, kv_cache_include_scale_frames=True, use_sdpa=args.use_sdpa, @@ -247,7 +252,7 @@ def main(): # Streaming options parser.add_argument("--enable_3d_rope", action="store_true", default=True) parser.add_argument("--max_frame_num", type=int, default=1024) - parser.add_argument("--num_scale_frames", type=int, default=4) + parser.add_argument("--num_scale_frames", type=int, default=8) parser.add_argument( "--keyframe_interval", type=int, @@ -255,7 +260,6 @@ def main(): help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame.", ) parser.add_argument("--kv_cache_sliding_window", type=int, default=64) - parser.add_argument("--kv_cache_scale_frames", type=int, default=8) parser.add_argument("--use_sdpa", action="store_true", default=False, help="Use SDPA backend (no flashinfer needed). Default: FlashInfer") parser.add_argument( @@ -265,7 +269,6 @@ def main(): help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. " "Use --no-offload_to_cpu to keep outputs on GPU.", ) - # Windowed options parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)") parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows") @@ -313,13 +316,21 @@ def main(): model = load_model(args, device) print(f"Total load time: {time.time() - t0:.1f}s") - # Keep model in its loaded dtype — autocast handles bf16/fp16 for the ops - # that benefit from it and keeps LayerNorm / reductions in fp32. + # Pick inference dtype; autocast still runs for the ops that need fp32 (e.g. LayerNorm). if torch.cuda.is_available(): dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 else: dtype = torch.float32 + # Cast the aggregator (DINOv2-style trunk) to the inference dtype to remove the + # redundant fp32 master weight copy + autocast bf16 weight cache (~2-3 GB saved, + # no measurable quality change). gct_base._predict_* upcasts inputs to fp32 and + # runs each head under `autocast(enabled=False)`, so camera/depth/point heads + # keep fp32 weights automatically. + if dtype != torch.float32 and getattr(model, "aggregator", None) is not None: + print(f"Casting aggregator to {dtype} (heads kept in fp32)") + model.aggregator = model.aggregator.to(dtype=dtype) + images = images.to(device) num_frames = images.shape[0] print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")