first commit
This commit is contained in:
474
lingbot_map/layers/rope.py
Normal file
474
lingbot_map/layers/rope.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
||||
|
||||
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
||||
# which extends the original RoPE concept to handle 2D spatial positions.
|
||||
|
||||
# Inspired by:
|
||||
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
||||
# https://github.com/naver-ai/rope-vit
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
"""Generates and caches 2D spatial positions for patches in a grid.
|
||||
|
||||
This class efficiently manages the generation of spatial coordinates for patches
|
||||
in a 2D grid, caching results to avoid redundant computations.
|
||||
|
||||
Attributes:
|
||||
position_cache: Dictionary storing precomputed position tensors for different
|
||||
grid dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the position generator with an empty cache."""
|
||||
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generates spatial positions for a batch of patches.
|
||||
|
||||
Args:
|
||||
batch_size: Number of samples in the batch.
|
||||
height: Height of the grid in patches.
|
||||
width: Width of the grid in patches.
|
||||
device: Target device for the position tensor.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
||||
for each position in the grid, repeated for each batch item.
|
||||
"""
|
||||
if (height, width) not in self.position_cache:
|
||||
y_coords = torch.arange(height, device=device)
|
||||
x_coords = torch.arange(width, device=device)
|
||||
positions = torch.cartesian_prod(y_coords, x_coords)
|
||||
self.position_cache[height, width] = positions
|
||||
|
||||
cached_positions = self.position_cache[height, width]
|
||||
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(nn.Module):
|
||||
"""2D Rotary Position Embedding implementation.
|
||||
|
||||
This module applies rotary position embeddings to input tokens based on their
|
||||
2D spatial positions. It handles the position-dependent rotation of features
|
||||
separately for vertical and horizontal dimensions.
|
||||
|
||||
Args:
|
||||
frequency: Base frequency for the position embeddings. Default: 100.0
|
||||
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
||||
|
||||
Attributes:
|
||||
base_frequency: Base frequency for computing position embeddings.
|
||||
scaling_factor: Factor to scale the computed frequencies.
|
||||
frequency_cache: Cache for storing precomputed frequency components.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
||||
"""Initializes the 2D RoPE module."""
|
||||
super().__init__()
|
||||
self.base_frequency = frequency
|
||||
self.scaling_factor = scaling_factor
|
||||
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
def _compute_frequency_components(
|
||||
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes frequency components for rotary embeddings.
|
||||
|
||||
Args:
|
||||
dim: Feature dimension (must be even).
|
||||
seq_len: Maximum sequence length.
|
||||
device: Target device for computations.
|
||||
dtype: Data type for the computed tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (cosine, sine) tensors for frequency components.
|
||||
"""
|
||||
cache_key = (dim, seq_len, device, dtype)
|
||||
if cache_key not in self.frequency_cache:
|
||||
# Compute frequency bands
|
||||
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency**exponents)
|
||||
|
||||
# Generate position-dependent frequencies
|
||||
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
||||
|
||||
# Compute and cache frequency components
|
||||
angles = angles.to(dtype)
|
||||
angles = torch.cat((angles, angles), dim=-1)
|
||||
cos_components = angles.cos().to(dtype)
|
||||
sin_components = angles.sin().to(dtype)
|
||||
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
||||
|
||||
return self.frequency_cache[cache_key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Performs feature rotation by splitting and recombining feature dimensions.
|
||||
|
||||
Args:
|
||||
x: Input tensor to rotate.
|
||||
|
||||
Returns:
|
||||
Rotated feature tensor.
|
||||
"""
|
||||
feature_dim = x.shape[-1]
|
||||
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d_rope(
|
||||
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Applies 1D rotary position embeddings along one dimension.
|
||||
|
||||
Args:
|
||||
tokens: Input token features.
|
||||
positions: Position indices.
|
||||
cos_comp: Cosine components for rotation.
|
||||
sin_comp: Sine components for rotation.
|
||||
|
||||
Returns:
|
||||
Tokens with applied rotary position embeddings.
|
||||
"""
|
||||
# Embed positions with frequency components
|
||||
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
||||
|
||||
# Apply rotation
|
||||
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies 2D rotary position embeddings to input tokens.
|
||||
|
||||
Args:
|
||||
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
||||
The feature dimension (dim) must be divisible by 4.
|
||||
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
||||
the y and x coordinates for each token.
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as input with applied 2D rotary position embeddings.
|
||||
|
||||
Raises:
|
||||
AssertionError: If input dimensions are invalid or positions are malformed.
|
||||
"""
|
||||
# Validate inputs
|
||||
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
||||
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
||||
|
||||
# Compute feature dimension for each spatial direction
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
|
||||
# Get frequency components
|
||||
max_position = int(positions.max()) + 1
|
||||
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
||||
|
||||
# Split features for vertical and horizontal processing
|
||||
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
||||
|
||||
# Apply RoPE separately for each dimension
|
||||
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
||||
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
||||
|
||||
# Combine processed features
|
||||
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
||||
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
计算1D旋转位置编码(RoPE)的频率张量。
|
||||
|
||||
RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
|
||||
公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
|
||||
|
||||
Args:
|
||||
dim: 特征维度,必须是偶数(因为要成对处理)
|
||||
pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
|
||||
theta: 基础频率,控制位置编码的周期性(默认10000)
|
||||
use_real: 是否返回实数形式(cos和sin分开)还是复数形式
|
||||
linear_factor: 线性缩放因子,用于上下文扩展
|
||||
ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
|
||||
repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
|
||||
freqs_dtype: 频率张量的数据类型
|
||||
|
||||
Returns:
|
||||
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
|
||||
实数形式:两个 [S, D] 的张量(cos和sin)
|
||||
"""
|
||||
# 确保维度是偶数(RoPE需要成对处理维度)
|
||||
assert dim % 2 == 0
|
||||
|
||||
# 将位置转换为torch张量
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # [S]
|
||||
|
||||
# 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
|
||||
theta = theta * ntk_factor
|
||||
|
||||
# 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
|
||||
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
|
||||
# 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2],每个频率对应一个维度对
|
||||
|
||||
# 步骤2:计算位置-频率矩阵
|
||||
# 使用外积:pos[m] * freqs[i] = m * θ_i
|
||||
# 结果:每个位置m和每个频率i的组合
|
||||
freqs = torch.outer(pos, freqs) # [S, D/2]
|
||||
|
||||
# 步骤3:根据返回格式转换
|
||||
if use_real and repeat_interleave_real:
|
||||
# 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
|
||||
# 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# 方式2:拼接重复(用于stable audio, allegro等模型)
|
||||
# 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# 方式3:复数形式(用于lumina等模型)
|
||||
# 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
|
||||
# torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
"""
|
||||
3D旋转位置编码(3D RoPE)模块
|
||||
|
||||
核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
|
||||
每个维度(t, h, w)独立使用RoPE,然后拼接起来。
|
||||
|
||||
公式:
|
||||
对于3D位置 (f, h, w)(帧、高度、宽度):
|
||||
- 帧维度使用 dim_f 个特征维度
|
||||
- 高度维度使用 dim_h 个特征维度
|
||||
- 宽度维度使用 dim_w 个特征维度
|
||||
其中 dim_f + dim_h + dim_w = attention_head_dim
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int = 1024,
|
||||
theta: float = 10000.0,
|
||||
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim # 注意力头的总维度
|
||||
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
|
||||
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
|
||||
|
||||
# 步骤1:分配维度给三个空间维度
|
||||
if fhw_dim is not None:
|
||||
# 如果指定了维度分配,使用指定的
|
||||
assert attention_head_dim == sum(
|
||||
fhw_dim
|
||||
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
|
||||
t_dim, h_dim, w_dim = fhw_dim
|
||||
else:
|
||||
# 否则自动分配:h和w各占1/3,t占剩余
|
||||
# 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 保存维度分配以便在forward中使用
|
||||
self.fhw_dim = (t_dim, h_dim, w_dim)
|
||||
|
||||
# 步骤2:为每个维度预计算频率
|
||||
# 分别计算时间、高度、宽度三个维度的RoPE频率
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
# 每个维度独立调用1D RoPE
|
||||
# 返回复数形式的频率: [max_seq_len, dim//2]
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||||
)
|
||||
freqs.append(freq)
|
||||
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
|
||||
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
前向传播:为3D输入(视频帧+patch)生成旋转位置编码
|
||||
|
||||
参数:
|
||||
- ppf (int): 帧数(patches per frame),当f_end为None时使用
|
||||
- pph (int): 每帧的patch高度数量
|
||||
- ppw (int): 每帧的patch宽度数量
|
||||
- patch_start_idx (int): 每帧的特殊token数量(在patches之前)
|
||||
- device: 计算设备(CPU/GPU)
|
||||
- f_start (int): 起始帧索引(用于causal模式),默认为0
|
||||
- f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
|
||||
|
||||
返回:
|
||||
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
|
||||
|
||||
Token排列顺序:
|
||||
[frame0_special_token_0, ..., frame0_special_token_N,
|
||||
frame0_patch_0, ..., frame0_patch_M,
|
||||
frame1_special_token_0, ..., frame1_special_token_N,
|
||||
frame1_patch_0, ..., frame1_patch_M,
|
||||
...]
|
||||
|
||||
模式:
|
||||
- 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
|
||||
- Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
|
||||
"""
|
||||
|
||||
# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
|
||||
self.freqs = self.freqs.to(device)
|
||||
# 获取实际的维度分配
|
||||
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
|
||||
t_dim, h_dim, w_dim = self.fhw_dim
|
||||
else:
|
||||
# 自动分配的情况
|
||||
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
|
||||
t_dim = self.attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 使用正确的split sizes(每个维度的一半)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
[
|
||||
t_dim // 2, # 时间维度
|
||||
h_dim // 2, # 高度维度
|
||||
w_dim // 2, # 宽度维度
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
|
||||
if f_end is not None:
|
||||
ppf = f_end - f_start
|
||||
frame_slice = slice(f_start, f_end)
|
||||
else:
|
||||
# 非causal模式:使用从0开始的ppf个帧
|
||||
frame_slice = slice(0, ppf)
|
||||
|
||||
# 步骤2:处理特殊token(如果存在)
|
||||
## For other tokens
|
||||
if patch_start_idx > 0:
|
||||
# 2.1 为特殊token生成位置编码
|
||||
# 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
|
||||
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
|
||||
# Shape: (ppf, patch_start_idx, dim)
|
||||
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
|
||||
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
|
||||
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
|
||||
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
|
||||
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
|
||||
|
||||
# 2.2 为图像patch生成位置编码
|
||||
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
|
||||
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
|
||||
# Shape: (ppf, pph, ppw, dim)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
|
||||
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
|
||||
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
|
||||
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
|
||||
|
||||
# 步骤3:按照正确的顺序组合特殊token和patches
|
||||
# 每帧内部顺序:[特殊tokens, patches]
|
||||
# Concatenate special tokens and patches for each frame along the second dimension
|
||||
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
|
||||
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
|
||||
|
||||
# 步骤4:展平为最终形状并添加batch和head维度
|
||||
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
|
||||
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
|
||||
return freqs
|
||||
|
||||
# 如果没有特殊token(patch_start_idx == 0),只处理图像patches
|
||||
# 所有patches位于 (f, 0:pph, 0:ppw)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
|
||||
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
|
||||
return freqs
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
"""
|
||||
应用旋转位置编码到输入特征
|
||||
|
||||
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
|
||||
|
||||
数学原理:
|
||||
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
|
||||
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
|
||||
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
|
||||
|
||||
这等价于旋转矩阵:
|
||||
[cos(θ) -sin(θ)] [x1]
|
||||
[sin(θ) cos(θ)] [x2]
|
||||
|
||||
参数:
|
||||
- x: 输入特征 [batch, heads, seq_len, head_dim]
|
||||
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
|
||||
|
||||
返回:
|
||||
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
|
||||
|
||||
实现步骤:
|
||||
1. 将x的每两个连续特征看作一个复数 (real, imag)
|
||||
2. 与预计算的复数频率 e^(iθ) 相乘
|
||||
3. 转回实数表示
|
||||
"""
|
||||
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
|
||||
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
|
||||
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
|
||||
# 步骤2:转换为复数表示 [b, h, seq, 32]
|
||||
# 每个元素是 real + imag*i
|
||||
x_complex = torch.view_as_complex(x_reshaped)
|
||||
|
||||
# 步骤3:复数乘法实现旋转
|
||||
# x_complex * freqs 相当于将每对特征旋转θ角度
|
||||
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
|
||||
x_rotated = x_complex * freqs
|
||||
|
||||
# 步骤4:转回实数表示 [b, h, seq, 32, 2]
|
||||
x_real = torch.view_as_real(x_rotated)
|
||||
|
||||
# 步骤5:展平最后两维 [b, h, seq, 64]
|
||||
x_out = x_real.flatten(3)
|
||||
|
||||
# 步骤6:转回原始数据类型
|
||||
return x_out.to(x.dtype)
|
||||
Reference in New Issue
Block a user