update demp gpu usage and fix some bug
This commit is contained in:
44
demo.py
44
demo.py
@@ -247,7 +247,7 @@ def main():
|
||||
# Streaming options
|
||||
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
|
||||
parser.add_argument("--max_frame_num", type=int, default=1024)
|
||||
parser.add_argument("--num_scale_frames", type=int, default=8)
|
||||
parser.add_argument("--num_scale_frames", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--keyframe_interval",
|
||||
type=int,
|
||||
@@ -258,6 +258,13 @@ def main():
|
||||
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
|
||||
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
||||
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
||||
parser.add_argument(
|
||||
"--offload_to_cpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="Offload per-frame predictions to CPU during inference to cut GPU peak memory. "
|
||||
"Use --no-offload_to_cpu to keep outputs on GPU.",
|
||||
)
|
||||
|
||||
# Windowed options
|
||||
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
|
||||
@@ -306,10 +313,24 @@ def main():
|
||||
model = load_model(args, device)
|
||||
print(f"Total load time: {time.time() - t0:.1f}s")
|
||||
|
||||
# Keep model in its loaded dtype — autocast handles bf16/fp16 for the ops
|
||||
# that benefit from it and keeps LayerNorm / reductions in fp32.
|
||||
if torch.cuda.is_available():
|
||||
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
images = images.to(device)
|
||||
num_frames = images.shape[0]
|
||||
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
|
||||
print(f"Mode: {args.mode}")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print(
|
||||
f"GPU mem after load: "
|
||||
f"alloc={torch.cuda.memory_allocated()/1e9:.2f} GB, "
|
||||
f"reserved={torch.cuda.memory_reserved()/1e9:.2f} GB"
|
||||
)
|
||||
|
||||
if args.mode != "streaming" and args.keyframe_interval != 1:
|
||||
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
|
||||
@@ -321,16 +342,18 @@ def main():
|
||||
)
|
||||
|
||||
# ── Inference ────────────────────────────────────────────────────────────
|
||||
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
||||
print(f"Running {args.mode} inference (dtype={dtype})...")
|
||||
t0 = time.time()
|
||||
|
||||
output_device = torch.device("cpu") if args.offload_to_cpu else None
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
||||
if args.mode == "streaming":
|
||||
predictions = model.inference_streaming(
|
||||
images,
|
||||
num_scale_frames=args.num_scale_frames,
|
||||
keyframe_interval=args.keyframe_interval,
|
||||
output_device=output_device,
|
||||
)
|
||||
else: # windowed
|
||||
predictions = model.inference_windowed(
|
||||
@@ -338,12 +361,27 @@ def main():
|
||||
window_size=args.window_size,
|
||||
overlap_size=args.overlap_size,
|
||||
num_scale_frames=args.num_scale_frames,
|
||||
output_device=output_device,
|
||||
)
|
||||
|
||||
print(f"Inference done in {time.time() - t0:.1f}s")
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
f"GPU peak during inference: "
|
||||
f"{torch.cuda.max_memory_allocated()/1e9:.2f} GB "
|
||||
f"(reserved peak {torch.cuda.max_memory_reserved()/1e9:.2f} GB)"
|
||||
)
|
||||
|
||||
# ── Post-process ─────────────────────────────────────────────────────────
|
||||
predictions, images_cpu = postprocess(predictions, images)
|
||||
if args.offload_to_cpu:
|
||||
del images
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
images_for_post = predictions["images"] # already CPU
|
||||
else:
|
||||
images_for_post = images
|
||||
|
||||
predictions, images_cpu = postprocess(predictions, images_for_post)
|
||||
|
||||
# ── Visualize ────────────────────────────────────────────────────────────
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user