first commit
This commit is contained in:
5
lingbot_map/layers/__init__.py
Normal file
5
lingbot_map/layers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from lingbot_map.layers.mlp import Mlp
|
||||
from lingbot_map.layers.patch_embed import PatchEmbed
|
||||
from lingbot_map.layers.block import Block
|
||||
from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
|
||||
from lingbot_map.layers.attention import Attention as MemEffAttention
|
||||
766
lingbot_map/layers/attention.py
Normal file
766
lingbot_map/layers/attention.py
Normal file
@@ -0,0 +1,766 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lingbot_map.layers.rope import apply_rotary_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# FlashInfer imports (optional - for paged attention)
|
||||
try:
|
||||
import flashinfer
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_AVAILABLE = False
|
||||
print("flashinfer not available")
|
||||
|
||||
try:
|
||||
from torchtitan.distributed.sequence_parallel import (
|
||||
gather_seq_scatter_heads,
|
||||
gather_heads_scatter_seq,
|
||||
pad_tensor,
|
||||
slice_input_tensor_scale_grad,
|
||||
gather_outputs,
|
||||
)
|
||||
except ImportError:
|
||||
print("torchtitan not available for ulysses cp")
|
||||
|
||||
def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
|
||||
"""Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
|
||||
q = gather_seq_scatter_heads(q, seq_dim, head_dim)
|
||||
k = gather_seq_scatter_heads(k, seq_dim, head_dim)
|
||||
v = gather_seq_scatter_heads(v, seq_dim, head_dim)
|
||||
return q, k, v
|
||||
|
||||
from typing_extensions import List
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = fused_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.rope = rope
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
|
||||
if self.rope is not None:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
if enable_ulysses_cp:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalAttention(nn.Module):
|
||||
"""
|
||||
Causal self-attention module with KV cache support for streaming inference.
|
||||
Used by CasualBlockCamera in camera_head.py.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
elementwise_attn_output_gate=False,
|
||||
# KV cache eviction parameters (matching build_attn_mask)
|
||||
kv_cache_sliding_window: int =64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = fused_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.rope = rope
|
||||
|
||||
self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
|
||||
|
||||
# Store KV cache eviction parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
# Calculate special token indices
|
||||
camera_token_idx = 0
|
||||
scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
|
||||
|
||||
# [3, B, num_heads, N, head_dim]
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.gate_proj is not None:
|
||||
gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
if kv_cache is None:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if enable_ulysses_cp:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
N = q.shape[2] # Update N after gather
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif enable_3d_rope and pos is not None:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
with torch.no_grad():
|
||||
block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
|
||||
if block_mask.dim() == 2:
|
||||
block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
|
||||
block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
||||
|
||||
video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
|
||||
video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
||||
|
||||
mask = block_mask | ~video_mask
|
||||
|
||||
# Apply sliding window mask if sliding_window_size > 0
|
||||
# sliding_window_size is in units of num_frame_per_block
|
||||
if sliding_window_size > 0 and frame_seqlen is not None:
|
||||
# Create sliding window mask: each frame can only attend to frames within the window
|
||||
num_frames = N // frame_seqlen
|
||||
sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
|
||||
|
||||
for i in range(num_frames):
|
||||
q_start = i * frame_seqlen
|
||||
q_end = (i + 1) * frame_seqlen
|
||||
# Calculate the window start: sliding_window_size is in units of num_frame_per_block
|
||||
# So the actual window size in frames is sliding_window_size * num_frame_per_block
|
||||
window_size_in_frames = sliding_window_size * num_frame_per_block
|
||||
window_start_frame = max(0, i - window_size_in_frames + 1)
|
||||
k_start = window_start_frame * frame_seqlen
|
||||
k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
|
||||
sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
|
||||
|
||||
# Combine with existing mask: both masks need to allow attention
|
||||
mask = mask & sliding_mask
|
||||
|
||||
# If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
|
||||
if num_frame_for_scale > 0:
|
||||
for i in range(num_frames):
|
||||
q_start = i * frame_seqlen
|
||||
q_end = (i + 1) * frame_seqlen
|
||||
# Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
|
||||
mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
|
||||
|
||||
## global attention for the first num_frame_for_scale frames
|
||||
if num_frame_for_scale > 0:
|
||||
mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
attn_mask=mask
|
||||
)
|
||||
else:
|
||||
# Apply RoPE to current k before caching
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif enable_3d_rope and pos is not None:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
# Check if we should skip appending to cache (non-keyframe in keyframe mode)
|
||||
skip_append = kv_cache.get("_skip_append", False)
|
||||
|
||||
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
|
||||
if not skip_append:
|
||||
# KEYFRAME: store in cache (original behavior)
|
||||
if kv_cache[f"k_{global_idx}"] is None:
|
||||
kv_cache[f"k_{global_idx}"] = k_reshaped
|
||||
kv_cache[f"v_{global_idx}"] = v_reshaped
|
||||
else:
|
||||
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
||||
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
||||
|
||||
# Apply sliding window eviction BEFORE attention to match causal_3drope behavior
|
||||
# This ensures current frame only attends to frames within the sliding window
|
||||
self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
|
||||
|
||||
# Retrieve full k, v from cache (already RoPE-applied, already evicted)
|
||||
k = kv_cache[f"k_{global_idx}"].clone()
|
||||
v = kv_cache[f"v_{global_idx}"].clone()
|
||||
else:
|
||||
# NON-KEYFRAME: attend to [cached + current] without storing in cache
|
||||
if kv_cache[f"k_{global_idx}"] is not None:
|
||||
k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
||||
v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
||||
else:
|
||||
k = k_reshaped
|
||||
v = v_reshaped
|
||||
a, b, c, d, e = k.shape
|
||||
|
||||
k = k.reshape(a, b, c*d, e)
|
||||
v = v.reshape(a, b, c*d, e)
|
||||
|
||||
# Prepend special tokens (camera + scale) from evicted frames if they exist
|
||||
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
||||
special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
|
||||
special_v = kv_cache[f"v_{global_idx}_special"]
|
||||
sa, sb, sc, sd, se = special_k.shape
|
||||
special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
|
||||
special_v = special_v.reshape(sa, sb, sc * sd, se)
|
||||
|
||||
# Prepend special tokens (older frames first)
|
||||
k = torch.cat([special_k, k], dim=2)
|
||||
v = torch.cat([special_v, v], dim=2)
|
||||
|
||||
# Note: k from cache is already RoPE-applied, no need to apply again
|
||||
|
||||
if self.fused_attn:
|
||||
# Use mask-based SDPA to ensure same kernel as batch mode
|
||||
# The causal constraint is enforced by KV cache contents, not by mask
|
||||
mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
attn_mask=mask,
|
||||
)
|
||||
|
||||
if self.gate_proj is not None:
|
||||
x = x * torch.sigmoid(gate_score)
|
||||
if enable_ulysses_cp:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
# Use actual dimensions from attention output, not original input C
|
||||
# x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
|
||||
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
|
||||
"""
|
||||
Apply sliding window eviction to KV cache BEFORE attention.
|
||||
|
||||
This ensures current frame only attends to frames within the sliding window,
|
||||
matching the behavior of causal_3drope's attention mask.
|
||||
"""
|
||||
sliding_window_frames = self.kv_cache_sliding_window
|
||||
scale_frames = self.kv_cache_scale_frames
|
||||
|
||||
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
||||
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
||||
|
||||
if num_cached_frames > sliding_window_frames + scale_frames:
|
||||
evict_start = scale_frames
|
||||
evict_end = num_cached_frames - sliding_window_frames
|
||||
|
||||
if evict_end > evict_start:
|
||||
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
|
||||
if self.kv_cache_cross_frame_special:
|
||||
if self.kv_cache_camera_only:
|
||||
# Only keep camera token
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
else:
|
||||
# Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
|
||||
# Special tokens are in range [camera_token_idx, scale_token_idx+1)
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
|
||||
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
||||
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
||||
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
||||
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
||||
|
||||
if self.kv_cache_include_scale_frames:
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
|
||||
|
||||
class FlashInferAttention(Attention):
|
||||
"""
|
||||
FlashInfer variant of the GCT attention layer.
|
||||
Uses FlashInferKVCacheManager for paged KV cache storage and
|
||||
FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
|
||||
Supports the same optimized token layout and KV cache streaming inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True,
|
||||
rope=None,
|
||||
# KV cache eviction parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
||||
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
fused_attn=fused_attn,
|
||||
rope=rope,
|
||||
)
|
||||
|
||||
# Store KV cache eviction parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
||||
"""Fused pre-attention ops for single-frame streaming (Phase 2).
|
||||
|
||||
Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
|
||||
[tpf, H, D] format ready for append_frame + compute_attention.
|
||||
|
||||
Extracted as a method so torch.compile can capture all pre-attn ops as one
|
||||
CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
|
||||
squeeze/permute/contiguous x3).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None: # enable_3d_rope=True
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
# Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
|
||||
q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
return q_nhd, k_nhd, v_nhd
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
# KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
"""
|
||||
Forward pass with FlashInfer paged KV cache and attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, N, C]
|
||||
kv_cache: FlashInferKVCacheManager instance or None (batch mode)
|
||||
global_idx: Block index for per-block cache access
|
||||
"""
|
||||
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
||||
|
||||
B, N, C = x.shape
|
||||
|
||||
# Detect if using optimized layout
|
||||
using_optimized_layout = (num_patches is not None and num_special is not None
|
||||
and num_frames is not None)
|
||||
|
||||
# ========== Batch Mode (no KV cache manager) ==========
|
||||
if not isinstance(kv_cache, FlashInferKVCacheManager):
|
||||
# [3, B, num_heads, N, head_dim]
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
if using_optimized_layout:
|
||||
boundary = num_frames * num_patches
|
||||
q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
|
||||
q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
|
||||
q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
|
||||
q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
|
||||
)
|
||||
q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
|
||||
q_special, k_special, v_special, seq_dim=2, head_dim=1
|
||||
)
|
||||
q = torch.cat([q_patch, q_special], dim=2)
|
||||
k = torch.cat([k_patch, k_special], dim=2)
|
||||
v = torch.cat([v_patch, v_special], dim=2)
|
||||
else:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
# Batch mode: use SDPA for numerical consistency with SDPA variant
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
if using_optimized_layout:
|
||||
seq_global = x.shape[2]
|
||||
seq_local = num_frames * (num_patches + num_special)
|
||||
cp_size = seq_global // seq_local
|
||||
boundary_global = num_frames * cp_size * num_patches
|
||||
x_patch = x[:, :, :boundary_global, :]
|
||||
x_special = x[:, :, boundary_global:, :]
|
||||
x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
|
||||
x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
|
||||
x = torch.cat([x_patch, x_special], dim=2)
|
||||
else:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# ========== Streaming Mode (with FlashInferKVCacheManager) ==========
|
||||
else:
|
||||
manager = kv_cache # FlashInferKVCacheManager
|
||||
|
||||
# Phase 1 (scale frames): num_frames > 1 — multi-frame batch
|
||||
# Phase 2 (streaming): num_frames == 1 — single frame
|
||||
is_multi_frame = (num_frames is not None and num_frames > 1)
|
||||
|
||||
if is_multi_frame:
|
||||
# Phase 1: compute full self-attention via SDPA (all frames attend to each other),
|
||||
# then append each frame's K/V to the paged cache one at a time.
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
# Apply RoPE before caching (RoPE baked into K before append)
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# Append each frame's K/V to the paged cache individually.
|
||||
tpf = manager.tokens_per_frame
|
||||
k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
|
||||
v_all = v.squeeze(0).permute(1, 0, 2)
|
||||
for f_idx in range(num_frames):
|
||||
s = f_idx * tpf
|
||||
manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
camera_only=self.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
else:
|
||||
# Phase 2: single-frame streaming via FlashInfer paged attention.
|
||||
q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
|
||||
# 1. Append to paged cache
|
||||
manager.append_frame(global_idx, k_nhd, v_nhd)
|
||||
|
||||
# 2. Apply sliding window eviction
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
camera_only=self.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
|
||||
# 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
||||
x = manager.compute_attention(global_idx, q_nhd)
|
||||
|
||||
# Convert back: [tpf, H, D] -> [B, tpf, C].
|
||||
x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SDPAAttention(Attention):
|
||||
"""
|
||||
SDPA variant for streaming inference.
|
||||
Uses F.scaled_dot_product_attention with dict-based KV cache.
|
||||
No FlashInfer dependency required — works on any CUDA GPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
|
||||
attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
|
||||
qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
|
||||
)
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
using_optimized_layout = (num_patches is not None and num_special is not None
|
||||
and num_frames is not None)
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
# ========== Batch Mode (no KV cache) ==========
|
||||
if kv_cache is None:
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# ========== Streaming Mode (with KV cache dict) ==========
|
||||
else:
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
camera_token_idx = 0
|
||||
scale_token_idx = camera_token_idx + num_register_tokens + 1
|
||||
|
||||
if kv_cache[f"k_{global_idx}"] is None:
|
||||
kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
|
||||
N // num_frame_per_block, self.head_dim)
|
||||
kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
|
||||
N // num_frame_per_block, self.head_dim)
|
||||
else:
|
||||
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat((
|
||||
kv_cache[f"k_{global_idx}"],
|
||||
k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
), dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat((
|
||||
kv_cache[f"v_{global_idx}"],
|
||||
v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
), dim=2)
|
||||
|
||||
self._apply_kv_cache_eviction(
|
||||
kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
|
||||
)
|
||||
|
||||
k_cached = kv_cache[f"k_{global_idx}"].clone()
|
||||
v_cached = kv_cache[f"v_{global_idx}"].clone()
|
||||
a, b, c, d, e = k_cached.shape
|
||||
k_full = k_cached.reshape(a, b, c * d, e)
|
||||
v_full = v_cached.reshape(a, b, c * d, e)
|
||||
|
||||
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
||||
special_k = kv_cache[f"k_{global_idx}_special"]
|
||||
special_v = kv_cache[f"v_{global_idx}_special"]
|
||||
sa, sb, sc, sd, se = special_k.shape
|
||||
k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
|
||||
v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
|
||||
|
||||
q_seq_len = q.shape[2]
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k_full, v_full,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
|
||||
"""Apply sliding window eviction to KV cache."""
|
||||
sliding_window_frames = self.kv_cache_sliding_window
|
||||
scale_frames = self.kv_cache_scale_frames
|
||||
|
||||
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
||||
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
||||
if num_cached_frames > sliding_window_frames + scale_frames:
|
||||
evict_start = scale_frames
|
||||
evict_end = num_cached_frames - sliding_window_frames
|
||||
if evict_end > evict_start:
|
||||
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
|
||||
if self.kv_cache_cross_frame_special:
|
||||
if self.kv_cache_camera_only:
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
else:
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
|
||||
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
||||
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
||||
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
||||
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
||||
|
||||
if self.kv_cache_include_scale_frames:
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
514
lingbot_map/layers/block.py
Normal file
514
lingbot_map/layers/block.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
import warnings
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
|
||||
from functools import lru_cache, partial
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_mask
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
qk_norm=qk_norm,
|
||||
fused_attn=fused_attn,
|
||||
rope=rope,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
if pos is not None:
|
||||
# if necessary, apply rope to the subset
|
||||
pos = pos[brange]
|
||||
residual = residual_func(x_subset, pos=pos)
|
||||
else:
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
class FlashInferBlock(nn.Module):
|
||||
"""
|
||||
FlashInfer variant of causal block for GCT.
|
||||
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
|
||||
Supports optimized token layout and KV cache streaming inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = FlashInferAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=qk_norm,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
||||
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
|
||||
|
||||
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
|
||||
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
|
||||
|
||||
Returns:
|
||||
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
|
||||
ready for manager.append_frame + manager.compute_attention.
|
||||
"""
|
||||
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos=None,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=None,
|
||||
num_special=None,
|
||||
num_frames=None,
|
||||
enable_3d_rope=False,
|
||||
kv_cache=None,
|
||||
global_idx=0,
|
||||
num_frame_per_block=1,
|
||||
num_frame_for_scale=-1,
|
||||
num_register_tokens=4,
|
||||
) -> Tensor:
|
||||
# Phase 2 (streaming): single-frame FlashInfer paged attention.
|
||||
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
|
||||
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
|
||||
if is_streaming:
|
||||
manager = kv_cache
|
||||
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
|
||||
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
# Eager: write frame K/V to paged cache
|
||||
manager.append_frame(global_idx, k_nhd, v_nhd)
|
||||
# CPU-only: update eviction state (deque ops, no GPU kernel)
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.attn.kv_cache_scale_frames,
|
||||
sliding_window=self.attn.kv_cache_sliding_window,
|
||||
cross_frame_special=self.attn.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.attn.kv_cache_include_scale_frames,
|
||||
camera_only=self.attn.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
||||
attn_x = manager.compute_attention(global_idx, q_nhd)
|
||||
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
|
||||
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
|
||||
self.attn.num_heads * self.attn.head_dim)
|
||||
# Compiled: output projection
|
||||
attn_x = self.attn.proj(attn_x)
|
||||
x = x + self.ls1(attn_x)
|
||||
else:
|
||||
# Phase 1 (multi-frame scale pass) or non-streaming training path
|
||||
x = x + self.ls1(self.attn(
|
||||
self.norm1(x),
|
||||
pos=pos,
|
||||
enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches,
|
||||
num_special=num_special,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope,
|
||||
kv_cache=kv_cache,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
x = self.ffn_residual(x)
|
||||
return x
|
||||
|
||||
def ffn_residual(self, x: Tensor) -> Tensor:
|
||||
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
|
||||
|
||||
Includes the residual add (x + ...) so torch.compile captures the entire
|
||||
ffn branch as one CUDA graph.
|
||||
"""
|
||||
return x + self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
|
||||
class CameraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
elementwise_attn_output_gate: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
attend_to_scale_frames: bool = False,
|
||||
num_random_frames: int = 0,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
|
||||
qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only)
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.num_random_frames = num_random_frames
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
self.masks = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_blockwise_causal_attn_mask(self,
|
||||
device: torch.device | str, num_frames: int = 21,
|
||||
frame_seqlen: int = 1560, num_frame_per_block=1
|
||||
) -> BlockMask:
|
||||
"""
|
||||
we will divide the token sequence into the following format
|
||||
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
||||
We use flexattention to construct the attention mask
|
||||
"""
|
||||
total_length = num_frames * frame_seqlen
|
||||
|
||||
# we do right padding to get to a multiple of 128
|
||||
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
||||
|
||||
ends = torch.zeros(total_length + padded_length,
|
||||
device=device, dtype=torch.long)
|
||||
|
||||
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
||||
frame_indices = torch.arange(
|
||||
start=0,
|
||||
end=total_length,
|
||||
step=frame_seqlen * num_frame_per_block,
|
||||
device=device
|
||||
)
|
||||
|
||||
for tmp in frame_indices:
|
||||
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
||||
frame_seqlen * num_frame_per_block
|
||||
|
||||
def attention_mask(b, h, q_idx, kv_idx):
|
||||
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
||||
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
||||
|
||||
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
||||
KV_LEN=total_length + padded_length, device=device)
|
||||
|
||||
return block_mask
|
||||
|
||||
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
|
||||
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
# Fast path for full attention (camera head) - skip mask computation
|
||||
if full_attention:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
mask_block = self._prepare_blockwise_causal_attn_mask(
|
||||
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
|
||||
|
||||
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
|
||||
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
class SDPABlock(nn.Module):
|
||||
"""
|
||||
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
|
||||
with dict-based KV cache. No FlashInfer dependency required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = SDPAAttention(
|
||||
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer, drop=drop, bias=ffn_bias)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
def attn_residual_func(x, pos=None):
|
||||
return self.ls1(self.attn(
|
||||
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
|
||||
def ffn_residual_func(x):
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
34
lingbot_map/layers/drop_path.py
Normal file
34
lingbot_map/layers/drop_path.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
582
lingbot_map/layers/flashinfer_cache.py
Normal file
582
lingbot_map/layers/flashinfer_cache.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
FlashInfer KV Cache Manager — Two-Stream Paged Design.
|
||||
|
||||
Two logical streams sharing one physical page pool per layer:
|
||||
|
||||
Patch stream (recyclable):
|
||||
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
|
||||
- Exactly 1 patch page per frame
|
||||
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
|
||||
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
|
||||
|
||||
Special stream (append-only, never recycled):
|
||||
- num_special_tokens (6) special tokens per frame
|
||||
- Packed continuously: one special page holds floor(page_size/6) frames
|
||||
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
|
||||
- Specials written for EVERY frame (including scale + window), not just evicted ones.
|
||||
|
||||
Physical layout per block:
|
||||
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
|
||||
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
|
||||
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
|
||||
dim 1: 0=K 1=V
|
||||
|
||||
Attention computation:
|
||||
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
|
||||
Special pages placed LAST → paged_kv_last_page_len naturally describes
|
||||
the partial special-tail without a custom mask.
|
||||
|
||||
plan() is called ONCE per frame step (when block_idx == 0).
|
||||
run() is called per layer, reusing the same plan. All layers at the
|
||||
same frame step have identical page structures (same page IDs in same
|
||||
positions), so reusing the plan across layers is correct.
|
||||
|
||||
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
|
||||
append_frame(block_idx, k, v)
|
||||
evict_frames(block_idx, scale_frames, sliding_window, ...)
|
||||
compute_attention(block_idx, q) -> out
|
||||
reset()
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_AVAILABLE = False
|
||||
|
||||
|
||||
class FlashInferKVCacheManager:
|
||||
"""
|
||||
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
|
||||
|
||||
Args:
|
||||
num_blocks: Number of Transformer blocks (one cache per block).
|
||||
max_num_frames: Maximum frames held in the KV window at once
|
||||
(scale_frames + sliding_window + headroom).
|
||||
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
|
||||
num_heads: Number of KV heads (= QO heads; MHA assumed).
|
||||
head_dim: Head dimension (64 for ViT-L).
|
||||
dtype: Storage dtype (bfloat16 / float16).
|
||||
device: CUDA device.
|
||||
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
|
||||
scale_frames: Number of always-resident scale frames (8).
|
||||
sliding_window: Sliding window size (64).
|
||||
max_total_frames: Upper bound on total frames ever processed; used to
|
||||
pre-allocate the special page pool (default 2048).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
max_num_frames: int,
|
||||
tokens_per_frame: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_special_tokens: int = 6,
|
||||
scale_frames: int = 8,
|
||||
sliding_window: int = 64,
|
||||
max_total_frames: int = 2048,
|
||||
force_fp32: bool = False,
|
||||
fa3: bool = False,
|
||||
):
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
||||
|
||||
self.num_blocks = num_blocks
|
||||
self.num_special_tokens = num_special_tokens # 6
|
||||
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
|
||||
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
|
||||
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
|
||||
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
|
||||
p = self.patches_per_frame
|
||||
if fa3:
|
||||
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
|
||||
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
|
||||
self.page_size = 1 << (p - 1).bit_length()
|
||||
else:
|
||||
self.page_size = p # exact: no zero padding in patch pages
|
||||
self.scale_frames = scale_frames # 8
|
||||
self.sliding_window = sliding_window # 64
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.tokens_per_frame = tokens_per_frame
|
||||
|
||||
assert self.patches_per_frame > 0, (
|
||||
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
|
||||
)
|
||||
assert self.page_size > 0
|
||||
|
||||
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
|
||||
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
|
||||
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
|
||||
self.force_fp32 = force_fp32
|
||||
if force_fp32:
|
||||
self.dtype = torch.float32
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
dtype = torch.bfloat16
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
# ── Page pool sizing ─────────────────────────────────────────────────
|
||||
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
|
||||
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
|
||||
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
|
||||
max_special_pages = (
|
||||
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
|
||||
)
|
||||
self.max_patch_pages = max_patch_pages
|
||||
self.max_num_pages = max_patch_pages + max_special_pages
|
||||
|
||||
# ── Physical paged KV caches ─────────────────────────────────────────
|
||||
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
|
||||
self.kv_caches: List[Tensor] = [
|
||||
torch.zeros(
|
||||
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
# ── Per-block state ──────────────────────────────────────────────────
|
||||
# Patch pages (IDs 0 .. max_patch_pages-1)
|
||||
self.scale_patch_pages: List[collections.deque] = [
|
||||
collections.deque() for _ in range(num_blocks)
|
||||
]
|
||||
self.live_window_patch_pages: List[collections.deque] = [
|
||||
collections.deque() for _ in range(num_blocks)
|
||||
]
|
||||
self.free_patch_pages: List[List[int]] = [
|
||||
list(range(max_patch_pages)) for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
# Special pages (IDs max_patch_pages .. max_num_pages-1)
|
||||
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
|
||||
self.free_special_pages: List[List[int]] = [
|
||||
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
|
||||
]
|
||||
self.special_token_count: List[int] = [0] * num_blocks
|
||||
|
||||
# Frame counter per block (determines scale vs window routing)
|
||||
self.frame_count: List[int] = [0] * num_blocks
|
||||
|
||||
# ── FlashInfer wrapper ───────────────────────────────────────────────
|
||||
# plan() is called once per frame step (block_idx == 0).
|
||||
# run() is called per layer, reusing the same aux structures.
|
||||
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
|
||||
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
|
||||
# FlashInfer 0.2.5 at 518×378 resolution.
|
||||
_fi_backend = "fa3" if fa3 else "fa2"
|
||||
self.workspace_buffer = torch.zeros(
|
||||
128 * 1024 * 1024, dtype=torch.uint8, device=device
|
||||
)
|
||||
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
backend=_fi_backend,
|
||||
)
|
||||
|
||||
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
|
||||
self._qo_indptr = torch.tensor(
|
||||
[0, tokens_per_frame], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
|
||||
# =========================================================================
|
||||
|
||||
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
|
||||
"""
|
||||
Append one frame's K/V tensors to the two-stream cache.
|
||||
|
||||
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
|
||||
i.e. specials come first (matching stream.py's patch_start_idx convention).
|
||||
|
||||
Args:
|
||||
block_idx: Block/layer index (0 … num_blocks-1).
|
||||
k: [tokens_per_frame, H, D] NHD layout.
|
||||
v: [tokens_per_frame, H, D] NHD layout.
|
||||
"""
|
||||
n = self.num_special_tokens # 6
|
||||
sp_k = k[:n].to(self.dtype) # [6, H, D]
|
||||
patch_k = k[n:].to(self.dtype) # [256, H, D]
|
||||
sp_v = v[:n].to(self.dtype)
|
||||
patch_v = v[n:].to(self.dtype)
|
||||
|
||||
assert patch_k.shape[0] == self.patches_per_frame, (
|
||||
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
|
||||
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
|
||||
)
|
||||
|
||||
self._write_patch_page(block_idx, patch_k, patch_v)
|
||||
self._write_special_tokens(block_idx, sp_k, sp_v)
|
||||
self.frame_count[block_idx] += 1
|
||||
|
||||
def evict_frames(
|
||||
self,
|
||||
block_idx: int,
|
||||
scale_frames: int,
|
||||
sliding_window: int,
|
||||
cross_frame_special: bool = True,
|
||||
include_scale_frames: bool = True,
|
||||
camera_only: bool = False,
|
||||
num_register_tokens: int = 4,
|
||||
) -> None:
|
||||
"""
|
||||
Evict old window patch pages (recycle to free list).
|
||||
|
||||
Special pages are NEVER evicted.
|
||||
Scale pages are NEVER evicted.
|
||||
Only live_window_patch_pages beyond `sliding_window` are recycled.
|
||||
"""
|
||||
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
|
||||
old_page = self.live_window_patch_pages[block_idx].popleft()
|
||||
self.free_patch_pages[block_idx].append(old_page)
|
||||
|
||||
def _gather_kv(self, block_idx: int):
|
||||
"""
|
||||
Gather all visible K and V tokens from the paged cache into dense tensors.
|
||||
|
||||
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
|
||||
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
|
||||
|
||||
Returns:
|
||||
k_flat: [kv_len, H, D] — all visible K tokens concatenated
|
||||
v_flat: [kv_len, H, D] — all visible V tokens concatenated
|
||||
"""
|
||||
visible = self.build_visible_page_table(block_idx)
|
||||
last_len = self.compute_last_page_len(block_idx)
|
||||
P = self.page_size
|
||||
|
||||
parts_k, parts_v = [], []
|
||||
for i, pid in enumerate(visible):
|
||||
n = last_len if (i == len(visible) - 1) else P
|
||||
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
|
||||
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
|
||||
|
||||
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
|
||||
v_flat = torch.cat(parts_v, dim=0)
|
||||
return k_flat, v_flat
|
||||
|
||||
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
|
||||
"""
|
||||
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
|
||||
|
||||
When self.force_fp32 is True, gathers all visible K/V into dense tensors
|
||||
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
|
||||
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
|
||||
|
||||
plan() is called once per frame step (when block_idx == 0).
|
||||
All layers at the same step share the same visible page structure,
|
||||
so the plan is reused by calling run() with each layer's kv_cache.
|
||||
|
||||
Args:
|
||||
block_idx: Block/layer index.
|
||||
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
|
||||
|
||||
Returns:
|
||||
out: [q_len, H, D]
|
||||
"""
|
||||
if self.frame_count[block_idx] == 0:
|
||||
# No KV present yet (should not occur in normal usage after append_frame)
|
||||
return torch.zeros_like(q)
|
||||
|
||||
if self.force_fp32:
|
||||
# ── fp32 gather+SDPA path ─────────────────────────────────────────
|
||||
# Gather visible K/V from paged cache and run SDPA in fp32.
|
||||
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
|
||||
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
|
||||
import torch.nn.functional as F_nn
|
||||
k_flat, v_flat = self._gather_kv(block_idx)
|
||||
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
|
||||
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
||||
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
||||
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
|
||||
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
|
||||
|
||||
if block_idx == 0:
|
||||
# ── Plan once per frame step ──────────────────────────────────────
|
||||
# Build visible page table from block 0's state.
|
||||
# All blocks have identical page structures, so this plan is valid
|
||||
# for all subsequent run() calls (block_idx = 1, 2, ...).
|
||||
visible = self.build_visible_page_table(0)
|
||||
last_len = self.compute_last_page_len(0)
|
||||
|
||||
assert visible, "visible page table is empty after append_frame"
|
||||
assert 1 <= last_len <= self.page_size, (
|
||||
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
|
||||
)
|
||||
|
||||
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
|
||||
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
|
||||
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
|
||||
|
||||
self.prefill_wrapper.plan(
|
||||
self._qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
num_qo_heads = self.num_heads,
|
||||
num_kv_heads = self.num_heads,
|
||||
head_dim_qk = self.head_dim,
|
||||
page_size = self.page_size,
|
||||
causal = False, # custom page ordering; no causal mask
|
||||
pos_encoding_mode = "NONE", # RoPE applied externally before append
|
||||
q_data_type = self.dtype,
|
||||
)
|
||||
|
||||
# ── Run attention for this layer ──────────────────────────────────────
|
||||
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
|
||||
return self.prefill_wrapper.run(
|
||||
q = q.to(self.dtype).contiguous(),
|
||||
paged_kv_cache = self.kv_caches[block_idx],
|
||||
) # → [q_len, H, D]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all per-block state for a new sequence."""
|
||||
for i in range(self.num_blocks):
|
||||
self.scale_patch_pages[i].clear()
|
||||
self.live_window_patch_pages[i].clear()
|
||||
self.all_special_pages[i].clear()
|
||||
self.free_patch_pages[i] = list(range(self.max_patch_pages))
|
||||
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
|
||||
self.special_token_count[i] = 0
|
||||
self.frame_count[i] = 0
|
||||
|
||||
# =========================================================================
|
||||
# Helper methods
|
||||
# =========================================================================
|
||||
|
||||
def build_visible_page_table(self, block_idx: int) -> List[int]:
|
||||
"""
|
||||
Return page IDs in strict order: scale → window → special.
|
||||
|
||||
Placing special pages last means only the final page may be partially
|
||||
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
|
||||
without a custom attention mask.
|
||||
"""
|
||||
return (
|
||||
list(self.scale_patch_pages[block_idx]) +
|
||||
list(self.live_window_patch_pages[block_idx]) +
|
||||
list(self.all_special_pages[block_idx])
|
||||
)
|
||||
|
||||
def compute_last_page_len(self, block_idx: int) -> int:
|
||||
"""
|
||||
Valid token count in the last page of the visible sequence.
|
||||
|
||||
- No special pages → last page is a patch page.
|
||||
Returns patches_per_frame (real tokens written),
|
||||
which may be < page_size when page_size was rounded
|
||||
up to a power of 2.
|
||||
- Special tail partial → special_token_count % page_size.
|
||||
- Special tail exactly full → page_size.
|
||||
"""
|
||||
if not self.all_special_pages[block_idx]:
|
||||
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
|
||||
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
|
||||
# valid count so it doesn't read beyond the real tokens.
|
||||
return self.patches_per_frame
|
||||
|
||||
tail = self.special_token_count[block_idx] % self.page_size
|
||||
return self.page_size if tail == 0 else tail
|
||||
|
||||
# ── Internal write helpers ────────────────────────────────────────────────
|
||||
|
||||
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
|
||||
"""
|
||||
Allocate one free patch page and write patches_per_frame patch tokens.
|
||||
|
||||
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
|
||||
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
|
||||
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
|
||||
patch_k/v fill exactly one full page (patches_per_frame == page_size).
|
||||
|
||||
Routes to scale_patch_pages if still filling scale quota,
|
||||
otherwise to live_window_patch_pages.
|
||||
|
||||
Returns:
|
||||
page_id: Physical page index used.
|
||||
"""
|
||||
assert self.free_patch_pages[block_idx], (
|
||||
f"block {block_idx}: patch page pool exhausted — "
|
||||
f"scale={len(self.scale_patch_pages[block_idx])}, "
|
||||
f"window={len(self.live_window_patch_pages[block_idx])}, "
|
||||
f"free={len(self.free_patch_pages[block_idx])}"
|
||||
)
|
||||
|
||||
page_id = self.free_patch_pages[block_idx].pop()
|
||||
|
||||
# Direct slice write: positions 0..patches_per_frame-1.
|
||||
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
|
||||
# this is equivalent to a full-page write. When page_size > patches_per_frame
|
||||
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
|
||||
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
|
||||
P = self.patches_per_frame
|
||||
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
|
||||
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
|
||||
|
||||
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
|
||||
self.scale_patch_pages[block_idx].append(page_id)
|
||||
else:
|
||||
self.live_window_patch_pages[block_idx].append(page_id)
|
||||
|
||||
return page_id
|
||||
|
||||
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
|
||||
"""
|
||||
Append num_special_tokens (6) special tokens to the special stream.
|
||||
|
||||
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
|
||||
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
|
||||
overhead of flashinfer.page.append_paged_kv_cache.
|
||||
|
||||
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
|
||||
two slice writes (rare — page_size=256 >> 6).
|
||||
"""
|
||||
remaining = self.num_special_tokens # 6
|
||||
written = 0
|
||||
|
||||
while remaining > 0:
|
||||
tail_offset = self.special_token_count[block_idx] % self.page_size
|
||||
|
||||
if tail_offset == 0:
|
||||
# Current tail page is full (or no page exists) — allocate a new one
|
||||
assert self.free_special_pages[block_idx], (
|
||||
f"block {block_idx}: special page pool exhausted at "
|
||||
f"special_token_count={self.special_token_count[block_idx]}. "
|
||||
f"Increase max_total_frames."
|
||||
)
|
||||
new_page = self.free_special_pages[block_idx].pop()
|
||||
self.all_special_pages[block_idx].append(new_page)
|
||||
|
||||
tail_page = self.all_special_pages[block_idx][-1]
|
||||
space = self.page_size - tail_offset # free slots in tail page
|
||||
write_n = min(remaining, space)
|
||||
|
||||
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
|
||||
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
|
||||
end = tail_offset + write_n
|
||||
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
|
||||
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
|
||||
|
||||
self.special_token_count[block_idx] += write_n
|
||||
written += write_n
|
||||
remaining -= write_n
|
||||
|
||||
# ── Legacy property (used by stream.py) ──────────────────────────────────
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames appended to block 0 (representative)."""
|
||||
return self.frame_count[0] if self.frame_count else 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sanity check
|
||||
# =============================================================================
|
||||
|
||||
def _sanity_check():
|
||||
"""
|
||||
Minimal smoke test.
|
||||
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if not torch.cuda.is_available():
|
||||
print("[sanity_check] CUDA not available — skipping.")
|
||||
return
|
||||
|
||||
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
|
||||
num_special = 6
|
||||
patches_per_frame = tokens_per_frame - num_special # 256
|
||||
page_size = patches_per_frame # 256
|
||||
|
||||
mgr = FlashInferKVCacheManager(
|
||||
num_blocks = 2,
|
||||
max_num_frames = 88,
|
||||
tokens_per_frame = tokens_per_frame,
|
||||
num_heads = 16,
|
||||
head_dim = 64,
|
||||
dtype = torch.bfloat16,
|
||||
device = device,
|
||||
num_special_tokens = num_special,
|
||||
scale_frames = 8,
|
||||
sliding_window = 64,
|
||||
max_total_frames = 200,
|
||||
)
|
||||
|
||||
def make_kv():
|
||||
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
return k, v
|
||||
|
||||
def make_q():
|
||||
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
|
||||
for block in range(2):
|
||||
for t in range(100):
|
||||
k, v = make_kv()
|
||||
mgr.append_frame(block, k, v)
|
||||
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
|
||||
|
||||
# ── Page count checks ───────────────────────────────────────────────
|
||||
n_scale = len(mgr.scale_patch_pages[block])
|
||||
n_window = len(mgr.live_window_patch_pages[block])
|
||||
n_spec = len(mgr.all_special_pages[block])
|
||||
sp_count = mgr.special_token_count[block]
|
||||
|
||||
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
|
||||
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
|
||||
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
|
||||
expected_spec_pages = math.ceil(100 * num_special / page_size)
|
||||
assert n_spec == expected_spec_pages, (
|
||||
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
|
||||
)
|
||||
assert sp_count == 100 * num_special, (
|
||||
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
|
||||
)
|
||||
|
||||
# ── last_page_len ────────────────────────────────────────────────────
|
||||
last_len = mgr.compute_last_page_len(block)
|
||||
tail = sp_count % page_size
|
||||
expected_len = page_size if tail == 0 else tail
|
||||
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
|
||||
|
||||
# ── visible page table order ─────────────────────────────────────────
|
||||
visible = mgr.build_visible_page_table(block)
|
||||
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
|
||||
for pid in visible[:n_scale + n_window]:
|
||||
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
|
||||
for pid in visible[n_scale + n_window:]:
|
||||
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
|
||||
|
||||
# ── forward pass: plan() once for block 0, run() for both blocks ─────
|
||||
if block == 1:
|
||||
# Simulate the actual calling pattern: plan on block 0, run on both
|
||||
q0 = make_q()
|
||||
out0 = mgr.compute_attention(0, q0) # triggers plan()
|
||||
q1 = make_q()
|
||||
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
|
||||
assert out0.shape == (tokens_per_frame, 16, 64)
|
||||
assert out1.shape == (tokens_per_frame, 16, 64)
|
||||
|
||||
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
|
||||
f"special_pages={n_spec}, special_tokens={sp_count}, "
|
||||
f"last_page_len={last_len}")
|
||||
|
||||
mgr.reset()
|
||||
assert mgr.frame_count[0] == 0
|
||||
print("\n[sanity_check] All assertions passed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_sanity_check()
|
||||
22
lingbot_map/layers/layer_scale.py
Normal file
22
lingbot_map/layers/layer_scale.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
40
lingbot_map/layers/mlp.py
Normal file
40
lingbot_map/layers/mlp.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
85
lingbot_map/layers/patch_embed.py
Normal file
85
lingbot_map/layers/patch_embed.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
474
lingbot_map/layers/rope.py
Normal file
474
lingbot_map/layers/rope.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
||||
|
||||
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
||||
# which extends the original RoPE concept to handle 2D spatial positions.
|
||||
|
||||
# Inspired by:
|
||||
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
||||
# https://github.com/naver-ai/rope-vit
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
"""Generates and caches 2D spatial positions for patches in a grid.
|
||||
|
||||
This class efficiently manages the generation of spatial coordinates for patches
|
||||
in a 2D grid, caching results to avoid redundant computations.
|
||||
|
||||
Attributes:
|
||||
position_cache: Dictionary storing precomputed position tensors for different
|
||||
grid dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the position generator with an empty cache."""
|
||||
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generates spatial positions for a batch of patches.
|
||||
|
||||
Args:
|
||||
batch_size: Number of samples in the batch.
|
||||
height: Height of the grid in patches.
|
||||
width: Width of the grid in patches.
|
||||
device: Target device for the position tensor.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
||||
for each position in the grid, repeated for each batch item.
|
||||
"""
|
||||
if (height, width) not in self.position_cache:
|
||||
y_coords = torch.arange(height, device=device)
|
||||
x_coords = torch.arange(width, device=device)
|
||||
positions = torch.cartesian_prod(y_coords, x_coords)
|
||||
self.position_cache[height, width] = positions
|
||||
|
||||
cached_positions = self.position_cache[height, width]
|
||||
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(nn.Module):
|
||||
"""2D Rotary Position Embedding implementation.
|
||||
|
||||
This module applies rotary position embeddings to input tokens based on their
|
||||
2D spatial positions. It handles the position-dependent rotation of features
|
||||
separately for vertical and horizontal dimensions.
|
||||
|
||||
Args:
|
||||
frequency: Base frequency for the position embeddings. Default: 100.0
|
||||
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
||||
|
||||
Attributes:
|
||||
base_frequency: Base frequency for computing position embeddings.
|
||||
scaling_factor: Factor to scale the computed frequencies.
|
||||
frequency_cache: Cache for storing precomputed frequency components.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
||||
"""Initializes the 2D RoPE module."""
|
||||
super().__init__()
|
||||
self.base_frequency = frequency
|
||||
self.scaling_factor = scaling_factor
|
||||
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
def _compute_frequency_components(
|
||||
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes frequency components for rotary embeddings.
|
||||
|
||||
Args:
|
||||
dim: Feature dimension (must be even).
|
||||
seq_len: Maximum sequence length.
|
||||
device: Target device for computations.
|
||||
dtype: Data type for the computed tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (cosine, sine) tensors for frequency components.
|
||||
"""
|
||||
cache_key = (dim, seq_len, device, dtype)
|
||||
if cache_key not in self.frequency_cache:
|
||||
# Compute frequency bands
|
||||
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency**exponents)
|
||||
|
||||
# Generate position-dependent frequencies
|
||||
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
||||
|
||||
# Compute and cache frequency components
|
||||
angles = angles.to(dtype)
|
||||
angles = torch.cat((angles, angles), dim=-1)
|
||||
cos_components = angles.cos().to(dtype)
|
||||
sin_components = angles.sin().to(dtype)
|
||||
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
||||
|
||||
return self.frequency_cache[cache_key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Performs feature rotation by splitting and recombining feature dimensions.
|
||||
|
||||
Args:
|
||||
x: Input tensor to rotate.
|
||||
|
||||
Returns:
|
||||
Rotated feature tensor.
|
||||
"""
|
||||
feature_dim = x.shape[-1]
|
||||
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d_rope(
|
||||
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Applies 1D rotary position embeddings along one dimension.
|
||||
|
||||
Args:
|
||||
tokens: Input token features.
|
||||
positions: Position indices.
|
||||
cos_comp: Cosine components for rotation.
|
||||
sin_comp: Sine components for rotation.
|
||||
|
||||
Returns:
|
||||
Tokens with applied rotary position embeddings.
|
||||
"""
|
||||
# Embed positions with frequency components
|
||||
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
||||
|
||||
# Apply rotation
|
||||
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies 2D rotary position embeddings to input tokens.
|
||||
|
||||
Args:
|
||||
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
||||
The feature dimension (dim) must be divisible by 4.
|
||||
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
||||
the y and x coordinates for each token.
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as input with applied 2D rotary position embeddings.
|
||||
|
||||
Raises:
|
||||
AssertionError: If input dimensions are invalid or positions are malformed.
|
||||
"""
|
||||
# Validate inputs
|
||||
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
||||
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
||||
|
||||
# Compute feature dimension for each spatial direction
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
|
||||
# Get frequency components
|
||||
max_position = int(positions.max()) + 1
|
||||
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
||||
|
||||
# Split features for vertical and horizontal processing
|
||||
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
||||
|
||||
# Apply RoPE separately for each dimension
|
||||
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
||||
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
||||
|
||||
# Combine processed features
|
||||
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
||||
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
计算1D旋转位置编码(RoPE)的频率张量。
|
||||
|
||||
RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
|
||||
公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
|
||||
|
||||
Args:
|
||||
dim: 特征维度,必须是偶数(因为要成对处理)
|
||||
pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
|
||||
theta: 基础频率,控制位置编码的周期性(默认10000)
|
||||
use_real: 是否返回实数形式(cos和sin分开)还是复数形式
|
||||
linear_factor: 线性缩放因子,用于上下文扩展
|
||||
ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
|
||||
repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
|
||||
freqs_dtype: 频率张量的数据类型
|
||||
|
||||
Returns:
|
||||
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
|
||||
实数形式:两个 [S, D] 的张量(cos和sin)
|
||||
"""
|
||||
# 确保维度是偶数(RoPE需要成对处理维度)
|
||||
assert dim % 2 == 0
|
||||
|
||||
# 将位置转换为torch张量
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # [S]
|
||||
|
||||
# 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
|
||||
theta = theta * ntk_factor
|
||||
|
||||
# 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
|
||||
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
|
||||
# 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2],每个频率对应一个维度对
|
||||
|
||||
# 步骤2:计算位置-频率矩阵
|
||||
# 使用外积:pos[m] * freqs[i] = m * θ_i
|
||||
# 结果:每个位置m和每个频率i的组合
|
||||
freqs = torch.outer(pos, freqs) # [S, D/2]
|
||||
|
||||
# 步骤3:根据返回格式转换
|
||||
if use_real and repeat_interleave_real:
|
||||
# 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
|
||||
# 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# 方式2:拼接重复(用于stable audio, allegro等模型)
|
||||
# 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# 方式3:复数形式(用于lumina等模型)
|
||||
# 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
|
||||
# torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
"""
|
||||
3D旋转位置编码(3D RoPE)模块
|
||||
|
||||
核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
|
||||
每个维度(t, h, w)独立使用RoPE,然后拼接起来。
|
||||
|
||||
公式:
|
||||
对于3D位置 (f, h, w)(帧、高度、宽度):
|
||||
- 帧维度使用 dim_f 个特征维度
|
||||
- 高度维度使用 dim_h 个特征维度
|
||||
- 宽度维度使用 dim_w 个特征维度
|
||||
其中 dim_f + dim_h + dim_w = attention_head_dim
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int = 1024,
|
||||
theta: float = 10000.0,
|
||||
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim # 注意力头的总维度
|
||||
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
|
||||
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
|
||||
|
||||
# 步骤1:分配维度给三个空间维度
|
||||
if fhw_dim is not None:
|
||||
# 如果指定了维度分配,使用指定的
|
||||
assert attention_head_dim == sum(
|
||||
fhw_dim
|
||||
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
|
||||
t_dim, h_dim, w_dim = fhw_dim
|
||||
else:
|
||||
# 否则自动分配:h和w各占1/3,t占剩余
|
||||
# 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 保存维度分配以便在forward中使用
|
||||
self.fhw_dim = (t_dim, h_dim, w_dim)
|
||||
|
||||
# 步骤2:为每个维度预计算频率
|
||||
# 分别计算时间、高度、宽度三个维度的RoPE频率
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
# 每个维度独立调用1D RoPE
|
||||
# 返回复数形式的频率: [max_seq_len, dim//2]
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||||
)
|
||||
freqs.append(freq)
|
||||
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
|
||||
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
前向传播:为3D输入(视频帧+patch)生成旋转位置编码
|
||||
|
||||
参数:
|
||||
- ppf (int): 帧数(patches per frame),当f_end为None时使用
|
||||
- pph (int): 每帧的patch高度数量
|
||||
- ppw (int): 每帧的patch宽度数量
|
||||
- patch_start_idx (int): 每帧的特殊token数量(在patches之前)
|
||||
- device: 计算设备(CPU/GPU)
|
||||
- f_start (int): 起始帧索引(用于causal模式),默认为0
|
||||
- f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
|
||||
|
||||
返回:
|
||||
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
|
||||
|
||||
Token排列顺序:
|
||||
[frame0_special_token_0, ..., frame0_special_token_N,
|
||||
frame0_patch_0, ..., frame0_patch_M,
|
||||
frame1_special_token_0, ..., frame1_special_token_N,
|
||||
frame1_patch_0, ..., frame1_patch_M,
|
||||
...]
|
||||
|
||||
模式:
|
||||
- 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
|
||||
- Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
|
||||
"""
|
||||
|
||||
# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
|
||||
self.freqs = self.freqs.to(device)
|
||||
# 获取实际的维度分配
|
||||
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
|
||||
t_dim, h_dim, w_dim = self.fhw_dim
|
||||
else:
|
||||
# 自动分配的情况
|
||||
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
|
||||
t_dim = self.attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 使用正确的split sizes(每个维度的一半)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
[
|
||||
t_dim // 2, # 时间维度
|
||||
h_dim // 2, # 高度维度
|
||||
w_dim // 2, # 宽度维度
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
|
||||
if f_end is not None:
|
||||
ppf = f_end - f_start
|
||||
frame_slice = slice(f_start, f_end)
|
||||
else:
|
||||
# 非causal模式:使用从0开始的ppf个帧
|
||||
frame_slice = slice(0, ppf)
|
||||
|
||||
# 步骤2:处理特殊token(如果存在)
|
||||
## For other tokens
|
||||
if patch_start_idx > 0:
|
||||
# 2.1 为特殊token生成位置编码
|
||||
# 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
|
||||
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
|
||||
# Shape: (ppf, patch_start_idx, dim)
|
||||
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
|
||||
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
|
||||
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
|
||||
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
|
||||
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
|
||||
|
||||
# 2.2 为图像patch生成位置编码
|
||||
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
|
||||
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
|
||||
# Shape: (ppf, pph, ppw, dim)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
|
||||
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
|
||||
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
|
||||
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
|
||||
|
||||
# 步骤3:按照正确的顺序组合特殊token和patches
|
||||
# 每帧内部顺序:[特殊tokens, patches]
|
||||
# Concatenate special tokens and patches for each frame along the second dimension
|
||||
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
|
||||
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
|
||||
|
||||
# 步骤4:展平为最终形状并添加batch和head维度
|
||||
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
|
||||
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
|
||||
return freqs
|
||||
|
||||
# 如果没有特殊token(patch_start_idx == 0),只处理图像patches
|
||||
# 所有patches位于 (f, 0:pph, 0:ppw)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
|
||||
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
|
||||
return freqs
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
"""
|
||||
应用旋转位置编码到输入特征
|
||||
|
||||
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
|
||||
|
||||
数学原理:
|
||||
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
|
||||
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
|
||||
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
|
||||
|
||||
这等价于旋转矩阵:
|
||||
[cos(θ) -sin(θ)] [x1]
|
||||
[sin(θ) cos(θ)] [x2]
|
||||
|
||||
参数:
|
||||
- x: 输入特征 [batch, heads, seq_len, head_dim]
|
||||
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
|
||||
|
||||
返回:
|
||||
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
|
||||
|
||||
实现步骤:
|
||||
1. 将x的每两个连续特征看作一个复数 (real, imag)
|
||||
2. 与预计算的复数频率 e^(iθ) 相乘
|
||||
3. 转回实数表示
|
||||
"""
|
||||
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
|
||||
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
|
||||
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
|
||||
# 步骤2:转换为复数表示 [b, h, seq, 32]
|
||||
# 每个元素是 real + imag*i
|
||||
x_complex = torch.view_as_complex(x_reshaped)
|
||||
|
||||
# 步骤3:复数乘法实现旋转
|
||||
# x_complex * freqs 相当于将每对特征旋转θ角度
|
||||
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
|
||||
x_rotated = x_complex * freqs
|
||||
|
||||
# 步骤4:转回实数表示 [b, h, seq, 32, 2]
|
||||
x_real = torch.view_as_real(x_rotated)
|
||||
|
||||
# 步骤5:展平最后两维 [b, h, seq, 64]
|
||||
x_out = x_real.flatten(3)
|
||||
|
||||
# 步骤6:转回原始数据类型
|
||||
return x_out.to(x.dtype)
|
||||
67
lingbot_map/layers/swiglu_ffn.py
Normal file
67
lingbot_map/layers/swiglu_ffn.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||
# try:
|
||||
# if XFORMERS_ENABLED:
|
||||
# from xformers.ops import SwiGLU
|
||||
|
||||
# XFORMERS_AVAILABLE = True
|
||||
# warnings.warn("xFormers is available (SwiGLU)")
|
||||
# else:
|
||||
# warnings.warn("xFormers is disabled (SwiGLU)")
|
||||
# raise ImportError
|
||||
# except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
# warnings.warn("xFormers is not available (SwiGLU)")
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
|
||||
411
lingbot_map/layers/vision_transformer.py
Normal file
411
lingbot_map/layers/vision_transformer.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
|
||||
|
||||
# TODO: Check this
|
||||
# We replace NestedTensorBlock with Block
|
||||
from .block import Block
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = ".".join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=Block,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
drop_cls_token=False,
|
||||
qk_norm=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
||||
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
||||
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1 if not drop_cls_token else 0
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
self.use_reentrant = False # hardcoded to False
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.drop_cls_token = drop_cls_token
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
assert num_register_tokens >= 0
|
||||
self.register_tokens = (
|
||||
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
||||
)
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||
logger.info("using SwiGLU layer as FFN")
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == "identity":
|
||||
logger.info("using Identity layer as FFN")
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
|
||||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
if not self.drop_cls_token:
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
else:
|
||||
patch_pos_embed = pos_embed
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
||||
assert N == M * M
|
||||
kwargs = {}
|
||||
if self.interpolate_offset:
|
||||
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
||||
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
||||
sx = float(w0 + self.interpolate_offset) / M
|
||||
sy = float(h0 + self.interpolate_offset) / M
|
||||
kwargs["scale_factor"] = (sx, sy)
|
||||
else:
|
||||
# Simply specify an output size instead of a scale factor
|
||||
kwargs["size"] = (w0, h0)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
**kwargs,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
if not self.drop_cls_token:
|
||||
x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
else:
|
||||
x = patch_pos_embed
|
||||
return x.to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.training:
|
||||
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.training:
|
||||
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=True, **kwargs):
|
||||
ret = self.forward_features(*args, **kwargs)
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret["x_norm_clstoken"])
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1536,
|
||||
depth=40,
|
||||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
Reference in New Issue
Block a user