add compile

This commit is contained in:
LinZhuoChen
2026-04-21 20:30:18 +08:00
parent 3a37a705a7
commit f2d9711b3d

89
demo.py
View File

@@ -139,6 +139,66 @@ def load_model(args, device):
return model.to(device).eval()
# =============================================================================
# torch.compile (opt-in via --compile)
# =============================================================================
def compile_model(model):
"""Compile hot, fixed-shape modules with mode="reduce-overhead".
Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script,
`model.point_head` is **kept** — the demo needs world_points for visualization.
"""
agg = model.aggregator
for i, b in enumerate(agg.frame_blocks):
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
for i, b in enumerate(agg.patch_embed.blocks):
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
for b in agg.global_blocks:
if hasattr(b, 'attn_pre'):
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
if hasattr(b, 'ffn_residual'):
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1):
"""Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times.
Warmup inputs are sliced from the already-preprocessed `images` tensor, so their
spatial shape matches what real inference will feed — this is what makes the
captured CUDA graphs reusable (reduce-overhead mode keys on shape).
"""
# images: [S, 3, H, W] on device already; slice and add batch dim.
warm_scale = images[:scale_frames].unsqueeze(0).to(dtype)
warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype)
for _ in range(passes):
model.clean_kv_cache()
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
model.forward(
warm_scale,
num_frame_for_scale=scale_frames,
num_frame_per_block=scale_frames,
causal_inference=True,
)
for i in range(warm_stream_n):
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
model.forward(
warm_stream[:, i:i + 1],
num_frame_for_scale=scale_frames,
num_frame_per_block=1,
causal_inference=True,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
# Wipe warmup KV so real inference_streaming starts clean (it also calls
# clean_kv_cache internally, but this is defensive + makes intent obvious).
model.clean_kv_cache()
# =============================================================================
# Post-processing
# =============================================================================
@@ -267,6 +327,9 @@ def main():
"(skips 3 refinement passes at a small accuracy cost).")
parser.add_argument("--use_sdpa", action="store_true", default=False,
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
parser.add_argument("--compile", action="store_true", default=False,
help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. "
"Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.")
parser.add_argument(
"--offload_to_cpu",
action=argparse.BooleanOptionalAction,
@@ -366,6 +429,32 @@ def main():
f"(after the first {args.num_scale_frames} scale frames)."
)
# ── Optional: torch.compile + CUDA-graph warmup (streaming only) ────────
if args.compile:
if args.mode != "streaming":
print(
f"--compile only applies to --mode streaming (got {args.mode!r}); "
"skipping compile."
)
else:
scale_for_warm = min(args.num_scale_frames, num_frames)
warm_stream_n = min(10, max(1, num_frames - scale_for_warm))
print(f"Warmup eager (scale + {warm_stream_n} streaming)...")
t_warm = time.time()
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1)
print(f" eager warmup: {time.time() - t_warm:.1f}s")
print("Compiling hot modules...")
compile_model(model)
# 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so
# the caching allocator / graph-address map converge on the state the
# real inference will see. See gct_profile.py:302-306 for rationale.
print("Warmup compiled (3x dress rehearsal)...")
t_warm = time.time()
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3)
print(f" compiled warmup: {time.time() - t_warm:.1f}s")
# ── Inference ────────────────────────────────────────────────────────────
print(f"Running {args.mode} inference (dtype={dtype})...")
t0 = time.time()