Files
lingbot-map/lingbot_map/layers/rope.py
LinZhuoChen f9b3ae457a first commit
2026-04-16 09:51:30 +08:00

475 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Implementation of 2D Rotary Position Embeddings (RoPE).
# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.
# Inspired by:
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
# https://github.com/naver-ai/rope-vit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import List, Optional, Tuple, Union
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
计算1D旋转位置编码RoPE的频率张量。
RoPE的核心思想使用旋转矩阵来编码位置信息使得相对位置关系保持不变。
公式对于位置m和维度i频率为 θ_i = θ^(-2i/d)其中θ是基础频率默认10000
Args:
dim: 特征维度,必须是偶数(因为要成对处理)
pos: 位置索引可以是整数自动生成0到pos-1的序列或位置数组 [S]
theta: 基础频率控制位置编码的周期性默认10000
use_real: 是否返回实数形式cos和sin分开还是复数形式
linear_factor: 线性缩放因子,用于上下文扩展
ntk_factor: NTK-Aware缩放因子用于处理更长的序列
repeat_interleave_real: 当use_real=True时是否交错重复用于某些模型架构
freqs_dtype: 频率张量的数据类型
Returns:
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
实数形式:两个 [S, D] 的张量cos和sin
"""
# 确保维度是偶数RoPE需要成对处理维度
assert dim % 2 == 0
# 将位置转换为torch张量
if isinstance(pos, int):
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # [S]
# 应用NTK缩放Neural Tangent Kernel用于处理训练时未见过的长序列
theta = theta * ntk_factor
# 步骤1计算频率 θ_i = 1 / (θ^(2i/d))
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
# 公式freq_i = 1 / (theta^(2i/d) * linear_factor)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2],每个频率对应一个维度对
# 步骤2计算位置-频率矩阵
# 使用外积pos[m] * freqs[i] = m * θ_i
# 结果每个位置m和每个频率i的组合
freqs = torch.outer(pos, freqs) # [S, D/2]
# 步骤3根据返回格式转换
if use_real and repeat_interleave_real:
# 方式1交错重复用于flux, hunyuan-dit, cogvideox等模型
# 将每个频率的cos和sin交错排列[cos_0, cos_0, cos_1, cos_1, ...]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# 方式2拼接重复用于stable audio, allegro等模型
# 将所有cos拼接然后是所有sin[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# 方式3复数形式用于lumina等模型
# 使用欧拉公式e^(iθ) = cos(θ) + i*sin(θ)
# torch.polar(r, θ) 返回 r * e^(iθ)这里r=1所以就是 e^(i*freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
return freqs_cis
class WanRotaryPosEmbed(nn.Module):
"""
3D旋转位置编码3D RoPE模块
核心思想将RoPE扩展到3D空间时间、高度、宽度为视频或3D数据提供位置编码。
每个维度t, h, w独立使用RoPE然后拼接起来。
公式:
对于3D位置 (f, h, w)(帧、高度、宽度):
- 帧维度使用 dim_f 个特征维度
- 高度维度使用 dim_h 个特征维度
- 宽度维度使用 dim_w 个特征维度
其中 dim_f + dim_h + dim_w = attention_head_dim
"""
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int = 1024,
theta: float = 10000.0,
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
):
super().__init__()
self.attention_head_dim = attention_head_dim # 注意力头的总维度
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
# 步骤1分配维度给三个空间维度
if fhw_dim is not None:
# 如果指定了维度分配,使用指定的
assert attention_head_dim == sum(
fhw_dim
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
t_dim, h_dim, w_dim = fhw_dim
else:
# 否则自动分配h和w各占1/3t占剩余
# 例如如果attention_head_dim=64则 h_dim=w_dim=21t_dim=22
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
# 保存维度分配以便在forward中使用
self.fhw_dim = (t_dim, h_dim, w_dim)
# 步骤2为每个维度预计算频率
# 分别计算时间、高度、宽度三个维度的RoPE频率
freqs = []
for dim in [t_dim, h_dim, w_dim]:
# 每个维度独立调用1D RoPE
# 返回复数形式的频率: [max_seq_len, dim//2]
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
freqs.append(freq)
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
self.freqs = torch.cat(freqs, dim=1)
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
"""
前向传播为3D输入视频帧+patch生成旋转位置编码
参数:
- ppf (int): 帧数patches per frame当f_end为None时使用
- pph (int): 每帧的patch高度数量
- ppw (int): 每帧的patch宽度数量
- patch_start_idx (int): 每帧的特殊token数量在patches之前
- device: 计算设备CPU/GPU
- f_start (int): 起始帧索引用于causal模式默认为0
- f_end (Optional[int]): 结束帧索引用于causal模式如果为None则使用ppf作为帧数
返回:
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
Token排列顺序
[frame0_special_token_0, ..., frame0_special_token_N,
frame0_patch_0, ..., frame0_patch_M,
frame1_special_token_0, ..., frame1_special_token_N,
frame1_patch_0, ..., frame1_patch_M,
...]
模式:
- 非causal模式f_end=None使用ppf作为帧数从位置0开始
- Causal模式f_end不为None使用[f_start, f_end)范围的帧ppf会被重新计算
"""
# 步骤1将预计算的频率移到目标设备并分割成三个维度
self.freqs = self.freqs.to(device)
# 获取实际的维度分配
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
t_dim, h_dim, w_dim = self.fhw_dim
else:
# 自动分配的情况
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
# 使用正确的split sizes每个维度的一半
freqs = self.freqs.split_with_sizes(
[
t_dim // 2, # 时间维度
h_dim // 2, # 高度维度
w_dim // 2, # 宽度维度
],
dim=1,
)
# 处理causal模式如果指定了f_end重新计算ppf和帧范围
if f_end is not None:
ppf = f_end - f_start
frame_slice = slice(f_start, f_end)
else:
# 非causal模式使用从0开始的ppf个帧
frame_slice = slice(0, ppf)
# 步骤2处理特殊token如果存在
## For other tokens
if patch_start_idx > 0:
# 2.1 为特殊token生成位置编码
# 特殊token位于对角线位置 (f, i, i)每个特殊token有唯一位置
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
# Shape: (ppf, patch_start_idx, dim)
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
# 2.2 为图像patch生成位置编码
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w)h,w 整体偏移 patch_start_idx
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
# Shape: (ppf, pph, ppw, dim)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
# 步骤3按照正确的顺序组合特殊token和patches
# 每帧内部顺序:[特殊tokens, patches]
# Concatenate special tokens and patches for each frame along the second dimension
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
# 步骤4展平为最终形状并添加batch和head维度
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
return freqs
# 如果没有特殊tokenpatch_start_idx == 0只处理图像patches
# 所有patches位于 (f, 0:pph, 0:ppw)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
return freqs
def apply_rotary_emb(x, freqs):
"""
应用旋转位置编码到输入特征
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
数学原理:
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
这等价于旋转矩阵:
[cos(θ) -sin(θ)] [x1]
[sin(θ) cos(θ)] [x2]
参数:
- x: 输入特征 [batch, heads, seq_len, head_dim]
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
返回:
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
实现步骤:
1. 将x的每两个连续特征看作一个复数 (real, imag)
2. 与预计算的复数频率 e^(iθ) 相乘
3. 转回实数表示
"""
# 步骤1reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
# 步骤2转换为复数表示 [b, h, seq, 32]
# 每个元素是 real + imag*i
x_complex = torch.view_as_complex(x_reshaped)
# 步骤3复数乘法实现旋转
# x_complex * freqs 相当于将每对特征旋转θ角度
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
x_rotated = x_complex * freqs
# 步骤4转回实数表示 [b, h, seq, 32, 2]
x_real = torch.view_as_real(x_rotated)
# 步骤5展平最后两维 [b, h, seq, 64]
x_out = x_real.flatten(3)
# 步骤6转回原始数据类型
return x_out.to(x.dtype)