# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Implementation of 2D Rotary Position Embeddings (RoPE). # This module provides a clean implementation of 2D Rotary Position Embeddings, # which extends the original RoPE concept to handle 2D spatial positions. # Inspired by: # https://github.com/meta-llama/codellama/blob/main/llama/model.py # https://github.com/naver-ai/rope-vit import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, Tuple from typing import List, Optional, Tuple, Union class PositionGetter: """Generates and caches 2D spatial positions for patches in a grid. This class efficiently manages the generation of spatial coordinates for patches in a 2D grid, caching results to avoid redundant computations. Attributes: position_cache: Dictionary storing precomputed position tensors for different grid dimensions. """ def __init__(self): """Initializes the position generator with an empty cache.""" self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: """Generates spatial positions for a batch of patches. Args: batch_size: Number of samples in the batch. height: Height of the grid in patches. width: Width of the grid in patches. device: Target device for the position tensor. Returns: Tensor of shape (batch_size, height*width, 2) containing y,x coordinates for each position in the grid, repeated for each batch item. """ if (height, width) not in self.position_cache: y_coords = torch.arange(height, device=device) x_coords = torch.arange(width, device=device) positions = torch.cartesian_prod(y_coords, x_coords) self.position_cache[height, width] = positions cached_positions = self.position_cache[height, width] return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() class RotaryPositionEmbedding2D(nn.Module): """2D Rotary Position Embedding implementation. This module applies rotary position embeddings to input tokens based on their 2D spatial positions. It handles the position-dependent rotation of features separately for vertical and horizontal dimensions. Args: frequency: Base frequency for the position embeddings. Default: 100.0 scaling_factor: Scaling factor for frequency computation. Default: 1.0 Attributes: base_frequency: Base frequency for computing position embeddings. scaling_factor: Factor to scale the computed frequencies. frequency_cache: Cache for storing precomputed frequency components. """ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): """Initializes the 2D RoPE module.""" super().__init__() self.base_frequency = frequency self.scaling_factor = scaling_factor self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} def _compute_frequency_components( self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes frequency components for rotary embeddings. Args: dim: Feature dimension (must be even). seq_len: Maximum sequence length. device: Target device for computations. dtype: Data type for the computed tensors. Returns: Tuple of (cosine, sine) tensors for frequency components. """ cache_key = (dim, seq_len, device, dtype) if cache_key not in self.frequency_cache: # Compute frequency bands exponents = torch.arange(0, dim, 2, device=device).float() / dim inv_freq = 1.0 / (self.base_frequency**exponents) # Generate position-dependent frequencies positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) angles = torch.einsum("i,j->ij", positions, inv_freq) # Compute and cache frequency components angles = angles.to(dtype) angles = torch.cat((angles, angles), dim=-1) cos_components = angles.cos().to(dtype) sin_components = angles.sin().to(dtype) self.frequency_cache[cache_key] = (cos_components, sin_components) return self.frequency_cache[cache_key] @staticmethod def _rotate_features(x: torch.Tensor) -> torch.Tensor: """Performs feature rotation by splitting and recombining feature dimensions. Args: x: Input tensor to rotate. Returns: Rotated feature tensor. """ feature_dim = x.shape[-1] x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_1d_rope( self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor ) -> torch.Tensor: """Applies 1D rotary position embeddings along one dimension. Args: tokens: Input token features. positions: Position indices. cos_comp: Cosine components for rotation. sin_comp: Sine components for rotation. Returns: Tokens with applied rotary position embeddings. """ # Embed positions with frequency components cos = F.embedding(positions, cos_comp)[:, None, :, :] sin = F.embedding(positions, sin_comp)[:, None, :, :] # Apply rotation return (tokens * cos) + (self._rotate_features(tokens) * sin) def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: """Applies 2D rotary position embeddings to input tokens. Args: tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). The feature dimension (dim) must be divisible by 4. positions: Position tensor of shape (batch_size, n_tokens, 2) containing the y and x coordinates for each token. Returns: Tensor of same shape as input with applied 2D rotary position embeddings. Raises: AssertionError: If input dimensions are invalid or positions are malformed. """ # Validate inputs assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" # Compute feature dimension for each spatial direction feature_dim = tokens.size(-1) // 2 # Get frequency components max_position = int(positions.max()) + 1 cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) # Split features for vertical and horizontal processing vertical_features, horizontal_features = tokens.chunk(2, dim=-1) # Apply RoPE separately for each dimension vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) # Combine processed features return torch.cat((vertical_features, horizontal_features), dim=-1) def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ 计算1D旋转位置编码(RoPE)的频率张量。 RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。 公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000) Args: dim: 特征维度,必须是偶数(因为要成对处理) pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S] theta: 基础频率,控制位置编码的周期性(默认10000) use_real: 是否返回实数形式(cos和sin分开)还是复数形式 linear_factor: 线性缩放因子,用于上下文扩展 ntk_factor: NTK-Aware缩放因子,用于处理更长的序列 repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构) freqs_dtype: 频率张量的数据类型 Returns: 复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j) 实数形式:两个 [S, D] 的张量(cos和sin) """ # 确保维度是偶数(RoPE需要成对处理维度) assert dim % 2 == 0 # 将位置转换为torch张量 if isinstance(pos, int): pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1] if isinstance(pos, np.ndarray): pos = torch.from_numpy(pos) # [S] # 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列) theta = theta * ntk_factor # 步骤1:计算频率 θ_i = 1 / (θ^(2i/d)) # 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理) # 公式:freq_i = 1 / (theta^(2i/d) * linear_factor) freqs = ( 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) / linear_factor ) # [D/2],每个频率对应一个维度对 # 步骤2:计算位置-频率矩阵 # 使用外积:pos[m] * freqs[i] = m * θ_i # 结果:每个位置m和每个频率i的组合 freqs = torch.outer(pos, freqs) # [S, D/2] # 步骤3:根据返回格式转换 if use_real and repeat_interleave_real: # 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型) # 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...] freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # 方式2:拼接重复(用于stable audio, allegro等模型) # 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n] freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: # 方式3:复数形式(用于lumina等模型) # 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ) # torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2] return freqs_cis class WanRotaryPosEmbed(nn.Module): """ 3D旋转位置编码(3D RoPE)模块 核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。 每个维度(t, h, w)独立使用RoPE,然后拼接起来。 公式: 对于3D位置 (f, h, w)(帧、高度、宽度): - 帧维度使用 dim_f 个特征维度 - 高度维度使用 dim_h 个特征维度 - 宽度维度使用 dim_w 个特征维度 其中 dim_f + dim_h + dim_w = attention_head_dim """ def __init__( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int = 1024, theta: float = 10000.0, fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22], ): super().__init__() self.attention_head_dim = attention_head_dim # 注意力头的总维度 self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w) self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率) # 步骤1:分配维度给三个空间维度 if fhw_dim is not None: # 如果指定了维度分配,使用指定的 assert attention_head_dim == sum( fhw_dim ), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}" t_dim, h_dim, w_dim = fhw_dim else: # 否则自动分配:h和w各占1/3,t占剩余 # 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22 h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim # 保存维度分配以便在forward中使用 self.fhw_dim = (t_dim, h_dim, w_dim) # 步骤2:为每个维度预计算频率 # 分别计算时间、高度、宽度三个维度的RoPE频率 freqs = [] for dim in [t_dim, h_dim, w_dim]: # 每个维度独立调用1D RoPE # 返回复数形式的频率: [max_seq_len, dim//2] freq = get_1d_rotary_pos_embed( dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 ) freqs.append(freq) # 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2] self.freqs = torch.cat(freqs, dim=1) def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor: """ 前向传播:为3D输入(视频帧+patch)生成旋转位置编码 参数: - ppf (int): 帧数(patches per frame),当f_end为None时使用 - pph (int): 每帧的patch高度数量 - ppw (int): 每帧的patch宽度数量 - patch_start_idx (int): 每帧的特殊token数量(在patches之前) - device: 计算设备(CPU/GPU) - f_start (int): 起始帧索引(用于causal模式),默认为0 - f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数 返回: - freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor Token排列顺序: [frame0_special_token_0, ..., frame0_special_token_N, frame0_patch_0, ..., frame0_patch_M, frame1_special_token_0, ..., frame1_special_token_N, frame1_patch_0, ..., frame1_patch_M, ...] 模式: - 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始 - Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算 """ # 步骤1:将预计算的频率移到目标设备,并分割成三个维度 self.freqs = self.freqs.to(device) # 获取实际的维度分配 if hasattr(self, 'fhw_dim') and self.fhw_dim is not None: t_dim, h_dim, w_dim = self.fhw_dim else: # 自动分配的情况 h_dim = w_dim = 2 * (self.attention_head_dim // 6) t_dim = self.attention_head_dim - h_dim - w_dim # 使用正确的split sizes(每个维度的一半) freqs = self.freqs.split_with_sizes( [ t_dim // 2, # 时间维度 h_dim // 2, # 高度维度 w_dim // 2, # 宽度维度 ], dim=1, ) # 处理causal模式:如果指定了f_end,重新计算ppf和帧范围 if f_end is not None: ppf = f_end - f_start frame_slice = slice(f_start, f_end) else: # 非causal模式:使用从0开始的ppf个帧 frame_slice = slice(0, ppf) # 步骤2:处理特殊token(如果存在) ## For other tokens if patch_start_idx > 0: # 2.1 为特殊token生成位置编码 # 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置 # camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5) # Shape: (ppf, patch_start_idx, dim) freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化 freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,... freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,... freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维 freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim) # 2.2 为图像patch生成位置编码 # Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx # 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理 # Shape: (ppf, pph, ppw, dim) freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度 freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始 freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始 freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维 freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度 # 步骤3:按照正确的顺序组合特殊token和patches # 每帧内部顺序:[特殊tokens, patches] # Concatenate special tokens and patches for each frame along the second dimension # Shape: (ppf, patch_start_idx + pph * ppw, dim) freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim) # 步骤4:展平为最终形状并添加batch和head维度 # Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim) freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1) freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度 return freqs # 如果没有特殊token(patch_start_idx == 0),只处理图像patches # 所有patches位于 (f, 0:pph, 0:ppw) freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度 freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始 freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始 freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim) return freqs def apply_rotary_emb(x, freqs): """ 应用旋转位置编码到输入特征 核心思想:使用复数乘法实现特征旋转,保持相对位置信息 数学原理: 对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法: (x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ)) = (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ)) 这等价于旋转矩阵: [cos(θ) -sin(θ)] [x1] [sin(θ) cos(θ)] [x2] 参数: - x: 输入特征 [batch, heads, seq_len, head_dim] - freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2] 返回: - x_out: 旋转后的特征 [batch, heads, seq_len, head_dim] 实现步骤: 1. 将x的每两个连续特征看作一个复数 (real, imag) 2. 与预计算的复数频率 e^(iθ) 相乘 3. 转回实数表示 """ # 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag) # 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2] x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) # 步骤2:转换为复数表示 [b, h, seq, 32] # 每个元素是 real + imag*i x_complex = torch.view_as_complex(x_reshaped) # 步骤3:复数乘法实现旋转 # x_complex * freqs 相当于将每对特征旋转θ角度 # freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式 x_rotated = x_complex * freqs # 步骤4:转回实数表示 [b, h, seq, 32, 2] x_real = torch.view_as_real(x_rotated) # 步骤5:展平最后两维 [b, h, seq, 64] x_out = x_real.flatten(3) # 步骤6:转回原始数据类型 return x_out.to(x.dtype)