This commit is contained in:
LinZhuoChen
2026-04-18 03:44:50 +08:00
parent 01c99afc41
commit 739c7663ab

333
gct_profile.py Normal file
View File

@@ -0,0 +1,333 @@
"""
Profile GCTStream streaming inference FPS.
Measures only the top-level model.forward() GPU time — no per-module hooks,
no inner breakdown. One CUDA event pair per frame, a single sync at the end.
Usage:
python gct_profile.py --backend both --dtype bf16 --num_frames 500
"""
import argparse
import contextlib
import json
import numpy as np
import torch
from lingbot_map.models.gct_stream import GCTStream
# ============================================================================
# Model loading
# ============================================================================
def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'):
"""Build GCTStream with random weights (no checkpoint). Eval mode on device."""
model = GCTStream(
img_size=img_size,
patch_size=14,
enable_3d_rope=True,
max_frame_num=max_frame_num,
kv_cache_sliding_window=sliding_window,
kv_cache_scale_frames=8,
use_sdpa=(backend == 'sdpa'),
)
return model.eval().to(device)
def compile_model(model):
"""
Apply torch.compile(mode="reduce-overhead") to compute-heavy, fixed-shape
modules and drop point_head. Matches the optimizations from the original
--compile path; omits per-block experiments that did not pay off.
"""
agg = model.aggregator
for i, b in enumerate(agg.frame_blocks):
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
for i, b in enumerate(agg.patch_embed.blocks):
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
for b in agg.global_blocks:
if hasattr(b, 'attn_pre'):
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
if hasattr(b, 'ffn_residual'):
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
model.point_head = None # saves ~5.9 ms/frame; not needed for FPS measurement
# ============================================================================
# Profiling — reuse one CUDA event pair, sync after every frame
# ============================================================================
def profile_streaming(model, images, num_frames, dtype):
"""
Run streaming inference. Return (per_frame_ms, scale_frames, phase1_ms).
Reuses a single CUDA event pair across frames and syncs after every frame
(matches the original non-lightweight path — necessary to keep GPU clock /
memory allocator behavior comparable run-to-run).
"""
device = next(model.parameters()).device
if images.ndim == 4:
images = images.unsqueeze(0)
images = images.to(dtype)
S = min(images.shape[1], num_frames)
scale_frames = min(8, S)
autocast_ctx = (
contextlib.nullcontext() if dtype == torch.float32
else torch.amp.autocast('cuda', dtype=dtype)
)
start_ev = torch.cuda.Event(enable_timing=True)
end_ev = torch.cuda.Event(enable_timing=True)
model.clean_kv_cache()
# ── Phase 1: scale frames, bidirectional attention among themselves ────
# Move data onto GPU BEFORE the start event so the host→device copy is
# excluded from the measured forward time. `.to(device)` is a no-op if
# `images` already lives on GPU.
scale_batch = images[:, :scale_frames].to(device)
start_ev.record()
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), autocast_ctx:
model.forward(
scale_batch,
num_frame_for_scale=scale_frames,
num_frame_per_block=scale_frames,
causal_inference=True,
)
end_ev.record()
torch.cuda.synchronize()
phase1_ms = start_ev.elapsed_time(end_ev)
print(f" Phase 1: {phase1_ms:.1f} ms for {scale_frames} scale frames")
# ── Phase 2: causal streaming, one frame at a time ─────────────────────
per_frame_ms = []
for i in range(scale_frames, S):
frame = images[:, i:i + 1].to(device) # outside the timed region
start_ev.record()
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), autocast_ctx:
model.forward(
frame,
num_frame_for_scale=scale_frames,
num_frame_per_block=1,
causal_inference=True,
)
end_ev.record()
torch.cuda.synchronize()
per_frame_ms.append(start_ev.elapsed_time(end_ev))
return per_frame_ms, scale_frames, phase1_ms
# ============================================================================
# Reporting
# ============================================================================
def summarize(per_frame_ms, scale_frames, phase1_ms, label):
"""Print global FPS (total time / total frames) + 10/50/90% windows + trace."""
n = len(per_frame_ms)
if n == 0:
print(f" [{label}]: no frames")
return {}
def avg_ms(pos, window=30):
lo = max(0, pos - window)
hi = min(n, pos + window + 1)
return float(np.mean(per_frame_ms[lo:hi]))
def fps(ms):
return 1000.0 / ms if ms > 0 else 0.0
# Global throughput: total wall time (Phase 1 + Phase 2) / total frames.
total_frames = scale_frames + n
total_ms = phase1_ms + float(np.sum(per_frame_ms))
global_ms_per_frame = total_ms / total_frames
global_fps = fps(global_ms_per_frame)
# Per-region windowed averages (±30 frames) for how FPS drifts over time.
p_lo = max(10, n // 10)
p_mid = n // 2
p_hi = n - max(1, n // 10)
ms_lo, ms_mid, ms_hi = avg_ms(p_lo), avg_ms(p_mid), avg_ms(p_hi)
print(f"\n [{label}] ({total_frames} total frames: {scale_frames} scale + {n} streaming)")
print(f" ── Global FPS ─────────────────────────────────────")
print(f" total time: {total_ms / 1000:.2f} s "
f"({phase1_ms:.1f} ms phase1 + {total_ms - phase1_ms:.1f} ms phase2)")
print(f" per frame : {global_ms_per_frame:6.2f} ms → {global_fps:6.2f} FPS")
print(f" ── Windowed FPS (±30 streaming frames) ────────────")
print(f" frame {scale_frames + p_lo:>5d} (10%): {ms_lo:6.2f} ms → {fps(ms_lo):6.2f} FPS")
print(f" frame {scale_frames + p_mid:>5d} (50%): {ms_mid:6.2f} ms → {fps(ms_mid):6.2f} FPS")
print(f" frame {scale_frames + p_hi:>5d} (90%): {ms_hi:6.2f} ms → {fps(ms_hi):6.2f} FPS")
# Trace at global frame indices that are multiples of 100, matching the
# original script. This naturally skips the cold first streaming frame
# (global index = scale_frames), whose ms is dominated by one-time CUDA
# graph (re)capture after `clean_kv_cache()` in profile_streaming.
print(f" ── FPS trace (every 100 global frames) ────────────")
first_trace = (100 - scale_frames) % 100 or 100
for i in range(first_trace, n, 100):
ms_i = avg_ms(i, window=3)
print(f" frame {scale_frames + i:>5d}: {fps(ms_i):6.2f} FPS ({ms_i:.2f} ms)")
return {
'global_fps': global_fps, 'global_ms': global_ms_per_frame,
'total_ms': total_ms, 'total_frames': total_frames,
'phase1_ms': phase1_ms,
'ms_lo': ms_lo, 'ms_mid': ms_mid, 'ms_hi': ms_hi,
'fps_lo': fps(ms_lo), 'fps_mid': fps(ms_mid), 'fps_hi': fps(ms_hi),
}
def print_comparison(results):
"""Side-by-side FPS / ms table across all variants."""
if len(results) < 2:
return
keys = sorted(results.keys())
col = 14
width = 18 + col * len(keys)
print(f"\n{'=' * width}\n Comparison\n{'=' * width}")
print(f" {'Metric':<18s}" + "".join(f"{k:>{col}s}" for k in keys))
print(" " + "-" * (width - 2))
rows = [
('Global FPS', 'global_fps'), ('Global ms/frame', 'global_ms'),
('FPS @10%', 'fps_lo'), ('FPS @50%', 'fps_mid'), ('FPS @90%', 'fps_hi'),
('ms @10%', 'ms_lo'), ('ms @50%', 'ms_mid'), ('ms @90%', 'ms_hi'),
]
for label, field in rows:
vals = "".join(f"{results[k].get(field, 0):>{col}.2f}" for k in keys)
print(f" {label:<18s}{vals}")
# ============================================================================
# Main
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description="GCTStream end-to-end FPS profiling (no module breakdown)."
)
parser.add_argument('--img_size', type=int, default=518)
parser.add_argument('--img_h', type=int, default=378, help='Must be divisible by 14')
parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14')
parser.add_argument('--num_frames', type=int, default=500)
parser.add_argument('--sliding_window', type=int, default=64)
parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer')
parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16')
parser.add_argument('--compile', action='store_true', default=True,
help='torch.compile hot modules (reduce-overhead) and drop point_head. '
'Typically ~5 FPS faster at 518×378.')
args = parser.parse_args()
dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32}
backends = ['sdpa', 'flashinfer'] if args.backend == 'both' else [args.backend]
dtypes = ['bf16', 'fp32'] if args.dtype == 'both' else [args.dtype]
device = 'cuda'
print("=" * 72)
print(f"GCTStream FPS profiling | {args.img_h}×{args.img_w} | "
f"{args.num_frames} frames | sw={args.sliding_window}")
print(f" backends={backends} dtypes={dtypes}")
print("=" * 72)
# Synthetic images — keep on CPU for long runs to avoid OOM.
img_device = device if args.num_frames <= 500 else 'cpu'
print(f"Generating {args.num_frames} synthetic images on {img_device.upper()}...")
torch.manual_seed(42)
images_master = torch.randn(
1, args.num_frames, 3, args.img_h, args.img_w,
device=img_device, dtype=torch.float32,
)
results = {}
for backend in backends:
for dtype_str in dtypes:
dtype = dtype_map[dtype_str]
key = f'{backend}_{dtype_str}'
print(f"\n{'=' * 72}\n Run: {key}\n{'=' * 72}")
model = load_model(
backend,
img_size=args.img_size,
sliding_window=args.sliding_window,
max_frame_num=args.num_frames + 100,
device=device,
)
# FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32.
if backend == 'flashinfer' and dtype == torch.float32:
model.aggregator.kv_cache_force_fp32 = True
autocast_ctx = (
contextlib.nullcontext() if dtype == torch.float32
else torch.amp.autocast('cuda', dtype=dtype)
)
# N streaming frames in warmup so CUDA graphs / cuDNN autotune /
# FlashInfer lazy init / allocator growth all complete before the
# measured profile begins. Profile's `clean_kv_cache → P1 → stream`
# opening then hits already-captured graphs with stable addresses.
WARMUP_STREAM = 10
warm_scale = images_master[:1, :8].to(device=device, dtype=dtype)
warm_stream = images_master[:1, 8:8 + WARMUP_STREAM].to(device=device, dtype=dtype)
def _warm(m, passes=1):
"""Run `passes` full `clean → Phase-1 → N-stream` sequences."""
for _ in range(passes):
m.clean_kv_cache()
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), autocast_ctx:
m.forward(warm_scale, num_frame_for_scale=8,
num_frame_per_block=8, causal_inference=True)
for i in range(WARMUP_STREAM):
torch.compiler.cudagraph_mark_step_begin()
with torch.no_grad(), autocast_ctx:
m.forward(warm_stream[:, i:i + 1], num_frame_for_scale=8,
num_frame_per_block=1, causal_inference=True)
torch.cuda.synchronize()
# Eager warmup populates RoPE / kernel caches BEFORE torch.compile
# captures CUDA graphs (otherwise capture would bake in a cache-miss
# tensor-allocation path).
print(f" Warmup eager (scale + {WARMUP_STREAM} streaming)...")
_warm(model)
if args.compile:
print(f" Compiling hot modules...")
compile_model(model)
# Three passes under compile: 1st captures CUDA graphs, 2nd/3rd
# replay so the caching allocator and graph-address map converge
# on the exact state the subsequent profile will see.
print(f" Warmup compiled (3× dress rehearsal)...")
_warm(model, passes=3)
else:
# No compile → a single dress-rehearsal pass is enough to
# settle cuDNN / allocator for the first Phase-2 frame.
_warm(model)
images = images_master.to(dtype=dtype)
per_frame_ms, scale_frames, phase1_ms = profile_streaming(
model, images, args.num_frames, dtype,
)
results[key] = summarize(per_frame_ms, scale_frames, phase1_ms, key)
del model
torch.cuda.empty_cache()
print_comparison(results)
out_path = (
f'/tmp/profile_results_{args.img_h}x{args.img_w}_'
f'{args.num_frames}f_{args.dtype}.json'
)
with open(out_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n Saved to {out_path}")
if __name__ == "__main__":
main()