first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

514
lingbot_map/layers/block.py Normal file
View File

@@ -0,0 +1,514 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import math
import torch
from torch import nn, Tensor
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
from functools import lru_cache, partial
from torch.nn.attention.flex_attention import BlockMask, create_mask
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
x = drop_add_residual_stochastic_depth(
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
if pos is not None:
# if necessary, apply rope to the subset
pos = pos[brange]
residual = residual_func(x_subset, pos=pos)
else:
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
class FlashInferBlock(nn.Module):
"""
FlashInfer variant of causal block for GCT.
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
Supports optimized token layout and KV cache streaming inference.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = FlashInferAttention(
dim=dim,
num_heads=num_heads,
qk_norm=qk_norm,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
Returns:
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
ready for manager.append_frame + manager.compute_attention.
"""
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
def forward(
self,
x: Tensor,
pos=None,
enable_ulysses_cp=False,
num_patches=None,
num_special=None,
num_frames=None,
enable_3d_rope=False,
kv_cache=None,
global_idx=0,
num_frame_per_block=1,
num_frame_for_scale=-1,
num_register_tokens=4,
) -> Tensor:
# Phase 2 (streaming): single-frame FlashInfer paged attention.
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
if is_streaming:
manager = kv_cache
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
# Eager: write frame K/V to paged cache
manager.append_frame(global_idx, k_nhd, v_nhd)
# CPU-only: update eviction state (deque ops, no GPU kernel)
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.attn.kv_cache_scale_frames,
sliding_window=self.attn.kv_cache_sliding_window,
cross_frame_special=self.attn.kv_cache_cross_frame_special,
include_scale_frames=self.attn.kv_cache_include_scale_frames,
camera_only=self.attn.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
attn_x = manager.compute_attention(global_idx, q_nhd)
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
self.attn.num_heads * self.attn.head_dim)
# Compiled: output projection
attn_x = self.attn.proj(attn_x)
x = x + self.ls1(attn_x)
else:
# Phase 1 (multi-frame scale pass) or non-streaming training path
x = x + self.ls1(self.attn(
self.norm1(x),
pos=pos,
enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches,
num_special=num_special,
num_frames=num_frames,
enable_3d_rope=enable_3d_rope,
kv_cache=kv_cache,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
x = self.ffn_residual(x)
return x
def ffn_residual(self, x: Tensor) -> Tensor:
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
Includes the residual add (x + ...) so torch.compile captures the entire
ffn branch as one CUDA graph.
"""
return x + self.ls2(self.mlp(self.norm2(x)))
class CameraBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
qk_norm=qk_norm, qkv_bias=qkv_bias,
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only)
self.sliding_window_size = sliding_window_size
self.attend_to_scale_frames = attend_to_scale_frames
self.num_random_frames = num_random_frames
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
self.masks = {}
@torch.no_grad()
def _prepare_blockwise_causal_attn_mask(self,
device: torch.device | str, num_frames: int = 21,
frame_seqlen: int = 1560, num_frame_per_block=1
) -> BlockMask:
"""
we will divide the token sequence into the following format
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length
ends = torch.zeros(total_length + padded_length,
device=device, dtype=torch.long)
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(
start=0,
end=total_length,
step=frame_seqlen * num_frame_per_block,
device=device
)
for tmp in frame_indices:
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
frame_seqlen * num_frame_per_block
def attention_mask(b, h, q_idx, kv_idx):
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length, device=device)
return block_mask
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Fast path for full attention (camera head) - skip mask computation
if full_attention:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
mask_block = self._prepare_blockwise_causal_attn_mask(
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
class SDPABlock(nn.Module):
"""
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
with dict-based KV cache. No FlashInfer dependency required.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SDPAAttention(
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, drop=drop, bias=ffn_bias)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
def attn_residual_func(x, pos=None):
return self.ls1(self.attn(
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
def ffn_residual_func(x):
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x