From 783ea8c3dda485c205fc2d9cf371b97767f5d5d8 Mon Sep 17 00:00:00 2001 From: LinZhuoChen Date: Tue, 21 Apr 2026 22:10:44 +0800 Subject: [PATCH] add compile --- gct_profile.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/gct_profile.py b/gct_profile.py index bdbe067..9615145 100644 --- a/gct_profile.py +++ b/gct_profile.py @@ -22,7 +22,7 @@ from lingbot_map.models.gct_stream import GCTStream # Model loading # ============================================================================ -def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'): +def load_model(backend, img_size, sliding_window, max_frame_num, camera_num_iterations, device='cuda'): """Build GCTStream with random weights (no checkpoint). Eval mode on device.""" model = GCTStream( img_size=img_size, @@ -32,6 +32,7 @@ def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'): kv_cache_sliding_window=sliding_window, kv_cache_scale_frames=8, use_sdpa=(backend == 'sdpa'), + camera_num_iterations=camera_num_iterations, ) return model.eval().to(device) @@ -217,11 +218,18 @@ def main(): parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14') parser.add_argument('--num_frames', type=int, default=500) parser.add_argument('--sliding_window', type=int, default=64) + parser.add_argument('--camera_num_iterations', type=int, default=4, + help='Camera head iterative-refinement steps. Default 4; ' + 'set 1 for faster inference (skips 3 refinement passes ' + 'at a small accuracy cost).') parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer') parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16') - parser.add_argument('--compile', action='store_true', default=True, + parser.add_argument('--compile', action='store_true', default=False, help='torch.compile hot modules (reduce-overhead) and drop point_head. ' 'Typically ~5 FPS faster at 518×378.') + parser.add_argument('--fa3', action='store_true', + help='Use FlashInfer FA3 (SM90) kernel instead of FA2 (requires power-of-2 page_size)') + args = parser.parse_args() dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32} @@ -256,12 +264,15 @@ def main(): img_size=args.img_size, sliding_window=args.sliding_window, max_frame_num=args.num_frames + 100, + camera_num_iterations=args.camera_num_iterations, device=device, ) # FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32. if backend == 'flashinfer' and dtype == torch.float32: model.aggregator.kv_cache_force_fp32 = True + if backend == 'flashinfer' and args.fa3: + model.aggregator.kv_cache_fa3 = True autocast_ctx = ( contextlib.nullcontext() if dtype == torch.float32