add compile

This commit is contained in:
LinZhuoChen
2026-04-21 22:10:44 +08:00
parent f2d9711b3d
commit 783ea8c3dd

View File

@@ -22,7 +22,7 @@ from lingbot_map.models.gct_stream import GCTStream
# Model loading # Model loading
# ============================================================================ # ============================================================================
def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'): def load_model(backend, img_size, sliding_window, max_frame_num, camera_num_iterations, device='cuda'):
"""Build GCTStream with random weights (no checkpoint). Eval mode on device.""" """Build GCTStream with random weights (no checkpoint). Eval mode on device."""
model = GCTStream( model = GCTStream(
img_size=img_size, img_size=img_size,
@@ -32,6 +32,7 @@ def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'):
kv_cache_sliding_window=sliding_window, kv_cache_sliding_window=sliding_window,
kv_cache_scale_frames=8, kv_cache_scale_frames=8,
use_sdpa=(backend == 'sdpa'), use_sdpa=(backend == 'sdpa'),
camera_num_iterations=camera_num_iterations,
) )
return model.eval().to(device) return model.eval().to(device)
@@ -217,11 +218,18 @@ def main():
parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14') parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14')
parser.add_argument('--num_frames', type=int, default=500) parser.add_argument('--num_frames', type=int, default=500)
parser.add_argument('--sliding_window', type=int, default=64) parser.add_argument('--sliding_window', type=int, default=64)
parser.add_argument('--camera_num_iterations', type=int, default=4,
help='Camera head iterative-refinement steps. Default 4; '
'set 1 for faster inference (skips 3 refinement passes '
'at a small accuracy cost).')
parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer') parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer')
parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16') parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16')
parser.add_argument('--compile', action='store_true', default=True, parser.add_argument('--compile', action='store_true', default=False,
help='torch.compile hot modules (reduce-overhead) and drop point_head. ' help='torch.compile hot modules (reduce-overhead) and drop point_head. '
'Typically ~5 FPS faster at 518×378.') 'Typically ~5 FPS faster at 518×378.')
parser.add_argument('--fa3', action='store_true',
help='Use FlashInfer FA3 (SM90) kernel instead of FA2 (requires power-of-2 page_size)')
args = parser.parse_args() args = parser.parse_args()
dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32} dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32}
@@ -256,12 +264,15 @@ def main():
img_size=args.img_size, img_size=args.img_size,
sliding_window=args.sliding_window, sliding_window=args.sliding_window,
max_frame_num=args.num_frames + 100, max_frame_num=args.num_frames + 100,
camera_num_iterations=args.camera_num_iterations,
device=device, device=device,
) )
# FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32. # FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32.
if backend == 'flashinfer' and dtype == torch.float32: if backend == 'flashinfer' and dtype == torch.float32:
model.aggregator.kv_cache_force_fp32 = True model.aggregator.kv_cache_force_fp32 = True
if backend == 'flashinfer' and args.fa3:
model.aggregator.kv_cache_fa3 = True
autocast_ctx = ( autocast_ctx = (
contextlib.nullcontext() if dtype == torch.float32 contextlib.nullcontext() if dtype == torch.float32