add compile
This commit is contained in:
@@ -22,7 +22,7 @@ from lingbot_map.models.gct_stream import GCTStream
|
||||
# Model loading
|
||||
# ============================================================================
|
||||
|
||||
def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'):
|
||||
def load_model(backend, img_size, sliding_window, max_frame_num, camera_num_iterations, device='cuda'):
|
||||
"""Build GCTStream with random weights (no checkpoint). Eval mode on device."""
|
||||
model = GCTStream(
|
||||
img_size=img_size,
|
||||
@@ -32,6 +32,7 @@ def load_model(backend, img_size, sliding_window, max_frame_num, device='cuda'):
|
||||
kv_cache_sliding_window=sliding_window,
|
||||
kv_cache_scale_frames=8,
|
||||
use_sdpa=(backend == 'sdpa'),
|
||||
camera_num_iterations=camera_num_iterations,
|
||||
)
|
||||
return model.eval().to(device)
|
||||
|
||||
@@ -217,11 +218,18 @@ def main():
|
||||
parser.add_argument('--img_w', type=int, default=504, help='Must be divisible by 14')
|
||||
parser.add_argument('--num_frames', type=int, default=500)
|
||||
parser.add_argument('--sliding_window', type=int, default=64)
|
||||
parser.add_argument('--camera_num_iterations', type=int, default=4,
|
||||
help='Camera head iterative-refinement steps. Default 4; '
|
||||
'set 1 for faster inference (skips 3 refinement passes '
|
||||
'at a small accuracy cost).')
|
||||
parser.add_argument('--backend', choices=['sdpa', 'flashinfer', 'both'], default='flashinfer')
|
||||
parser.add_argument('--dtype', choices=['bf16', 'fp32', 'both'], default='bf16')
|
||||
parser.add_argument('--compile', action='store_true', default=True,
|
||||
parser.add_argument('--compile', action='store_true', default=False,
|
||||
help='torch.compile hot modules (reduce-overhead) and drop point_head. '
|
||||
'Typically ~5 FPS faster at 518×378.')
|
||||
parser.add_argument('--fa3', action='store_true',
|
||||
help='Use FlashInfer FA3 (SM90) kernel instead of FA2 (requires power-of-2 page_size)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dtype_map = {'bf16': torch.bfloat16, 'fp32': torch.float32}
|
||||
@@ -256,12 +264,15 @@ def main():
|
||||
img_size=args.img_size,
|
||||
sliding_window=args.sliding_window,
|
||||
max_frame_num=args.num_frames + 100,
|
||||
camera_num_iterations=args.camera_num_iterations,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# FlashInfer FA2 only supports fp16/bf16; fall back to gather+SDPA for fp32.
|
||||
if backend == 'flashinfer' and dtype == torch.float32:
|
||||
model.aggregator.kv_cache_force_fp32 = True
|
||||
if backend == 'flashinfer' and args.fa3:
|
||||
model.aggregator.kv_cache_fa3 = True
|
||||
|
||||
autocast_ctx = (
|
||||
contextlib.nullcontext() if dtype == torch.float32
|
||||
|
||||
Reference in New Issue
Block a user