update viser

This commit is contained in:
LinZhuoChen
2026-04-16 12:27:39 +08:00
parent 422600cf85
commit 42de2badd2
5 changed files with 84 additions and 12 deletions

View File

@@ -169,6 +169,11 @@ class FlashInferKVCacheManager:
# Frame counter per block (determines scale vs window routing)
self.frame_count: List[int] = [0] * num_blocks
# Deferred eviction support for flow-based keyframe selection.
# When True, evict_frames() becomes a no-op; caller must later call
# execute_deferred_eviction() or rollback_last_frame().
self._defer_eviction: bool = False
# ── FlashInfer wrapper ───────────────────────────────────────────────
# plan() is called once per frame step (block_idx == 0).
# run() is called per layer, reusing the same aux structures.
@@ -237,11 +242,64 @@ class FlashInferKVCacheManager:
Special pages are NEVER evicted.
Scale pages are NEVER evicted.
Only live_window_patch_pages beyond `sliding_window` are recycled.
When ``_defer_eviction`` is True, this method is a no-op. The caller
is expected to later call ``execute_deferred_eviction()`` (keep frame)
or ``rollback_last_frame()`` (discard frame).
"""
if self._defer_eviction:
return
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def execute_deferred_eviction(
self,
block_idx: int,
scale_frames: int,
sliding_window: int,
**kwargs,
) -> None:
"""Run the eviction that was skipped while ``_defer_eviction`` was True."""
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def rollback_last_frame(self, block_idx: int) -> None:
"""Undo the most recent ``append_frame()`` for *block_idx*.
This reverses all three sub-operations of ``append_frame``:
patch page allocation, special-token write, and frame_count increment.
It must be called **before** any eviction for that frame (i.e. while
``_defer_eviction`` is True or before ``evict_frames`` is called).
"""
assert self.frame_count[block_idx] > 0, (
f"block {block_idx}: cannot rollback, frame_count is 0"
)
# 1) Undo patch page ── pop from whichever deque it was routed to.
if self.frame_count[block_idx] > self.scale_frames:
page_id = self.live_window_patch_pages[block_idx].pop()
else:
page_id = self.scale_patch_pages[block_idx].pop()
self.free_patch_pages[block_idx].append(page_id)
# 2) Undo special tokens
n = self.num_special_tokens
new_count = self.special_token_count[block_idx] - n
assert new_count >= 0, (
f"block {block_idx}: special_token_count underflow "
f"({self.special_token_count[block_idx]} - {n})"
)
new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0
while len(self.all_special_pages[block_idx]) > new_num_pages:
freed = self.all_special_pages[block_idx].pop()
self.free_special_pages[block_idx].append(freed)
self.special_token_count[block_idx] = new_count
# 3) Decrement frame count
self.frame_count[block_idx] -= 1
def _gather_kv(self, block_idx: int):
"""
Gather all visible K and V tokens from the paged cache into dense tensors.