diff --git a/demo.py b/demo.py index 33d396d..942e8ba 100644 --- a/demo.py +++ b/demo.py @@ -139,6 +139,66 @@ def load_model(args, device): return model.to(device).eval() +# ============================================================================= +# torch.compile (opt-in via --compile) +# ============================================================================= + +def compile_model(model): + """Compile hot, fixed-shape modules with mode="reduce-overhead". + + Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script, + `model.point_head` is **kept** — the demo needs world_points for visualization. + """ + agg = model.aggregator + for i, b in enumerate(agg.frame_blocks): + agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead") + for i, b in enumerate(agg.patch_embed.blocks): + agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead") + for b in agg.global_blocks: + if hasattr(b, 'attn_pre'): + b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead") + if hasattr(b, 'ffn_residual'): + b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead") + b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead") + + +def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1): + """Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times. + + Warmup inputs are sliced from the already-preprocessed `images` tensor, so their + spatial shape matches what real inference will feed — this is what makes the + captured CUDA graphs reusable (reduce-overhead mode keys on shape). + """ + # images: [S, 3, H, W] on device already; slice and add batch dim. + warm_scale = images[:scale_frames].unsqueeze(0).to(dtype) + warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype) + + for _ in range(passes): + model.clean_kv_cache() + torch.compiler.cudagraph_mark_step_begin() + with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): + model.forward( + warm_scale, + num_frame_for_scale=scale_frames, + num_frame_per_block=scale_frames, + causal_inference=True, + ) + for i in range(warm_stream_n): + torch.compiler.cudagraph_mark_step_begin() + with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): + model.forward( + warm_stream[:, i:i + 1], + num_frame_for_scale=scale_frames, + num_frame_per_block=1, + causal_inference=True, + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + # Wipe warmup KV so real inference_streaming starts clean (it also calls + # clean_kv_cache internally, but this is defensive + makes intent obvious). + model.clean_kv_cache() + + # ============================================================================= # Post-processing # ============================================================================= @@ -267,6 +327,9 @@ def main(): "(skips 3 refinement passes at a small accuracy cost).") parser.add_argument("--use_sdpa", action="store_true", default=False, help="Use SDPA backend (no flashinfer needed). Default: FlashInfer") + parser.add_argument("--compile", action="store_true", default=False, + help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. " + "Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.") parser.add_argument( "--offload_to_cpu", action=argparse.BooleanOptionalAction, @@ -366,6 +429,32 @@ def main(): f"(after the first {args.num_scale_frames} scale frames)." ) + # ── Optional: torch.compile + CUDA-graph warmup (streaming only) ──────── + if args.compile: + if args.mode != "streaming": + print( + f"--compile only applies to --mode streaming (got {args.mode!r}); " + "skipping compile." + ) + else: + scale_for_warm = min(args.num_scale_frames, num_frames) + warm_stream_n = min(10, max(1, num_frames - scale_for_warm)) + print(f"Warmup eager (scale + {warm_stream_n} streaming)...") + t_warm = time.time() + _warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1) + print(f" eager warmup: {time.time() - t_warm:.1f}s") + + print("Compiling hot modules...") + compile_model(model) + + # 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so + # the caching allocator / graph-address map converge on the state the + # real inference will see. See gct_profile.py:302-306 for rationale. + print("Warmup compiled (3x dress rehearsal)...") + t_warm = time.time() + _warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3) + print(f" compiled warmup: {time.time() - t_warm:.1f}s") + # ── Inference ──────────────────────────────────────────────────────────── print(f"Running {args.mode} inference (dtype={dtype})...") t0 = time.time()