add compile
This commit is contained in:
89
demo.py
89
demo.py
@@ -139,6 +139,66 @@ def load_model(args, device):
|
||||
return model.to(device).eval()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# torch.compile (opt-in via --compile)
|
||||
# =============================================================================
|
||||
|
||||
def compile_model(model):
|
||||
"""Compile hot, fixed-shape modules with mode="reduce-overhead".
|
||||
|
||||
Mirrors the targets in gct_profile.py:compile_model. Unlike the profile script,
|
||||
`model.point_head` is **kept** — the demo needs world_points for visualization.
|
||||
"""
|
||||
agg = model.aggregator
|
||||
for i, b in enumerate(agg.frame_blocks):
|
||||
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
|
||||
for i, b in enumerate(agg.patch_embed.blocks):
|
||||
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
|
||||
for b in agg.global_blocks:
|
||||
if hasattr(b, 'attn_pre'):
|
||||
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
|
||||
if hasattr(b, 'ffn_residual'):
|
||||
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
|
||||
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
|
||||
|
||||
|
||||
def _warm_streaming(model, images, scale_frames, warm_stream_n, dtype, passes=1):
|
||||
"""Drive `clean_kv_cache → Phase 1 → N streaming forwards` `passes` times.
|
||||
|
||||
Warmup inputs are sliced from the already-preprocessed `images` tensor, so their
|
||||
spatial shape matches what real inference will feed — this is what makes the
|
||||
captured CUDA graphs reusable (reduce-overhead mode keys on shape).
|
||||
"""
|
||||
# images: [S, 3, H, W] on device already; slice and add batch dim.
|
||||
warm_scale = images[:scale_frames].unsqueeze(0).to(dtype)
|
||||
warm_stream = images[scale_frames:scale_frames + warm_stream_n].unsqueeze(0).to(dtype)
|
||||
|
||||
for _ in range(passes):
|
||||
model.clean_kv_cache()
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
||||
model.forward(
|
||||
warm_scale,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=scale_frames,
|
||||
causal_inference=True,
|
||||
)
|
||||
for i in range(warm_stream_n):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
||||
model.forward(
|
||||
warm_stream[:, i:i + 1],
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=1,
|
||||
causal_inference=True,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
# Wipe warmup KV so real inference_streaming starts clean (it also calls
|
||||
# clean_kv_cache internally, but this is defensive + makes intent obvious).
|
||||
model.clean_kv_cache()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-processing
|
||||
# =============================================================================
|
||||
@@ -267,6 +327,9 @@ def main():
|
||||
"(skips 3 refinement passes at a small accuracy cost).")
|
||||
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
||||
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
||||
parser.add_argument("--compile", action="store_true", default=False,
|
||||
help="torch.compile hot modules (reduce-overhead) with a CUDA-graph warmup. "
|
||||
"Streaming mode only; ~5 FPS faster at 518x378. Adds ~30-60 s warmup time.")
|
||||
parser.add_argument(
|
||||
"--offload_to_cpu",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
@@ -366,6 +429,32 @@ def main():
|
||||
f"(after the first {args.num_scale_frames} scale frames)."
|
||||
)
|
||||
|
||||
# ── Optional: torch.compile + CUDA-graph warmup (streaming only) ────────
|
||||
if args.compile:
|
||||
if args.mode != "streaming":
|
||||
print(
|
||||
f"--compile only applies to --mode streaming (got {args.mode!r}); "
|
||||
"skipping compile."
|
||||
)
|
||||
else:
|
||||
scale_for_warm = min(args.num_scale_frames, num_frames)
|
||||
warm_stream_n = min(10, max(1, num_frames - scale_for_warm))
|
||||
print(f"Warmup eager (scale + {warm_stream_n} streaming)...")
|
||||
t_warm = time.time()
|
||||
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=1)
|
||||
print(f" eager warmup: {time.time() - t_warm:.1f}s")
|
||||
|
||||
print("Compiling hot modules...")
|
||||
compile_model(model)
|
||||
|
||||
# 3 passes under compile: 1st captures CUDA graphs, 2nd/3rd replay so
|
||||
# the caching allocator / graph-address map converge on the state the
|
||||
# real inference will see. See gct_profile.py:302-306 for rationale.
|
||||
print("Warmup compiled (3x dress rehearsal)...")
|
||||
t_warm = time.time()
|
||||
_warm_streaming(model, images, scale_for_warm, warm_stream_n, dtype, passes=3)
|
||||
print(f" compiled warmup: {time.time() - t_warm:.1f}s")
|
||||
|
||||
# ── Inference ────────────────────────────────────────────────────────────
|
||||
print(f"Running {args.mode} inference (dtype={dtype})...")
|
||||
t0 = time.time()
|
||||
|
||||
Reference in New Issue
Block a user