Files
lingbot-map/lingbot_map/layers/flashinfer_cache.py
LinZhuoChen 42de2badd2 update viser
2026-04-16 12:27:39 +08:00

641 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
FlashInfer KV Cache Manager — Two-Stream Paged Design.
Two logical streams sharing one physical page pool per layer:
Patch stream (recyclable):
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
- Exactly 1 patch page per frame
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
Special stream (append-only, never recycled):
- num_special_tokens (6) special tokens per frame
- Packed continuously: one special page holds floor(page_size/6) frames
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
- Specials written for EVERY frame (including scale + window), not just evicted ones.
Physical layout per block:
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
dim 1: 0=K 1=V
Attention computation:
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
Special pages placed LAST → paged_kv_last_page_len naturally describes
the partial special-tail without a custom mask.
plan() is called ONCE per frame step (when block_idx == 0).
run() is called per layer, reusing the same plan. All layers at the
same frame step have identical page structures (same page IDs in same
positions), so reusing the plan across layers is correct.
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
append_frame(block_idx, k, v)
evict_frames(block_idx, scale_frames, sliding_window, ...)
compute_attention(block_idx, q) -> out
reset()
"""
import collections
import math
from typing import List
import torch
from torch import Tensor
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
class FlashInferKVCacheManager:
"""
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
Args:
num_blocks: Number of Transformer blocks (one cache per block).
max_num_frames: Maximum frames held in the KV window at once
(scale_frames + sliding_window + headroom).
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
num_heads: Number of KV heads (= QO heads; MHA assumed).
head_dim: Head dimension (64 for ViT-L).
dtype: Storage dtype (bfloat16 / float16).
device: CUDA device.
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
scale_frames: Number of always-resident scale frames (8).
sliding_window: Sliding window size (64).
max_total_frames: Upper bound on total frames ever processed; used to
pre-allocate the special page pool (default 2048).
"""
def __init__(
self,
num_blocks: int,
max_num_frames: int,
tokens_per_frame: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
num_special_tokens: int = 6,
scale_frames: int = 8,
sliding_window: int = 64,
max_total_frames: int = 2048,
force_fp32: bool = False,
fa3: bool = False,
):
if not FLASHINFER_AVAILABLE:
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
self.num_blocks = num_blocks
self.num_special_tokens = num_special_tokens # 6
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
p = self.patches_per_frame
if fa3:
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
self.page_size = 1 << (p - 1).bit_length()
else:
self.page_size = p # exact: no zero padding in patch pages
self.scale_frames = scale_frames # 8
self.sliding_window = sliding_window # 64
self.num_heads = num_heads
self.head_dim = head_dim
self.tokens_per_frame = tokens_per_frame
assert self.patches_per_frame > 0, (
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
)
assert self.page_size > 0
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
self.force_fp32 = force_fp32
if force_fp32:
self.dtype = torch.float32
else:
if dtype == torch.float32:
dtype = torch.bfloat16
self.dtype = dtype
self.device = device
# ── Page pool sizing ─────────────────────────────────────────────────
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
max_special_pages = (
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
)
self.max_patch_pages = max_patch_pages
self.max_num_pages = max_patch_pages + max_special_pages
# ── Physical paged KV caches ─────────────────────────────────────────
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
self.kv_caches: List[Tensor] = [
torch.zeros(
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
dtype=dtype, device=device,
)
for _ in range(num_blocks)
]
# ── Per-block state ──────────────────────────────────────────────────
# Patch pages (IDs 0 .. max_patch_pages-1)
self.scale_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.live_window_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.free_patch_pages: List[List[int]] = [
list(range(max_patch_pages)) for _ in range(num_blocks)
]
# Special pages (IDs max_patch_pages .. max_num_pages-1)
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
self.free_special_pages: List[List[int]] = [
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
]
self.special_token_count: List[int] = [0] * num_blocks
# Frame counter per block (determines scale vs window routing)
self.frame_count: List[int] = [0] * num_blocks
# Deferred eviction support for flow-based keyframe selection.
# When True, evict_frames() becomes a no-op; caller must later call
# execute_deferred_eviction() or rollback_last_frame().
self._defer_eviction: bool = False
# ── FlashInfer wrapper ───────────────────────────────────────────────
# plan() is called once per frame step (block_idx == 0).
# run() is called per layer, reusing the same aux structures.
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
# FlashInfer 0.2.5 at 518×378 resolution.
_fi_backend = "fa3" if fa3 else "fa2"
self.workspace_buffer = torch.zeros(
128 * 1024 * 1024, dtype=torch.uint8, device=device
)
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
kv_layout="NHD",
backend=_fi_backend,
)
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
self._qo_indptr = torch.tensor(
[0, tokens_per_frame], dtype=torch.int32, device=device
)
# =========================================================================
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
# =========================================================================
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
"""
Append one frame's K/V tensors to the two-stream cache.
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
i.e. specials come first (matching stream.py's patch_start_idx convention).
Args:
block_idx: Block/layer index (0 … num_blocks-1).
k: [tokens_per_frame, H, D] NHD layout.
v: [tokens_per_frame, H, D] NHD layout.
"""
n = self.num_special_tokens # 6
sp_k = k[:n].to(self.dtype) # [6, H, D]
patch_k = k[n:].to(self.dtype) # [256, H, D]
sp_v = v[:n].to(self.dtype)
patch_v = v[n:].to(self.dtype)
assert patch_k.shape[0] == self.patches_per_frame, (
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
)
self._write_patch_page(block_idx, patch_k, patch_v)
self._write_special_tokens(block_idx, sp_k, sp_v)
self.frame_count[block_idx] += 1
def evict_frames(
self,
block_idx: int,
scale_frames: int,
sliding_window: int,
cross_frame_special: bool = True,
include_scale_frames: bool = True,
camera_only: bool = False,
num_register_tokens: int = 4,
) -> None:
"""
Evict old window patch pages (recycle to free list).
Special pages are NEVER evicted.
Scale pages are NEVER evicted.
Only live_window_patch_pages beyond `sliding_window` are recycled.
When ``_defer_eviction`` is True, this method is a no-op. The caller
is expected to later call ``execute_deferred_eviction()`` (keep frame)
or ``rollback_last_frame()`` (discard frame).
"""
if self._defer_eviction:
return
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def execute_deferred_eviction(
self,
block_idx: int,
scale_frames: int,
sliding_window: int,
**kwargs,
) -> None:
"""Run the eviction that was skipped while ``_defer_eviction`` was True."""
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def rollback_last_frame(self, block_idx: int) -> None:
"""Undo the most recent ``append_frame()`` for *block_idx*.
This reverses all three sub-operations of ``append_frame``:
patch page allocation, special-token write, and frame_count increment.
It must be called **before** any eviction for that frame (i.e. while
``_defer_eviction`` is True or before ``evict_frames`` is called).
"""
assert self.frame_count[block_idx] > 0, (
f"block {block_idx}: cannot rollback, frame_count is 0"
)
# 1) Undo patch page ── pop from whichever deque it was routed to.
if self.frame_count[block_idx] > self.scale_frames:
page_id = self.live_window_patch_pages[block_idx].pop()
else:
page_id = self.scale_patch_pages[block_idx].pop()
self.free_patch_pages[block_idx].append(page_id)
# 2) Undo special tokens
n = self.num_special_tokens
new_count = self.special_token_count[block_idx] - n
assert new_count >= 0, (
f"block {block_idx}: special_token_count underflow "
f"({self.special_token_count[block_idx]} - {n})"
)
new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0
while len(self.all_special_pages[block_idx]) > new_num_pages:
freed = self.all_special_pages[block_idx].pop()
self.free_special_pages[block_idx].append(freed)
self.special_token_count[block_idx] = new_count
# 3) Decrement frame count
self.frame_count[block_idx] -= 1
def _gather_kv(self, block_idx: int):
"""
Gather all visible K and V tokens from the paged cache into dense tensors.
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
Returns:
k_flat: [kv_len, H, D] — all visible K tokens concatenated
v_flat: [kv_len, H, D] — all visible V tokens concatenated
"""
visible = self.build_visible_page_table(block_idx)
last_len = self.compute_last_page_len(block_idx)
P = self.page_size
parts_k, parts_v = [], []
for i, pid in enumerate(visible):
n = last_len if (i == len(visible) - 1) else P
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
v_flat = torch.cat(parts_v, dim=0)
return k_flat, v_flat
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
"""
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
When self.force_fp32 is True, gathers all visible K/V into dense tensors
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
plan() is called once per frame step (when block_idx == 0).
All layers at the same step share the same visible page structure,
so the plan is reused by calling run() with each layer's kv_cache.
Args:
block_idx: Block/layer index.
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
Returns:
out: [q_len, H, D]
"""
if self.frame_count[block_idx] == 0:
# No KV present yet (should not occur in normal usage after append_frame)
return torch.zeros_like(q)
if self.force_fp32:
# ── fp32 gather+SDPA path ─────────────────────────────────────────
# Gather visible K/V from paged cache and run SDPA in fp32.
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
import torch.nn.functional as F_nn
k_flat, v_flat = self._gather_kv(block_idx)
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
if block_idx == 0:
# ── Plan once per frame step ──────────────────────────────────────
# Build visible page table from block 0's state.
# All blocks have identical page structures, so this plan is valid
# for all subsequent run() calls (block_idx = 1, 2, ...).
visible = self.build_visible_page_table(0)
last_len = self.compute_last_page_len(0)
assert visible, "visible page table is empty after append_frame"
assert 1 <= last_len <= self.page_size, (
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
)
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
self.prefill_wrapper.plan(
self._qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads = self.num_heads,
num_kv_heads = self.num_heads,
head_dim_qk = self.head_dim,
page_size = self.page_size,
causal = False, # custom page ordering; no causal mask
pos_encoding_mode = "NONE", # RoPE applied externally before append
q_data_type = self.dtype,
)
# ── Run attention for this layer ──────────────────────────────────────
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
return self.prefill_wrapper.run(
q = q.to(self.dtype).contiguous(),
paged_kv_cache = self.kv_caches[block_idx],
) # → [q_len, H, D]
def reset(self) -> None:
"""Reset all per-block state for a new sequence."""
for i in range(self.num_blocks):
self.scale_patch_pages[i].clear()
self.live_window_patch_pages[i].clear()
self.all_special_pages[i].clear()
self.free_patch_pages[i] = list(range(self.max_patch_pages))
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
self.special_token_count[i] = 0
self.frame_count[i] = 0
# =========================================================================
# Helper methods
# =========================================================================
def build_visible_page_table(self, block_idx: int) -> List[int]:
"""
Return page IDs in strict order: scale → window → special.
Placing special pages last means only the final page may be partially
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
without a custom attention mask.
"""
return (
list(self.scale_patch_pages[block_idx]) +
list(self.live_window_patch_pages[block_idx]) +
list(self.all_special_pages[block_idx])
)
def compute_last_page_len(self, block_idx: int) -> int:
"""
Valid token count in the last page of the visible sequence.
- No special pages → last page is a patch page.
Returns patches_per_frame (real tokens written),
which may be < page_size when page_size was rounded
up to a power of 2.
- Special tail partial → special_token_count % page_size.
- Special tail exactly full → page_size.
"""
if not self.all_special_pages[block_idx]:
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
# valid count so it doesn't read beyond the real tokens.
return self.patches_per_frame
tail = self.special_token_count[block_idx] % self.page_size
return self.page_size if tail == 0 else tail
# ── Internal write helpers ────────────────────────────────────────────────
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
"""
Allocate one free patch page and write patches_per_frame patch tokens.
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
patch_k/v fill exactly one full page (patches_per_frame == page_size).
Routes to scale_patch_pages if still filling scale quota,
otherwise to live_window_patch_pages.
Returns:
page_id: Physical page index used.
"""
assert self.free_patch_pages[block_idx], (
f"block {block_idx}: patch page pool exhausted — "
f"scale={len(self.scale_patch_pages[block_idx])}, "
f"window={len(self.live_window_patch_pages[block_idx])}, "
f"free={len(self.free_patch_pages[block_idx])}"
)
page_id = self.free_patch_pages[block_idx].pop()
# Direct slice write: positions 0..patches_per_frame-1.
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
# this is equivalent to a full-page write. When page_size > patches_per_frame
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
P = self.patches_per_frame
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
self.scale_patch_pages[block_idx].append(page_id)
else:
self.live_window_patch_pages[block_idx].append(page_id)
return page_id
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
"""
Append num_special_tokens (6) special tokens to the special stream.
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
overhead of flashinfer.page.append_paged_kv_cache.
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
two slice writes (rare — page_size=256 >> 6).
"""
remaining = self.num_special_tokens # 6
written = 0
while remaining > 0:
tail_offset = self.special_token_count[block_idx] % self.page_size
if tail_offset == 0:
# Current tail page is full (or no page exists) — allocate a new one
assert self.free_special_pages[block_idx], (
f"block {block_idx}: special page pool exhausted at "
f"special_token_count={self.special_token_count[block_idx]}. "
f"Increase max_total_frames."
)
new_page = self.free_special_pages[block_idx].pop()
self.all_special_pages[block_idx].append(new_page)
tail_page = self.all_special_pages[block_idx][-1]
space = self.page_size - tail_offset # free slots in tail page
write_n = min(remaining, space)
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
end = tail_offset + write_n
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
self.special_token_count[block_idx] += write_n
written += write_n
remaining -= write_n
# ── Legacy property (used by stream.py) ──────────────────────────────────
@property
def num_frames(self) -> int:
"""Number of frames appended to block 0 (representative)."""
return self.frame_count[0] if self.frame_count else 0
# =============================================================================
# Sanity check
# =============================================================================
def _sanity_check():
"""
Minimal smoke test.
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("[sanity_check] CUDA not available — skipping.")
return
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
num_special = 6
patches_per_frame = tokens_per_frame - num_special # 256
page_size = patches_per_frame # 256
mgr = FlashInferKVCacheManager(
num_blocks = 2,
max_num_frames = 88,
tokens_per_frame = tokens_per_frame,
num_heads = 16,
head_dim = 64,
dtype = torch.bfloat16,
device = device,
num_special_tokens = num_special,
scale_frames = 8,
sliding_window = 64,
max_total_frames = 200,
)
def make_kv():
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
return k, v
def make_q():
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
for block in range(2):
for t in range(100):
k, v = make_kv()
mgr.append_frame(block, k, v)
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
# ── Page count checks ───────────────────────────────────────────────
n_scale = len(mgr.scale_patch_pages[block])
n_window = len(mgr.live_window_patch_pages[block])
n_spec = len(mgr.all_special_pages[block])
sp_count = mgr.special_token_count[block]
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
expected_spec_pages = math.ceil(100 * num_special / page_size)
assert n_spec == expected_spec_pages, (
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
)
assert sp_count == 100 * num_special, (
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
)
# ── last_page_len ────────────────────────────────────────────────────
last_len = mgr.compute_last_page_len(block)
tail = sp_count % page_size
expected_len = page_size if tail == 0 else tail
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
# ── visible page table order ─────────────────────────────────────────
visible = mgr.build_visible_page_table(block)
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
for pid in visible[:n_scale + n_window]:
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
for pid in visible[n_scale + n_window:]:
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
# ── forward pass: plan() once for block 0, run() for both blocks ─────
if block == 1:
# Simulate the actual calling pattern: plan on block 0, run on both
q0 = make_q()
out0 = mgr.compute_attention(0, q0) # triggers plan()
q1 = make_q()
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
assert out0.shape == (tokens_per_frame, 16, 64)
assert out1.shape == (tokens_per_frame, 16, 64)
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
f"special_pages={n_spec}, special_tokens={sp_count}, "
f"last_page_len={last_len}")
mgr.reset()
assert mgr.frame_count[0] == 0
print("\n[sanity_check] All assertions passed.")
if __name__ == "__main__":
_sanity_check()