first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
from lingbot_map.layers.mlp import Mlp
from lingbot_map.layers.patch_embed import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
from lingbot_map.layers.attention import Attention as MemEffAttention

View File

@@ -0,0 +1,766 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import math
import warnings
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from lingbot_map.layers.rope import apply_rotary_emb
from einops import rearrange
# FlashInfer imports (optional - for paged attention)
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
print("flashinfer not available")
try:
from torchtitan.distributed.sequence_parallel import (
gather_seq_scatter_heads,
gather_heads_scatter_seq,
pad_tensor,
slice_input_tensor_scale_grad,
gather_outputs,
)
except ImportError:
print("torchtitan not available for ulysses cp")
def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
"""Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
q = gather_seq_scatter_heads(q, seq_dim, head_dim)
k = gather_seq_scatter_heads(k, seq_dim, head_dim)
v = gather_seq_scatter_heads(v, seq_dim, head_dim)
return q, k, v
from typing_extensions import List
from typing import Optional, Tuple
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
if self.rope is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
if enable_ulysses_cp:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CausalAttention(nn.Module):
"""
Causal self-attention module with KV cache support for streaming inference.
Used by CasualBlockCamera in camera_head.py.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
elementwise_attn_output_gate=False,
# KV cache eviction parameters (matching build_attn_mask)
kv_cache_sliding_window: int =64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
# Store KV cache eviction parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
B, N, C = x.shape
# Calculate special token indices
camera_token_idx = 0
scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
# [3, B, num_heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.gate_proj is not None:
gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if kv_cache is None:
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
N = q.shape[2] # Update N after gather
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif enable_3d_rope and pos is not None:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
with torch.no_grad():
block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
if block_mask.dim() == 2:
block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
mask = block_mask | ~video_mask
# Apply sliding window mask if sliding_window_size > 0
# sliding_window_size is in units of num_frame_per_block
if sliding_window_size > 0 and frame_seqlen is not None:
# Create sliding window mask: each frame can only attend to frames within the window
num_frames = N // frame_seqlen
sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
for i in range(num_frames):
q_start = i * frame_seqlen
q_end = (i + 1) * frame_seqlen
# Calculate the window start: sliding_window_size is in units of num_frame_per_block
# So the actual window size in frames is sliding_window_size * num_frame_per_block
window_size_in_frames = sliding_window_size * num_frame_per_block
window_start_frame = max(0, i - window_size_in_frames + 1)
k_start = window_start_frame * frame_seqlen
k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
# Combine with existing mask: both masks need to allow attention
mask = mask & sliding_mask
# If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
if num_frame_for_scale > 0:
for i in range(num_frames):
q_start = i * frame_seqlen
q_end = (i + 1) * frame_seqlen
# Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
## global attention for the first num_frame_for_scale frames
if num_frame_for_scale > 0:
mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=mask
)
else:
# Apply RoPE to current k before caching
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif enable_3d_rope and pos is not None:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Check if we should skip appending to cache (non-keyframe in keyframe mode)
skip_append = kv_cache.get("_skip_append", False)
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
if not skip_append:
# KEYFRAME: store in cache (original behavior)
if kv_cache[f"k_{global_idx}"] is None:
kv_cache[f"k_{global_idx}"] = k_reshaped
kv_cache[f"v_{global_idx}"] = v_reshaped
else:
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
# Apply sliding window eviction BEFORE attention to match causal_3drope behavior
# This ensures current frame only attends to frames within the sliding window
self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
# Retrieve full k, v from cache (already RoPE-applied, already evicted)
k = kv_cache[f"k_{global_idx}"].clone()
v = kv_cache[f"v_{global_idx}"].clone()
else:
# NON-KEYFRAME: attend to [cached + current] without storing in cache
if kv_cache[f"k_{global_idx}"] is not None:
k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
else:
k = k_reshaped
v = v_reshaped
a, b, c, d, e = k.shape
k = k.reshape(a, b, c*d, e)
v = v.reshape(a, b, c*d, e)
# Prepend special tokens (camera + scale) from evicted frames if they exist
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
special_v = kv_cache[f"v_{global_idx}_special"]
sa, sb, sc, sd, se = special_k.shape
special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
special_v = special_v.reshape(sa, sb, sc * sd, se)
# Prepend special tokens (older frames first)
k = torch.cat([special_k, k], dim=2)
v = torch.cat([special_v, v], dim=2)
# Note: k from cache is already RoPE-applied, no need to apply again
if self.fused_attn:
# Use mask-based SDPA to ensure same kernel as batch mode
# The causal constraint is enforced by KV cache contents, not by mask
mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=mask,
)
if self.gate_proj is not None:
x = x * torch.sigmoid(gate_score)
if enable_ulysses_cp:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
# Use actual dimensions from attention output, not original input C
# x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
"""
Apply sliding window eviction to KV cache BEFORE attention.
This ensures current frame only attends to frames within the sliding window,
matching the behavior of causal_3drope's attention mask.
"""
sliding_window_frames = self.kv_cache_sliding_window
scale_frames = self.kv_cache_scale_frames
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
if num_cached_frames > sliding_window_frames + scale_frames:
evict_start = scale_frames
evict_end = num_cached_frames - sliding_window_frames
if evict_end > evict_start:
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
if self.kv_cache_cross_frame_special:
if self.kv_cache_camera_only:
# Only keep camera token
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
else:
# Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
# Special tokens are in range [camera_token_idx, scale_token_idx+1)
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
kv_cache[f"k_{global_idx}_special"] = new_special_k
kv_cache[f"v_{global_idx}_special"] = new_special_v
else:
kv_cache[f"k_{global_idx}_special"] = torch.cat(
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
kv_cache[f"v_{global_idx}_special"] = torch.cat(
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
if self.kv_cache_include_scale_frames:
kv_cache[f"k_{global_idx}"] = torch.cat([
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat([
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
else:
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
class FlashInferAttention(Attention):
"""
FlashInfer variant of the GCT attention layer.
Uses FlashInferKVCacheManager for paged KV cache storage and
FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
Supports the same optimized token layout and KV cache streaming inference.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
# KV cache eviction parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
if not FLASHINFER_AVAILABLE:
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
super().__init__(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
# Store KV cache eviction parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
"""Fused pre-attention ops for single-frame streaming (Phase 2).
Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
[tpf, H, D] format ready for append_frame + compute_attention.
Extracted as a method so torch.compile can capture all pre-attn ops as one
CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
squeeze/permute/contiguous x3).
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None: # enable_3d_rope=True
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
return q_nhd, k_nhd, v_nhd
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
# KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
"""
Forward pass with FlashInfer paged KV cache and attention.
Args:
x: Input tensor [B, N, C]
kv_cache: FlashInferKVCacheManager instance or None (batch mode)
global_idx: Block index for per-block cache access
"""
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
B, N, C = x.shape
# Detect if using optimized layout
using_optimized_layout = (num_patches is not None and num_special is not None
and num_frames is not None)
# ========== Batch Mode (no KV cache manager) ==========
if not isinstance(kv_cache, FlashInferKVCacheManager):
# [3, B, num_heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
if using_optimized_layout:
boundary = num_frames * num_patches
q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
)
q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
q_special, k_special, v_special, seq_dim=2, head_dim=1
)
q = torch.cat([q_patch, q_special], dim=2)
k = torch.cat([k_patch, k_special], dim=2)
v = torch.cat([v_patch, v_special], dim=2)
else:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Batch mode: use SDPA for numerical consistency with SDPA variant
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
if enable_ulysses_cp:
if using_optimized_layout:
seq_global = x.shape[2]
seq_local = num_frames * (num_patches + num_special)
cp_size = seq_global // seq_local
boundary_global = num_frames * cp_size * num_patches
x_patch = x[:, :, :boundary_global, :]
x_special = x[:, :, boundary_global:, :]
x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
x = torch.cat([x_patch, x_special], dim=2)
else:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# ========== Streaming Mode (with FlashInferKVCacheManager) ==========
else:
manager = kv_cache # FlashInferKVCacheManager
# Phase 1 (scale frames): num_frames > 1 — multi-frame batch
# Phase 2 (streaming): num_frames == 1 — single frame
is_multi_frame = (num_frames is not None and num_frames > 1)
if is_multi_frame:
# Phase 1: compute full self-attention via SDPA (all frames attend to each other),
# then append each frame's K/V to the paged cache one at a time.
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# Apply RoPE before caching (RoPE baked into K before append)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# Append each frame's K/V to the paged cache individually.
tpf = manager.tokens_per_frame
k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
v_all = v.squeeze(0).permute(1, 0, 2)
for f_idx in range(num_frames):
s = f_idx * tpf
manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
cross_frame_special=self.kv_cache_cross_frame_special,
include_scale_frames=self.kv_cache_include_scale_frames,
camera_only=self.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
else:
# Phase 2: single-frame streaming via FlashInfer paged attention.
q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
# 1. Append to paged cache
manager.append_frame(global_idx, k_nhd, v_nhd)
# 2. Apply sliding window eviction
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
cross_frame_special=self.kv_cache_cross_frame_special,
include_scale_frames=self.kv_cache_include_scale_frames,
camera_only=self.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
# 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
x = manager.compute_attention(global_idx, q_nhd)
# Convert back: [tpf, H, D] -> [B, tpf, C].
x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SDPAAttention(Attention):
"""
SDPA variant for streaming inference.
Uses F.scaled_dot_product_attention with dict-based KV cache.
No FlashInfer dependency required — works on any CUDA GPU.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
B, N, C = x.shape
using_optimized_layout = (num_patches is not None and num_special is not None
and num_frames is not None)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# ========== Batch Mode (no KV cache) ==========
if kv_cache is None:
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# ========== Streaming Mode (with KV cache dict) ==========
else:
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
camera_token_idx = 0
scale_token_idx = camera_token_idx + num_register_tokens + 1
if kv_cache[f"k_{global_idx}"] is None:
kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
N // num_frame_per_block, self.head_dim)
kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
N // num_frame_per_block, self.head_dim)
else:
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
kv_cache[f"k_{global_idx}"] = torch.cat((
kv_cache[f"k_{global_idx}"],
k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
), dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat((
kv_cache[f"v_{global_idx}"],
v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
), dim=2)
self._apply_kv_cache_eviction(
kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
)
k_cached = kv_cache[f"k_{global_idx}"].clone()
v_cached = kv_cache[f"v_{global_idx}"].clone()
a, b, c, d, e = k_cached.shape
k_full = k_cached.reshape(a, b, c * d, e)
v_full = v_cached.reshape(a, b, c * d, e)
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
special_k = kv_cache[f"k_{global_idx}_special"]
special_v = kv_cache[f"v_{global_idx}_special"]
sa, sb, sc, sd, se = special_k.shape
k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
q_seq_len = q.shape[2]
x = F.scaled_dot_product_attention(
q, k_full, v_full,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
"""Apply sliding window eviction to KV cache."""
sliding_window_frames = self.kv_cache_sliding_window
scale_frames = self.kv_cache_scale_frames
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
if num_cached_frames > sliding_window_frames + scale_frames:
evict_start = scale_frames
evict_end = num_cached_frames - sliding_window_frames
if evict_end > evict_start:
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
if self.kv_cache_cross_frame_special:
if self.kv_cache_camera_only:
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
else:
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
kv_cache[f"k_{global_idx}_special"] = new_special_k
kv_cache[f"v_{global_idx}_special"] = new_special_v
else:
kv_cache[f"k_{global_idx}_special"] = torch.cat(
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
kv_cache[f"v_{global_idx}_special"] = torch.cat(
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
if self.kv_cache_include_scale_frames:
kv_cache[f"k_{global_idx}"] = torch.cat([
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat([
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
else:
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]

514
lingbot_map/layers/block.py Normal file
View File

@@ -0,0 +1,514 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import math
import torch
from torch import nn, Tensor
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
from functools import lru_cache, partial
from torch.nn.attention.flex_attention import BlockMask, create_mask
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
x = drop_add_residual_stochastic_depth(
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
if pos is not None:
# if necessary, apply rope to the subset
pos = pos[brange]
residual = residual_func(x_subset, pos=pos)
else:
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
class FlashInferBlock(nn.Module):
"""
FlashInfer variant of causal block for GCT.
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
Supports optimized token layout and KV cache streaming inference.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = FlashInferAttention(
dim=dim,
num_heads=num_heads,
qk_norm=qk_norm,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
Returns:
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
ready for manager.append_frame + manager.compute_attention.
"""
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
def forward(
self,
x: Tensor,
pos=None,
enable_ulysses_cp=False,
num_patches=None,
num_special=None,
num_frames=None,
enable_3d_rope=False,
kv_cache=None,
global_idx=0,
num_frame_per_block=1,
num_frame_for_scale=-1,
num_register_tokens=4,
) -> Tensor:
# Phase 2 (streaming): single-frame FlashInfer paged attention.
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
if is_streaming:
manager = kv_cache
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
# Eager: write frame K/V to paged cache
manager.append_frame(global_idx, k_nhd, v_nhd)
# CPU-only: update eviction state (deque ops, no GPU kernel)
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.attn.kv_cache_scale_frames,
sliding_window=self.attn.kv_cache_sliding_window,
cross_frame_special=self.attn.kv_cache_cross_frame_special,
include_scale_frames=self.attn.kv_cache_include_scale_frames,
camera_only=self.attn.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
attn_x = manager.compute_attention(global_idx, q_nhd)
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
self.attn.num_heads * self.attn.head_dim)
# Compiled: output projection
attn_x = self.attn.proj(attn_x)
x = x + self.ls1(attn_x)
else:
# Phase 1 (multi-frame scale pass) or non-streaming training path
x = x + self.ls1(self.attn(
self.norm1(x),
pos=pos,
enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches,
num_special=num_special,
num_frames=num_frames,
enable_3d_rope=enable_3d_rope,
kv_cache=kv_cache,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
x = self.ffn_residual(x)
return x
def ffn_residual(self, x: Tensor) -> Tensor:
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
Includes the residual add (x + ...) so torch.compile captures the entire
ffn branch as one CUDA graph.
"""
return x + self.ls2(self.mlp(self.norm2(x)))
class CameraBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
qk_norm=qk_norm, qkv_bias=qkv_bias,
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only)
self.sliding_window_size = sliding_window_size
self.attend_to_scale_frames = attend_to_scale_frames
self.num_random_frames = num_random_frames
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
self.masks = {}
@torch.no_grad()
def _prepare_blockwise_causal_attn_mask(self,
device: torch.device | str, num_frames: int = 21,
frame_seqlen: int = 1560, num_frame_per_block=1
) -> BlockMask:
"""
we will divide the token sequence into the following format
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length
ends = torch.zeros(total_length + padded_length,
device=device, dtype=torch.long)
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(
start=0,
end=total_length,
step=frame_seqlen * num_frame_per_block,
device=device
)
for tmp in frame_indices:
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
frame_seqlen * num_frame_per_block
def attention_mask(b, h, q_idx, kv_idx):
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length, device=device)
return block_mask
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Fast path for full attention (camera head) - skip mask computation
if full_attention:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
mask_block = self._prepare_blockwise_causal_attn_mask(
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
class SDPABlock(nn.Module):
"""
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
with dict-based KV cache. No FlashInfer dependency required.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SDPAAttention(
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, drop=drop, bias=ffn_bias)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
def attn_residual_func(x, pos=None):
return self.ls1(self.attn(
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
def ffn_residual_func(x):
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x

View File

@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,582 @@
"""
FlashInfer KV Cache Manager — Two-Stream Paged Design.
Two logical streams sharing one physical page pool per layer:
Patch stream (recyclable):
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
- Exactly 1 patch page per frame
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
Special stream (append-only, never recycled):
- num_special_tokens (6) special tokens per frame
- Packed continuously: one special page holds floor(page_size/6) frames
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
- Specials written for EVERY frame (including scale + window), not just evicted ones.
Physical layout per block:
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
dim 1: 0=K 1=V
Attention computation:
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
Special pages placed LAST → paged_kv_last_page_len naturally describes
the partial special-tail without a custom mask.
plan() is called ONCE per frame step (when block_idx == 0).
run() is called per layer, reusing the same plan. All layers at the
same frame step have identical page structures (same page IDs in same
positions), so reusing the plan across layers is correct.
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
append_frame(block_idx, k, v)
evict_frames(block_idx, scale_frames, sliding_window, ...)
compute_attention(block_idx, q) -> out
reset()
"""
import collections
import math
from typing import List
import torch
from torch import Tensor
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
class FlashInferKVCacheManager:
"""
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
Args:
num_blocks: Number of Transformer blocks (one cache per block).
max_num_frames: Maximum frames held in the KV window at once
(scale_frames + sliding_window + headroom).
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
num_heads: Number of KV heads (= QO heads; MHA assumed).
head_dim: Head dimension (64 for ViT-L).
dtype: Storage dtype (bfloat16 / float16).
device: CUDA device.
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
scale_frames: Number of always-resident scale frames (8).
sliding_window: Sliding window size (64).
max_total_frames: Upper bound on total frames ever processed; used to
pre-allocate the special page pool (default 2048).
"""
def __init__(
self,
num_blocks: int,
max_num_frames: int,
tokens_per_frame: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
num_special_tokens: int = 6,
scale_frames: int = 8,
sliding_window: int = 64,
max_total_frames: int = 2048,
force_fp32: bool = False,
fa3: bool = False,
):
if not FLASHINFER_AVAILABLE:
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
self.num_blocks = num_blocks
self.num_special_tokens = num_special_tokens # 6
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
p = self.patches_per_frame
if fa3:
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
self.page_size = 1 << (p - 1).bit_length()
else:
self.page_size = p # exact: no zero padding in patch pages
self.scale_frames = scale_frames # 8
self.sliding_window = sliding_window # 64
self.num_heads = num_heads
self.head_dim = head_dim
self.tokens_per_frame = tokens_per_frame
assert self.patches_per_frame > 0, (
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
)
assert self.page_size > 0
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
self.force_fp32 = force_fp32
if force_fp32:
self.dtype = torch.float32
else:
if dtype == torch.float32:
dtype = torch.bfloat16
self.dtype = dtype
self.device = device
# ── Page pool sizing ─────────────────────────────────────────────────
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
max_special_pages = (
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
)
self.max_patch_pages = max_patch_pages
self.max_num_pages = max_patch_pages + max_special_pages
# ── Physical paged KV caches ─────────────────────────────────────────
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
self.kv_caches: List[Tensor] = [
torch.zeros(
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
dtype=dtype, device=device,
)
for _ in range(num_blocks)
]
# ── Per-block state ──────────────────────────────────────────────────
# Patch pages (IDs 0 .. max_patch_pages-1)
self.scale_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.live_window_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.free_patch_pages: List[List[int]] = [
list(range(max_patch_pages)) for _ in range(num_blocks)
]
# Special pages (IDs max_patch_pages .. max_num_pages-1)
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
self.free_special_pages: List[List[int]] = [
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
]
self.special_token_count: List[int] = [0] * num_blocks
# Frame counter per block (determines scale vs window routing)
self.frame_count: List[int] = [0] * num_blocks
# ── FlashInfer wrapper ───────────────────────────────────────────────
# plan() is called once per frame step (block_idx == 0).
# run() is called per layer, reusing the same aux structures.
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
# FlashInfer 0.2.5 at 518×378 resolution.
_fi_backend = "fa3" if fa3 else "fa2"
self.workspace_buffer = torch.zeros(
128 * 1024 * 1024, dtype=torch.uint8, device=device
)
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
kv_layout="NHD",
backend=_fi_backend,
)
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
self._qo_indptr = torch.tensor(
[0, tokens_per_frame], dtype=torch.int32, device=device
)
# =========================================================================
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
# =========================================================================
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
"""
Append one frame's K/V tensors to the two-stream cache.
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
i.e. specials come first (matching stream.py's patch_start_idx convention).
Args:
block_idx: Block/layer index (0 … num_blocks-1).
k: [tokens_per_frame, H, D] NHD layout.
v: [tokens_per_frame, H, D] NHD layout.
"""
n = self.num_special_tokens # 6
sp_k = k[:n].to(self.dtype) # [6, H, D]
patch_k = k[n:].to(self.dtype) # [256, H, D]
sp_v = v[:n].to(self.dtype)
patch_v = v[n:].to(self.dtype)
assert patch_k.shape[0] == self.patches_per_frame, (
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
)
self._write_patch_page(block_idx, patch_k, patch_v)
self._write_special_tokens(block_idx, sp_k, sp_v)
self.frame_count[block_idx] += 1
def evict_frames(
self,
block_idx: int,
scale_frames: int,
sliding_window: int,
cross_frame_special: bool = True,
include_scale_frames: bool = True,
camera_only: bool = False,
num_register_tokens: int = 4,
) -> None:
"""
Evict old window patch pages (recycle to free list).
Special pages are NEVER evicted.
Scale pages are NEVER evicted.
Only live_window_patch_pages beyond `sliding_window` are recycled.
"""
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def _gather_kv(self, block_idx: int):
"""
Gather all visible K and V tokens from the paged cache into dense tensors.
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
Returns:
k_flat: [kv_len, H, D] — all visible K tokens concatenated
v_flat: [kv_len, H, D] — all visible V tokens concatenated
"""
visible = self.build_visible_page_table(block_idx)
last_len = self.compute_last_page_len(block_idx)
P = self.page_size
parts_k, parts_v = [], []
for i, pid in enumerate(visible):
n = last_len if (i == len(visible) - 1) else P
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
v_flat = torch.cat(parts_v, dim=0)
return k_flat, v_flat
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
"""
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
When self.force_fp32 is True, gathers all visible K/V into dense tensors
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
plan() is called once per frame step (when block_idx == 0).
All layers at the same step share the same visible page structure,
so the plan is reused by calling run() with each layer's kv_cache.
Args:
block_idx: Block/layer index.
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
Returns:
out: [q_len, H, D]
"""
if self.frame_count[block_idx] == 0:
# No KV present yet (should not occur in normal usage after append_frame)
return torch.zeros_like(q)
if self.force_fp32:
# ── fp32 gather+SDPA path ─────────────────────────────────────────
# Gather visible K/V from paged cache and run SDPA in fp32.
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
import torch.nn.functional as F_nn
k_flat, v_flat = self._gather_kv(block_idx)
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
if block_idx == 0:
# ── Plan once per frame step ──────────────────────────────────────
# Build visible page table from block 0's state.
# All blocks have identical page structures, so this plan is valid
# for all subsequent run() calls (block_idx = 1, 2, ...).
visible = self.build_visible_page_table(0)
last_len = self.compute_last_page_len(0)
assert visible, "visible page table is empty after append_frame"
assert 1 <= last_len <= self.page_size, (
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
)
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
self.prefill_wrapper.plan(
self._qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads = self.num_heads,
num_kv_heads = self.num_heads,
head_dim_qk = self.head_dim,
page_size = self.page_size,
causal = False, # custom page ordering; no causal mask
pos_encoding_mode = "NONE", # RoPE applied externally before append
q_data_type = self.dtype,
)
# ── Run attention for this layer ──────────────────────────────────────
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
return self.prefill_wrapper.run(
q = q.to(self.dtype).contiguous(),
paged_kv_cache = self.kv_caches[block_idx],
) # → [q_len, H, D]
def reset(self) -> None:
"""Reset all per-block state for a new sequence."""
for i in range(self.num_blocks):
self.scale_patch_pages[i].clear()
self.live_window_patch_pages[i].clear()
self.all_special_pages[i].clear()
self.free_patch_pages[i] = list(range(self.max_patch_pages))
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
self.special_token_count[i] = 0
self.frame_count[i] = 0
# =========================================================================
# Helper methods
# =========================================================================
def build_visible_page_table(self, block_idx: int) -> List[int]:
"""
Return page IDs in strict order: scale → window → special.
Placing special pages last means only the final page may be partially
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
without a custom attention mask.
"""
return (
list(self.scale_patch_pages[block_idx]) +
list(self.live_window_patch_pages[block_idx]) +
list(self.all_special_pages[block_idx])
)
def compute_last_page_len(self, block_idx: int) -> int:
"""
Valid token count in the last page of the visible sequence.
- No special pages → last page is a patch page.
Returns patches_per_frame (real tokens written),
which may be < page_size when page_size was rounded
up to a power of 2.
- Special tail partial → special_token_count % page_size.
- Special tail exactly full → page_size.
"""
if not self.all_special_pages[block_idx]:
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
# valid count so it doesn't read beyond the real tokens.
return self.patches_per_frame
tail = self.special_token_count[block_idx] % self.page_size
return self.page_size if tail == 0 else tail
# ── Internal write helpers ────────────────────────────────────────────────
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
"""
Allocate one free patch page and write patches_per_frame patch tokens.
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
patch_k/v fill exactly one full page (patches_per_frame == page_size).
Routes to scale_patch_pages if still filling scale quota,
otherwise to live_window_patch_pages.
Returns:
page_id: Physical page index used.
"""
assert self.free_patch_pages[block_idx], (
f"block {block_idx}: patch page pool exhausted — "
f"scale={len(self.scale_patch_pages[block_idx])}, "
f"window={len(self.live_window_patch_pages[block_idx])}, "
f"free={len(self.free_patch_pages[block_idx])}"
)
page_id = self.free_patch_pages[block_idx].pop()
# Direct slice write: positions 0..patches_per_frame-1.
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
# this is equivalent to a full-page write. When page_size > patches_per_frame
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
P = self.patches_per_frame
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
self.scale_patch_pages[block_idx].append(page_id)
else:
self.live_window_patch_pages[block_idx].append(page_id)
return page_id
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
"""
Append num_special_tokens (6) special tokens to the special stream.
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
overhead of flashinfer.page.append_paged_kv_cache.
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
two slice writes (rare — page_size=256 >> 6).
"""
remaining = self.num_special_tokens # 6
written = 0
while remaining > 0:
tail_offset = self.special_token_count[block_idx] % self.page_size
if tail_offset == 0:
# Current tail page is full (or no page exists) — allocate a new one
assert self.free_special_pages[block_idx], (
f"block {block_idx}: special page pool exhausted at "
f"special_token_count={self.special_token_count[block_idx]}. "
f"Increase max_total_frames."
)
new_page = self.free_special_pages[block_idx].pop()
self.all_special_pages[block_idx].append(new_page)
tail_page = self.all_special_pages[block_idx][-1]
space = self.page_size - tail_offset # free slots in tail page
write_n = min(remaining, space)
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
end = tail_offset + write_n
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
self.special_token_count[block_idx] += write_n
written += write_n
remaining -= write_n
# ── Legacy property (used by stream.py) ──────────────────────────────────
@property
def num_frames(self) -> int:
"""Number of frames appended to block 0 (representative)."""
return self.frame_count[0] if self.frame_count else 0
# =============================================================================
# Sanity check
# =============================================================================
def _sanity_check():
"""
Minimal smoke test.
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("[sanity_check] CUDA not available — skipping.")
return
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
num_special = 6
patches_per_frame = tokens_per_frame - num_special # 256
page_size = patches_per_frame # 256
mgr = FlashInferKVCacheManager(
num_blocks = 2,
max_num_frames = 88,
tokens_per_frame = tokens_per_frame,
num_heads = 16,
head_dim = 64,
dtype = torch.bfloat16,
device = device,
num_special_tokens = num_special,
scale_frames = 8,
sliding_window = 64,
max_total_frames = 200,
)
def make_kv():
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
return k, v
def make_q():
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
for block in range(2):
for t in range(100):
k, v = make_kv()
mgr.append_frame(block, k, v)
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
# ── Page count checks ───────────────────────────────────────────────
n_scale = len(mgr.scale_patch_pages[block])
n_window = len(mgr.live_window_patch_pages[block])
n_spec = len(mgr.all_special_pages[block])
sp_count = mgr.special_token_count[block]
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
expected_spec_pages = math.ceil(100 * num_special / page_size)
assert n_spec == expected_spec_pages, (
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
)
assert sp_count == 100 * num_special, (
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
)
# ── last_page_len ────────────────────────────────────────────────────
last_len = mgr.compute_last_page_len(block)
tail = sp_count % page_size
expected_len = page_size if tail == 0 else tail
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
# ── visible page table order ─────────────────────────────────────────
visible = mgr.build_visible_page_table(block)
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
for pid in visible[:n_scale + n_window]:
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
for pid in visible[n_scale + n_window:]:
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
# ── forward pass: plan() once for block 0, run() for both blocks ─────
if block == 1:
# Simulate the actual calling pattern: plan on block 0, run on both
q0 = make_q()
out0 = mgr.compute_attention(0, q0) # triggers plan()
q1 = make_q()
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
assert out0.shape == (tokens_per_frame, 16, 64)
assert out1.shape == (tokens_per_frame, 16, 64)
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
f"special_pages={n_spec}, special_tokens={sp_count}, "
f"last_page_len={last_len}")
mgr.reset()
assert mgr.frame_count[0] == 0
print("\n[sanity_check] All assertions passed.")
if __name__ == "__main__":
_sanity_check()

View File

@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

40
lingbot_map/layers/mlp.py Normal file
View File

@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

474
lingbot_map/layers/rope.py Normal file
View File

@@ -0,0 +1,474 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Implementation of 2D Rotary Position Embeddings (RoPE).
# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.
# Inspired by:
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
# https://github.com/naver-ai/rope-vit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import List, Optional, Tuple, Union
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
计算1D旋转位置编码RoPE的频率张量。
RoPE的核心思想使用旋转矩阵来编码位置信息使得相对位置关系保持不变。
公式对于位置m和维度i频率为 θ_i = θ^(-2i/d)其中θ是基础频率默认10000
Args:
dim: 特征维度,必须是偶数(因为要成对处理)
pos: 位置索引可以是整数自动生成0到pos-1的序列或位置数组 [S]
theta: 基础频率控制位置编码的周期性默认10000
use_real: 是否返回实数形式cos和sin分开还是复数形式
linear_factor: 线性缩放因子,用于上下文扩展
ntk_factor: NTK-Aware缩放因子用于处理更长的序列
repeat_interleave_real: 当use_real=True时是否交错重复用于某些模型架构
freqs_dtype: 频率张量的数据类型
Returns:
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
实数形式:两个 [S, D] 的张量cos和sin
"""
# 确保维度是偶数RoPE需要成对处理维度
assert dim % 2 == 0
# 将位置转换为torch张量
if isinstance(pos, int):
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # [S]
# 应用NTK缩放Neural Tangent Kernel用于处理训练时未见过的长序列
theta = theta * ntk_factor
# 步骤1计算频率 θ_i = 1 / (θ^(2i/d))
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
# 公式freq_i = 1 / (theta^(2i/d) * linear_factor)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2],每个频率对应一个维度对
# 步骤2计算位置-频率矩阵
# 使用外积pos[m] * freqs[i] = m * θ_i
# 结果每个位置m和每个频率i的组合
freqs = torch.outer(pos, freqs) # [S, D/2]
# 步骤3根据返回格式转换
if use_real and repeat_interleave_real:
# 方式1交错重复用于flux, hunyuan-dit, cogvideox等模型
# 将每个频率的cos和sin交错排列[cos_0, cos_0, cos_1, cos_1, ...]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# 方式2拼接重复用于stable audio, allegro等模型
# 将所有cos拼接然后是所有sin[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# 方式3复数形式用于lumina等模型
# 使用欧拉公式e^(iθ) = cos(θ) + i*sin(θ)
# torch.polar(r, θ) 返回 r * e^(iθ)这里r=1所以就是 e^(i*freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
return freqs_cis
class WanRotaryPosEmbed(nn.Module):
"""
3D旋转位置编码3D RoPE模块
核心思想将RoPE扩展到3D空间时间、高度、宽度为视频或3D数据提供位置编码。
每个维度t, h, w独立使用RoPE然后拼接起来。
公式:
对于3D位置 (f, h, w)(帧、高度、宽度):
- 帧维度使用 dim_f 个特征维度
- 高度维度使用 dim_h 个特征维度
- 宽度维度使用 dim_w 个特征维度
其中 dim_f + dim_h + dim_w = attention_head_dim
"""
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int = 1024,
theta: float = 10000.0,
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
):
super().__init__()
self.attention_head_dim = attention_head_dim # 注意力头的总维度
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
# 步骤1分配维度给三个空间维度
if fhw_dim is not None:
# 如果指定了维度分配,使用指定的
assert attention_head_dim == sum(
fhw_dim
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
t_dim, h_dim, w_dim = fhw_dim
else:
# 否则自动分配h和w各占1/3t占剩余
# 例如如果attention_head_dim=64则 h_dim=w_dim=21t_dim=22
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
# 保存维度分配以便在forward中使用
self.fhw_dim = (t_dim, h_dim, w_dim)
# 步骤2为每个维度预计算频率
# 分别计算时间、高度、宽度三个维度的RoPE频率
freqs = []
for dim in [t_dim, h_dim, w_dim]:
# 每个维度独立调用1D RoPE
# 返回复数形式的频率: [max_seq_len, dim//2]
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
freqs.append(freq)
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
self.freqs = torch.cat(freqs, dim=1)
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
"""
前向传播为3D输入视频帧+patch生成旋转位置编码
参数:
- ppf (int): 帧数patches per frame当f_end为None时使用
- pph (int): 每帧的patch高度数量
- ppw (int): 每帧的patch宽度数量
- patch_start_idx (int): 每帧的特殊token数量在patches之前
- device: 计算设备CPU/GPU
- f_start (int): 起始帧索引用于causal模式默认为0
- f_end (Optional[int]): 结束帧索引用于causal模式如果为None则使用ppf作为帧数
返回:
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
Token排列顺序
[frame0_special_token_0, ..., frame0_special_token_N,
frame0_patch_0, ..., frame0_patch_M,
frame1_special_token_0, ..., frame1_special_token_N,
frame1_patch_0, ..., frame1_patch_M,
...]
模式:
- 非causal模式f_end=None使用ppf作为帧数从位置0开始
- Causal模式f_end不为None使用[f_start, f_end)范围的帧ppf会被重新计算
"""
# 步骤1将预计算的频率移到目标设备并分割成三个维度
self.freqs = self.freqs.to(device)
# 获取实际的维度分配
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
t_dim, h_dim, w_dim = self.fhw_dim
else:
# 自动分配的情况
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
# 使用正确的split sizes每个维度的一半
freqs = self.freqs.split_with_sizes(
[
t_dim // 2, # 时间维度
h_dim // 2, # 高度维度
w_dim // 2, # 宽度维度
],
dim=1,
)
# 处理causal模式如果指定了f_end重新计算ppf和帧范围
if f_end is not None:
ppf = f_end - f_start
frame_slice = slice(f_start, f_end)
else:
# 非causal模式使用从0开始的ppf个帧
frame_slice = slice(0, ppf)
# 步骤2处理特殊token如果存在
## For other tokens
if patch_start_idx > 0:
# 2.1 为特殊token生成位置编码
# 特殊token位于对角线位置 (f, i, i)每个特殊token有唯一位置
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
# Shape: (ppf, patch_start_idx, dim)
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
# 2.2 为图像patch生成位置编码
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w)h,w 整体偏移 patch_start_idx
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
# Shape: (ppf, pph, ppw, dim)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
# 步骤3按照正确的顺序组合特殊token和patches
# 每帧内部顺序:[特殊tokens, patches]
# Concatenate special tokens and patches for each frame along the second dimension
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
# 步骤4展平为最终形状并添加batch和head维度
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
return freqs
# 如果没有特殊tokenpatch_start_idx == 0只处理图像patches
# 所有patches位于 (f, 0:pph, 0:ppw)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
return freqs
def apply_rotary_emb(x, freqs):
"""
应用旋转位置编码到输入特征
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
数学原理:
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
这等价于旋转矩阵:
[cos(θ) -sin(θ)] [x1]
[sin(θ) cos(θ)] [x2]
参数:
- x: 输入特征 [batch, heads, seq_len, head_dim]
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
返回:
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
实现步骤:
1. 将x的每两个连续特征看作一个复数 (real, imag)
2. 与预计算的复数频率 e^(iθ) 相乘
3. 转回实数表示
"""
# 步骤1reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
# 步骤2转换为复数表示 [b, h, seq, 32]
# 每个元素是 real + imag*i
x_complex = torch.view_as_complex(x_reshaped)
# 步骤3复数乘法实现旋转
# x_complex * freqs 相当于将每对特征旋转θ角度
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
x_rotated = x_complex * freqs
# 步骤4转回实数表示 [b, h, seq, 32, 2]
x_real = torch.view_as_real(x_rotated)
# 步骤5展平最后两维 [b, h, seq, 64]
x_out = x_real.flatten(3)
# 步骤6转回原始数据类型
return x_out.to(x.dtype)

View File

@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import os
from typing import Callable, Optional
import warnings
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
# try:
# if XFORMERS_ENABLED:
# from xformers.ops import SwiGLU
# XFORMERS_AVAILABLE = True
# warnings.warn("xFormers is available (SwiGLU)")
# else:
# warnings.warn("xFormers is disabled (SwiGLU)")
# raise ImportError
# except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
# warnings.warn("xFormers is not available (SwiGLU)")
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)

View File

@@ -0,0 +1,411 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.nn.init import trunc_normal_
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
# TODO: Check this
# We replace NestedTensorBlock with Block
from .block import Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
drop_cls_token=False,
qk_norm=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if not drop_cls_token else 0
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.use_reentrant = False # hardcoded to False
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.drop_cls_token = drop_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
qk_norm=qk_norm,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
if not self.drop_cls_token:
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
else:
patch_pos_embed = pos_embed
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
if not self.drop_cls_token:
x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
else:
x = patch_pos_embed
return x.to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
if self.training:
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
else:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
if self.training:
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
else:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=True, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model