515 lines
22 KiB
Python
515 lines
22 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the Apache License, Version 2.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
# References:
|
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
|
|
|
import logging
|
|
import os
|
|
from typing import Callable, List, Any, Tuple, Dict
|
|
import warnings
|
|
import math
|
|
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
|
|
from functools import lru_cache, partial
|
|
from torch.nn.attention.flex_attention import BlockMask, create_mask
|
|
from .drop_path import DropPath
|
|
from .layer_scale import LayerScale
|
|
from .mlp import Mlp
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
proj_bias: bool = True,
|
|
ffn_bias: bool = True,
|
|
drop: float = 0.0,
|
|
attn_drop: float = 0.0,
|
|
init_values=None,
|
|
drop_path: float = 0.0,
|
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
attn_class: Callable[..., nn.Module] = Attention,
|
|
ffn_layer: Callable[..., nn.Module] = Mlp,
|
|
qk_norm: bool = False,
|
|
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
|
rope=None,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.attn = attn_class(
|
|
dim,
|
|
num_heads=num_heads,
|
|
qkv_bias=qkv_bias,
|
|
proj_bias=proj_bias,
|
|
attn_drop=attn_drop,
|
|
proj_drop=drop,
|
|
qk_norm=qk_norm,
|
|
fused_attn=fused_attn,
|
|
rope=rope,
|
|
)
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = ffn_layer(
|
|
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
|
)
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.sample_drop_ratio = drop_path
|
|
|
|
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
|
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
|
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
|
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
|
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
|
enable_3d_rope=enable_3d_rope))
|
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor:
|
|
return self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
if self.training and self.sample_drop_ratio > 0.1:
|
|
# the overhead is compensated only for a drop path rate larger than 0.1
|
|
x = drop_add_residual_stochastic_depth(
|
|
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
|
)
|
|
x = drop_add_residual_stochastic_depth(
|
|
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
|
)
|
|
elif self.training and self.sample_drop_ratio > 0.0:
|
|
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
|
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
|
else:
|
|
x = x + attn_residual_func(x, pos=pos)
|
|
x = x + ffn_residual_func(x)
|
|
return x
|
|
|
|
|
|
def drop_add_residual_stochastic_depth(
|
|
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
|
) -> Tensor:
|
|
# 1) extract subset using permutation
|
|
b, n, d = x.shape
|
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
|
x_subset = x[brange]
|
|
|
|
# 2) apply residual_func to get residual
|
|
if pos is not None:
|
|
# if necessary, apply rope to the subset
|
|
pos = pos[brange]
|
|
residual = residual_func(x_subset, pos=pos)
|
|
else:
|
|
residual = residual_func(x_subset)
|
|
|
|
x_flat = x.flatten(1)
|
|
residual = residual.flatten(1)
|
|
|
|
residual_scale_factor = b / sample_subset_size
|
|
|
|
# 3) add the residual
|
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
|
return x_plus_residual.view_as(x)
|
|
|
|
|
|
def get_branges_scales(x, sample_drop_ratio=0.0):
|
|
b, n, d = x.shape
|
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
|
residual_scale_factor = b / sample_subset_size
|
|
return brange, residual_scale_factor
|
|
|
|
|
|
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
|
if scaling_vector is None:
|
|
x_flat = x.flatten(1)
|
|
residual = residual.flatten(1)
|
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
|
else:
|
|
x_plus_residual = scaled_index_add(
|
|
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
|
)
|
|
return x_plus_residual
|
|
|
|
|
|
class FlashInferBlock(nn.Module):
|
|
"""
|
|
FlashInfer variant of causal block for GCT.
|
|
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
|
|
Supports optimized token layout and KV cache streaming inference.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
proj_bias: bool = True,
|
|
ffn_bias: bool = True,
|
|
drop: float = 0.0,
|
|
attn_drop: float = 0.0,
|
|
init_values=None,
|
|
drop_path: float = 0.0,
|
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
ffn_layer: Callable[..., nn.Module] = Mlp,
|
|
qk_norm: bool = False,
|
|
rope=None,
|
|
kv_cache_sliding_window: int = 64,
|
|
kv_cache_scale_frames: int = 8,
|
|
kv_cache_cross_frame_special: bool = True,
|
|
kv_cache_include_scale_frames: bool = True,
|
|
kv_cache_camera_only: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = FlashInferAttention(
|
|
dim=dim,
|
|
num_heads=num_heads,
|
|
qk_norm=qk_norm,
|
|
qkv_bias=qkv_bias,
|
|
proj_bias=proj_bias,
|
|
attn_drop=attn_drop,
|
|
proj_drop=drop,
|
|
rope=rope,
|
|
kv_cache_sliding_window=kv_cache_sliding_window,
|
|
kv_cache_scale_frames=kv_cache_scale_frames,
|
|
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
|
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
|
kv_cache_camera_only=kv_cache_camera_only,
|
|
)
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = ffn_layer(
|
|
in_features=dim,
|
|
hidden_features=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=drop,
|
|
bias=ffn_bias
|
|
)
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.sample_drop_ratio = drop_path
|
|
|
|
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
|
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
|
|
|
|
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
|
|
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
|
|
|
|
Returns:
|
|
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
|
|
ready for manager.append_frame + manager.compute_attention.
|
|
"""
|
|
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
|
|
|
|
def forward(
|
|
self,
|
|
x: Tensor,
|
|
pos=None,
|
|
enable_ulysses_cp=False,
|
|
num_patches=None,
|
|
num_special=None,
|
|
num_frames=None,
|
|
enable_3d_rope=False,
|
|
kv_cache=None,
|
|
global_idx=0,
|
|
num_frame_per_block=1,
|
|
num_frame_for_scale=-1,
|
|
num_register_tokens=4,
|
|
) -> Tensor:
|
|
# Phase 2 (streaming): single-frame FlashInfer paged attention.
|
|
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
|
|
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
|
|
if is_streaming:
|
|
manager = kv_cache
|
|
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
|
|
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
|
# Eager: write frame K/V to paged cache
|
|
manager.append_frame(global_idx, k_nhd, v_nhd)
|
|
# CPU-only: update eviction state (deque ops, no GPU kernel)
|
|
manager.evict_frames(
|
|
block_idx=global_idx,
|
|
scale_frames=self.attn.kv_cache_scale_frames,
|
|
sliding_window=self.attn.kv_cache_sliding_window,
|
|
cross_frame_special=self.attn.kv_cache_cross_frame_special,
|
|
include_scale_frames=self.attn.kv_cache_include_scale_frames,
|
|
camera_only=self.attn.kv_cache_camera_only,
|
|
num_register_tokens=num_register_tokens,
|
|
)
|
|
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
|
attn_x = manager.compute_attention(global_idx, q_nhd)
|
|
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
|
|
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
|
|
self.attn.num_heads * self.attn.head_dim)
|
|
# Compiled: output projection
|
|
attn_x = self.attn.proj(attn_x)
|
|
x = x + self.ls1(attn_x)
|
|
else:
|
|
# Phase 1 (multi-frame scale pass) or non-streaming training path
|
|
x = x + self.ls1(self.attn(
|
|
self.norm1(x),
|
|
pos=pos,
|
|
enable_ulysses_cp=enable_ulysses_cp,
|
|
num_patches=num_patches,
|
|
num_special=num_special,
|
|
num_frames=num_frames,
|
|
enable_3d_rope=enable_3d_rope,
|
|
kv_cache=kv_cache,
|
|
global_idx=global_idx,
|
|
num_frame_per_block=num_frame_per_block,
|
|
num_frame_for_scale=num_frame_for_scale,
|
|
num_register_tokens=num_register_tokens,
|
|
))
|
|
x = self.ffn_residual(x)
|
|
return x
|
|
|
|
def ffn_residual(self, x: Tensor) -> Tensor:
|
|
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
|
|
|
|
Includes the residual add (x + ...) so torch.compile captures the entire
|
|
ffn branch as one CUDA graph.
|
|
"""
|
|
return x + self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
|
|
class CameraBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
proj_bias: bool = True,
|
|
ffn_bias: bool = True,
|
|
drop: float = 0.0,
|
|
attn_drop: float = 0.0,
|
|
init_values=None,
|
|
drop_path: float = 0.0,
|
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
attn_class: Callable[..., nn.Module] = Attention,
|
|
ffn_layer: Callable[..., nn.Module] = Mlp,
|
|
qk_norm: bool = False,
|
|
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
|
rope=None,
|
|
elementwise_attn_output_gate: bool = False,
|
|
sliding_window_size: int = -1,
|
|
attend_to_scale_frames: bool = False,
|
|
num_random_frames: int = 0,
|
|
# KV cache parameters
|
|
kv_cache_sliding_window: int = 64,
|
|
kv_cache_scale_frames: int = 8,
|
|
kv_cache_cross_frame_special: bool = True,
|
|
kv_cache_include_scale_frames: bool = True,
|
|
kv_cache_camera_only: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
|
|
qk_norm=qk_norm, qkv_bias=qkv_bias,
|
|
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
|
|
kv_cache_sliding_window=kv_cache_sliding_window,
|
|
kv_cache_scale_frames=kv_cache_scale_frames,
|
|
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
|
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
|
kv_cache_camera_only=kv_cache_camera_only)
|
|
|
|
self.sliding_window_size = sliding_window_size
|
|
self.attend_to_scale_frames = attend_to_scale_frames
|
|
self.num_random_frames = num_random_frames
|
|
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
self.mlp = ffn_layer(
|
|
in_features=dim,
|
|
hidden_features=mlp_hidden_dim,
|
|
act_layer=act_layer,
|
|
drop=drop,
|
|
bias=ffn_bias
|
|
)
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
|
|
self.sample_drop_ratio = drop_path
|
|
self.masks = {}
|
|
|
|
@torch.no_grad()
|
|
def _prepare_blockwise_causal_attn_mask(self,
|
|
device: torch.device | str, num_frames: int = 21,
|
|
frame_seqlen: int = 1560, num_frame_per_block=1
|
|
) -> BlockMask:
|
|
"""
|
|
we will divide the token sequence into the following format
|
|
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
|
We use flexattention to construct the attention mask
|
|
"""
|
|
total_length = num_frames * frame_seqlen
|
|
|
|
# we do right padding to get to a multiple of 128
|
|
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
|
|
|
ends = torch.zeros(total_length + padded_length,
|
|
device=device, dtype=torch.long)
|
|
|
|
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
|
frame_indices = torch.arange(
|
|
start=0,
|
|
end=total_length,
|
|
step=frame_seqlen * num_frame_per_block,
|
|
device=device
|
|
)
|
|
|
|
for tmp in frame_indices:
|
|
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
|
frame_seqlen * num_frame_per_block
|
|
|
|
def attention_mask(b, h, q_idx, kv_idx):
|
|
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
|
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
|
|
|
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
|
KV_LEN=total_length + padded_length, device=device)
|
|
|
|
return block_mask
|
|
|
|
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
|
|
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
|
|
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
|
|
|
# Fast path for full attention (camera head) - skip mask computation
|
|
if full_attention:
|
|
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
|
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
|
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor:
|
|
return self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
if self.training and self.sample_drop_ratio > 0.0:
|
|
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
|
x = x + self.drop_path1(ffn_residual_func(x))
|
|
else:
|
|
x = x + attn_residual_func(x, pos=pos)
|
|
x = x + ffn_residual_func(x)
|
|
return x
|
|
|
|
mask_block = self._prepare_blockwise_causal_attn_mask(
|
|
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
|
|
|
|
|
|
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
|
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
|
|
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
|
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor:
|
|
return self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
if self.training and self.sample_drop_ratio > 0.0:
|
|
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
|
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
|
else:
|
|
x = x + attn_residual_func(x, pos=pos)
|
|
x = x + ffn_residual_func(x)
|
|
return x
|
|
|
|
|
|
class SDPABlock(nn.Module):
|
|
"""
|
|
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
|
|
with dict-based KV cache. No FlashInfer dependency required.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_ratio: float = 4.0,
|
|
qkv_bias: bool = True,
|
|
proj_bias: bool = True,
|
|
ffn_bias: bool = True,
|
|
drop: float = 0.0,
|
|
attn_drop: float = 0.0,
|
|
init_values=None,
|
|
drop_path: float = 0.0,
|
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
|
ffn_layer: Callable[..., nn.Module] = Mlp,
|
|
qk_norm: bool = False,
|
|
rope=None,
|
|
kv_cache_sliding_window: int = 64,
|
|
kv_cache_scale_frames: int = 8,
|
|
kv_cache_cross_frame_special: bool = True,
|
|
kv_cache_include_scale_frames: bool = True,
|
|
kv_cache_camera_only: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
self.norm1 = norm_layer(dim)
|
|
self.attn = SDPAAttention(
|
|
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
|
|
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
|
|
kv_cache_sliding_window=kv_cache_sliding_window,
|
|
kv_cache_scale_frames=kv_cache_scale_frames,
|
|
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
|
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
|
kv_cache_camera_only=kv_cache_camera_only,
|
|
)
|
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
self.norm2 = norm_layer(dim)
|
|
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
|
act_layer=act_layer, drop=drop, bias=ffn_bias)
|
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
|
self.sample_drop_ratio = drop_path
|
|
|
|
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
|
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
|
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
|
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
|
def attn_residual_func(x, pos=None):
|
|
return self.ls1(self.attn(
|
|
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
|
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
|
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
|
|
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
|
|
num_register_tokens=num_register_tokens,
|
|
))
|
|
|
|
def ffn_residual_func(x):
|
|
return self.ls2(self.mlp(self.norm2(x)))
|
|
|
|
if self.training and self.sample_drop_ratio > 0.0:
|
|
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
|
x = x + self.drop_path1(ffn_residual_func(x))
|
|
else:
|
|
x = x + attn_residual_func(x, pos=pos)
|
|
x = x + ffn_residual_func(x)
|
|
return x
|