update
This commit is contained in:
333
gct_profile.py
Normal file
333
gct_profile.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Profile GCTStream streaming inference FPS.
|
||||
|
||||
Measures only the top-level model.forward() GPU time — no per-module hooks,
|
||||
no inner breakdown. One CUDA event pair per frame, a single sync at the end.
|
||||
|
||||
Usage:
|
||||
python gct_profile.py --backend both --dtype bf16 --num_frames 500
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lingbot_map.models.gct_stream import GCTStream
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Model loading
|
||||
# ============================================================================
|
||||
|
||||
def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'):
|
||||
"""Build GCTStream with random weights (no checkpoint). Eval mode on device."""
|
||||
model = GCTStream(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
enable_3d_rope=True,
|
||||
max_frame_num=max_frame_num,
|
||||
kv_cache_sliding_window=sliding_window,
|
||||
kv_cache_scale_frames=8,
|
||||
use_sdpa=(backend == 'sdpa'),
|
||||
)
|
||||
return model.eval().to(device)
|
||||
|
||||
|
||||
def compile_model(model):
|
||||
"""
|
||||
Apply torch.compile(mode="reduce-overhead") to compute-heavy, fixed-shape
|
||||
modules and drop point_head. Matches the optimizations from the original
|
||||
--compile path; omits per-block experiments that did not pay off.
|
||||
"""
|
||||
agg = model.aggregator
|
||||
for i, b in enumerate(agg.frame_blocks):
|
||||
agg.frame_blocks[i] = torch.compile(b, mode="reduce-overhead")
|
||||
for i, b in enumerate(agg.patch_embed.blocks):
|
||||
agg.patch_embed.blocks[i] = torch.compile(b, mode="reduce-overhead")
|
||||
for b in agg.global_blocks:
|
||||
if hasattr(b, 'attn_pre'):
|
||||
b.attn_pre = torch.compile(b.attn_pre, mode="reduce-overhead")
|
||||
if hasattr(b, 'ffn_residual'):
|
||||
b.ffn_residual = torch.compile(b.ffn_residual, mode="reduce-overhead")
|
||||
b.attn.proj = torch.compile(b.attn.proj, mode="reduce-overhead")
|
||||
model.point_head = None # saves ~5.9 ms/frame; not needed for FPS measurement
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Profiling — reuse one CUDA event pair, sync after every frame
|
||||
# ============================================================================
|
||||
|
||||
def profile_streaming(model, images, num_frames, dtype):
|
||||
"""
|
||||
Run streaming inference. Return (per_frame_ms, scale_frames, phase1_ms).
|
||||
|
||||
Reuses a single CUDA event pair across frames and syncs after every frame
|
||||
(matches the original non-lightweight path — necessary to keep GPU clock /
|
||||
memory allocator behavior comparable run-to-run).
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
if images.ndim == 4:
|
||||
images = images.unsqueeze(0)
|
||||
images = images.to(dtype)
|
||||
S = min(images.shape[1], num_frames)
|
||||
scale_frames = min(8, S)
|
||||
|
||||
autocast_ctx = (
|
||||
contextlib.nullcontext() if dtype == torch.float32
|
||||
else torch.amp.autocast('cuda', dtype=dtype)
|
||||
)
|
||||
|
||||
start_ev = torch.cuda.Event(enable_timing=True)
|
||||
end_ev = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
model.clean_kv_cache()
|
||||
|
||||
# ── Phase 1: scale frames, bidirectional attention among themselves ────
|
||||
# Move data onto GPU BEFORE the start event so the host→device copy is
|
||||
# excluded from the measured forward time. `.to(device)` is a no-op if
|
||||
# `images` already lives on GPU.
|
||||
scale_batch = images[:, :scale_frames].to(device)
|
||||
start_ev.record()
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
model.forward(
|
||||
scale_batch,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=scale_frames,
|
||||
causal_inference=True,
|
||||
)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
phase1_ms = start_ev.elapsed_time(end_ev)
|
||||
print(f" Phase 1: {phase1_ms:.1f} ms for {scale_frames} scale frames")
|
||||
|
||||
# ── Phase 2: causal streaming, one frame at a time ─────────────────────
|
||||
per_frame_ms = []
|
||||
for i in range(scale_frames, S):
|
||||
frame = images[:, i:i + 1].to(device) # outside the timed region
|
||||
start_ev.record()
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
model.forward(
|
||||
frame,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=1,
|
||||
causal_inference=True,
|
||||
)
|
||||
end_ev.record()
|
||||
torch.cuda.synchronize()
|
||||
per_frame_ms.append(start_ev.elapsed_time(end_ev))
|
||||
|
||||
return per_frame_ms, scale_frames, phase1_ms
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Reporting
|
||||
# ============================================================================
|
||||
|
||||
def summarize(per_frame_ms, scale_frames, phase1_ms, label):
|
||||
"""Print global FPS (total time / total frames) + 10/50/90% windows + trace."""
|
||||
n = len(per_frame_ms)
|
||||
if n == 0:
|
||||
print(f" [{label}]: no frames")
|
||||
return {}
|
||||
|
||||
def avg_ms(pos, window=30):
|
||||
lo = max(0, pos - window)
|
||||
hi = min(n, pos + window + 1)
|
||||
return float(np.mean(per_frame_ms[lo:hi]))
|
||||
|
||||
def fps(ms):
|
||||
return 1000.0 / ms if ms > 0 else 0.0
|
||||
|
||||
# Global throughput: total wall time (Phase 1 + Phase 2) / total frames.
|
||||
total_frames = scale_frames + n
|
||||
total_ms = phase1_ms + float(np.sum(per_frame_ms))
|
||||
global_ms_per_frame = total_ms / total_frames
|
||||
global_fps = fps(global_ms_per_frame)
|
||||
|
||||
# Per-region windowed averages (±30 frames) for how FPS drifts over time.
|
||||
p_lo = max(10, n // 10)
|
||||
p_mid = n // 2
|
||||
p_hi = n - max(1, n // 10)
|
||||
ms_lo, ms_mid, ms_hi = avg_ms(p_lo), avg_ms(p_mid), avg_ms(p_hi)
|
||||
|
||||
print(f"\n [{label}] ({total_frames} total frames: {scale_frames} scale + {n} streaming)")
|
||||
print(f" ── Global FPS ─────────────────────────────────────")
|
||||
print(f" total time: {total_ms / 1000:.2f} s "
|
||||
f"({phase1_ms:.1f} ms phase1 + {total_ms - phase1_ms:.1f} ms phase2)")
|
||||
print(f" per frame : {global_ms_per_frame:6.2f} ms → {global_fps:6.2f} FPS")
|
||||
print(f" ── Windowed FPS (±30 streaming frames) ────────────")
|
||||
print(f" frame {scale_frames + p_lo:>5d} (10%): {ms_lo:6.2f} ms → {fps(ms_lo):6.2f} FPS")
|
||||
print(f" frame {scale_frames + p_mid:>5d} (50%): {ms_mid:6.2f} ms → {fps(ms_mid):6.2f} FPS")
|
||||
print(f" frame {scale_frames + p_hi:>5d} (90%): {ms_hi:6.2f} ms → {fps(ms_hi):6.2f} FPS")
|
||||
|
||||
# Trace at global frame indices that are multiples of 100, matching the
|
||||
# original script. This naturally skips the cold first streaming frame
|
||||
# (global index = scale_frames), whose ms is dominated by one-time CUDA
|
||||
# graph (re)capture after `clean_kv_cache()` in profile_streaming.
|
||||
print(f" ── FPS trace (every 100 global frames) ────────────")
|
||||
first_trace = (100 - scale_frames) % 100 or 100
|
||||
for i in range(first_trace, n, 100):
|
||||
ms_i = avg_ms(i, window=3)
|
||||
print(f" frame {scale_frames + i:>5d}: {fps(ms_i):6.2f} FPS ({ms_i:.2f} ms)")
|
||||
|
||||
return {
|
||||
'global_fps': global_fps, 'global_ms': global_ms_per_frame,
|
||||
'total_ms': total_ms, 'total_frames': total_frames,
|
||||
'phase1_ms': phase1_ms,
|
||||
'ms_lo': ms_lo, 'ms_mid': ms_mid, 'ms_hi': ms_hi,
|
||||
'fps_lo': fps(ms_lo), 'fps_mid': fps(ms_mid), 'fps_hi': fps(ms_hi),
|
||||
}
|
||||
|
||||
|
||||
def print_comparison(results):
|
||||
"""Side-by-side FPS / ms table across all variants."""
|
||||
if len(results) < 2:
|
||||
return
|
||||
keys = sorted(results.keys())
|
||||
col = 14
|
||||
width = 18 + col * len(keys)
|
||||
print(f"\n{'=' * width}\n Comparison\n{'=' * width}")
|
||||
print(f" {'Metric':<18s}" + "".join(f"{k:>{col}s}" for k in keys))
|
||||
print(" " + "-" * (width - 2))
|
||||
rows = [
|
||||
('Global FPS', 'global_fps'), ('Global ms/frame', 'global_ms'),
|
||||
('FPS @10%', 'fps_lo'), ('FPS @50%', 'fps_mid'), ('FPS @90%', 'fps_hi'),
|
||||
('ms @10%', 'ms_lo'), ('ms @50%', 'ms_mid'), ('ms @90%', 'ms_hi'),
|
||||
]
|
||||
for label, field in rows:
|
||||
vals = "".join(f"{results[k].get(field, 0):>{col}.2f}" for k in keys)
|
||||
print(f" {label:<18s}{vals}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GCTStream end-to-end FPS profiling (no module breakdown)."
|
||||
)
|
||||
parser.add_argument('--img_size', type=int, default=518)
|
||||
parser.add_argument('--img_h', type=int, default=378, help='Must be divisible by 14')
|
||||
parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14')
|
||||
parser.add_argument('--num_frames', type=int, default=500)
|
||||
parser.add_argument('--sliding_window', type=int, default=64)
|
||||
parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer')
|
||||
parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16')
|
||||
parser.add_argument('--compile', action='store_true', default=True,
|
||||
help='torch.compile hot modules (reduce-overhead) and drop point_head. '
|
||||
'Typically ~5 FPS faster at 518×378.')
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32}
|
||||
backends = ['sdpa', 'flashinfer'] if args.backend == 'both' else [args.backend]
|
||||
dtypes = ['bf16', 'fp32'] if args.dtype == 'both' else [args.dtype]
|
||||
device = 'cuda'
|
||||
|
||||
print("=" * 72)
|
||||
print(f"GCTStream FPS profiling | {args.img_h}×{args.img_w} | "
|
||||
f"{args.num_frames} frames | sw={args.sliding_window}")
|
||||
print(f" backends={backends} dtypes={dtypes}")
|
||||
print("=" * 72)
|
||||
|
||||
# Synthetic images — keep on CPU for long runs to avoid OOM.
|
||||
img_device = device if args.num_frames <= 500 else 'cpu'
|
||||
print(f"Generating {args.num_frames} synthetic images on {img_device.upper()}...")
|
||||
torch.manual_seed(42)
|
||||
images_master = torch.randn(
|
||||
1, args.num_frames, 3, args.img_h, args.img_w,
|
||||
device=img_device, dtype=torch.float32,
|
||||
)
|
||||
|
||||
results = {}
|
||||
for backend in backends:
|
||||
for dtype_str in dtypes:
|
||||
dtype = dtype_map[dtype_str]
|
||||
key = f'{backend}_{dtype_str}'
|
||||
print(f"\n{'=' * 72}\n Run: {key}\n{'=' * 72}")
|
||||
|
||||
model = load_model(
|
||||
backend,
|
||||
img_size=args.img_size,
|
||||
sliding_window=args.sliding_window,
|
||||
max_frame_num=args.num_frames + 100,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32.
|
||||
if backend == 'flashinfer' and dtype == torch.float32:
|
||||
model.aggregator.kv_cache_force_fp32 = True
|
||||
|
||||
autocast_ctx = (
|
||||
contextlib.nullcontext() if dtype == torch.float32
|
||||
else torch.amp.autocast('cuda', dtype=dtype)
|
||||
)
|
||||
# N streaming frames in warmup so CUDA graphs / cuDNN autotune /
|
||||
# FlashInfer lazy init / allocator growth all complete before the
|
||||
# measured profile begins. Profile's `clean_kv_cache → P1 → stream`
|
||||
# opening then hits already-captured graphs with stable addresses.
|
||||
WARMUP_STREAM = 10
|
||||
warm_scale = images_master[:1, :8].to(device=device, dtype=dtype)
|
||||
warm_stream = images_master[:1, 8:8 + WARMUP_STREAM].to(device=device, dtype=dtype)
|
||||
|
||||
def _warm(m, passes=1):
|
||||
"""Run `passes` full `clean → Phase-1 → N-stream` sequences."""
|
||||
for _ in range(passes):
|
||||
m.clean_kv_cache()
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
m.forward(warm_scale, num_frame_for_scale=8,
|
||||
num_frame_per_block=8, causal_inference=True)
|
||||
for i in range(WARMUP_STREAM):
|
||||
torch.compiler.cudagraph_mark_step_begin()
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
m.forward(warm_stream[:, i:i + 1], num_frame_for_scale=8,
|
||||
num_frame_per_block=1, causal_inference=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Eager warmup populates RoPE / kernel caches BEFORE torch.compile
|
||||
# captures CUDA graphs (otherwise capture would bake in a cache-miss
|
||||
# tensor-allocation path).
|
||||
print(f" Warmup eager (scale + {WARMUP_STREAM} streaming)...")
|
||||
_warm(model)
|
||||
|
||||
if args.compile:
|
||||
print(f" Compiling hot modules...")
|
||||
compile_model(model)
|
||||
# Three passes under compile: 1st captures CUDA graphs, 2nd/3rd
|
||||
# replay so the caching allocator and graph-address map converge
|
||||
# on the exact state the subsequent profile will see.
|
||||
print(f" Warmup compiled (3× dress rehearsal)...")
|
||||
_warm(model, passes=3)
|
||||
else:
|
||||
# No compile → a single dress-rehearsal pass is enough to
|
||||
# settle cuDNN / allocator for the first Phase-2 frame.
|
||||
_warm(model)
|
||||
|
||||
images = images_master.to(dtype=dtype)
|
||||
per_frame_ms, scale_frames, phase1_ms = profile_streaming(
|
||||
model, images, args.num_frames, dtype,
|
||||
)
|
||||
results[key] = summarize(per_frame_ms, scale_frames, phase1_ms, key)
|
||||
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
print_comparison(results)
|
||||
|
||||
out_path = (
|
||||
f'/tmp/profile_results_{args.img_h}x{args.img_w}_'
|
||||
f'{args.num_frames}f_{args.dtype}.json'
|
||||
)
|
||||
with open(out_path, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\n Saved to {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user