first commit
This commit is contained in:
514
lingbot_map/layers/block.py
Normal file
514
lingbot_map/layers/block.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
import warnings
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
|
||||
from functools import lru_cache, partial
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_mask
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
qk_norm=qk_norm,
|
||||
fused_attn=fused_attn,
|
||||
rope=rope,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
if pos is not None:
|
||||
# if necessary, apply rope to the subset
|
||||
pos = pos[brange]
|
||||
residual = residual_func(x_subset, pos=pos)
|
||||
else:
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
class FlashInferBlock(nn.Module):
|
||||
"""
|
||||
FlashInfer variant of causal block for GCT.
|
||||
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
|
||||
Supports optimized token layout and KV cache streaming inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = FlashInferAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=qk_norm,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
||||
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
|
||||
|
||||
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
|
||||
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
|
||||
|
||||
Returns:
|
||||
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
|
||||
ready for manager.append_frame + manager.compute_attention.
|
||||
"""
|
||||
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos=None,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=None,
|
||||
num_special=None,
|
||||
num_frames=None,
|
||||
enable_3d_rope=False,
|
||||
kv_cache=None,
|
||||
global_idx=0,
|
||||
num_frame_per_block=1,
|
||||
num_frame_for_scale=-1,
|
||||
num_register_tokens=4,
|
||||
) -> Tensor:
|
||||
# Phase 2 (streaming): single-frame FlashInfer paged attention.
|
||||
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
|
||||
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
|
||||
if is_streaming:
|
||||
manager = kv_cache
|
||||
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
|
||||
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
# Eager: write frame K/V to paged cache
|
||||
manager.append_frame(global_idx, k_nhd, v_nhd)
|
||||
# CPU-only: update eviction state (deque ops, no GPU kernel)
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.attn.kv_cache_scale_frames,
|
||||
sliding_window=self.attn.kv_cache_sliding_window,
|
||||
cross_frame_special=self.attn.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.attn.kv_cache_include_scale_frames,
|
||||
camera_only=self.attn.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
||||
attn_x = manager.compute_attention(global_idx, q_nhd)
|
||||
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
|
||||
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
|
||||
self.attn.num_heads * self.attn.head_dim)
|
||||
# Compiled: output projection
|
||||
attn_x = self.attn.proj(attn_x)
|
||||
x = x + self.ls1(attn_x)
|
||||
else:
|
||||
# Phase 1 (multi-frame scale pass) or non-streaming training path
|
||||
x = x + self.ls1(self.attn(
|
||||
self.norm1(x),
|
||||
pos=pos,
|
||||
enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches,
|
||||
num_special=num_special,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope,
|
||||
kv_cache=kv_cache,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
x = self.ffn_residual(x)
|
||||
return x
|
||||
|
||||
def ffn_residual(self, x: Tensor) -> Tensor:
|
||||
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
|
||||
|
||||
Includes the residual add (x + ...) so torch.compile captures the entire
|
||||
ffn branch as one CUDA graph.
|
||||
"""
|
||||
return x + self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
|
||||
class CameraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
elementwise_attn_output_gate: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
attend_to_scale_frames: bool = False,
|
||||
num_random_frames: int = 0,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
|
||||
qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only)
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.num_random_frames = num_random_frames
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
self.masks = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_blockwise_causal_attn_mask(self,
|
||||
device: torch.device | str, num_frames: int = 21,
|
||||
frame_seqlen: int = 1560, num_frame_per_block=1
|
||||
) -> BlockMask:
|
||||
"""
|
||||
we will divide the token sequence into the following format
|
||||
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
||||
We use flexattention to construct the attention mask
|
||||
"""
|
||||
total_length = num_frames * frame_seqlen
|
||||
|
||||
# we do right padding to get to a multiple of 128
|
||||
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
||||
|
||||
ends = torch.zeros(total_length + padded_length,
|
||||
device=device, dtype=torch.long)
|
||||
|
||||
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
||||
frame_indices = torch.arange(
|
||||
start=0,
|
||||
end=total_length,
|
||||
step=frame_seqlen * num_frame_per_block,
|
||||
device=device
|
||||
)
|
||||
|
||||
for tmp in frame_indices:
|
||||
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
||||
frame_seqlen * num_frame_per_block
|
||||
|
||||
def attention_mask(b, h, q_idx, kv_idx):
|
||||
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
||||
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
||||
|
||||
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
||||
KV_LEN=total_length + padded_length, device=device)
|
||||
|
||||
return block_mask
|
||||
|
||||
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
|
||||
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
# Fast path for full attention (camera head) - skip mask computation
|
||||
if full_attention:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
mask_block = self._prepare_blockwise_causal_attn_mask(
|
||||
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
|
||||
|
||||
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
|
||||
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
class SDPABlock(nn.Module):
|
||||
"""
|
||||
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
|
||||
with dict-based KV cache. No FlashInfer dependency required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = SDPAAttention(
|
||||
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer, drop=drop, bias=ffn_bias)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
def attn_residual_func(x, pos=None):
|
||||
return self.ls1(self.attn(
|
||||
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
|
||||
def ffn_residual_func(x):
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
Reference in New Issue
Block a user