diff --git a/.gitignore b/.gitignore index 4df2626..cbcbae1 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ demo_render/ CLAUDE.md .claude/ .agents/ +skyseg.onnx \ No newline at end of file diff --git a/README.md b/README.md index cc78637..b507a28 100644 --- a/README.md +++ b/README.md @@ -115,13 +115,30 @@ python demo.py --model_path /path/to/checkpoint.pt \ ``` -### With Sky Masking +### Sky Masking + +Sky masking uses an ONNX sky segmentation model to filter out sky points from the reconstructed point cloud, which improves visualization quality for outdoor scenes. + +**Setup:** + +```bash +# Install onnxruntime (required) +pip install onnxruntime # CPU +# or +pip install onnxruntime-gpu # GPU (faster for large image sets) +``` + +The sky segmentation model (`skyseg.onnx`) will be automatically downloaded from [HuggingFace](https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx) on first use. + +**Usage:** ```bash python demo.py --model_path /path/to/checkpoint.pt \ --image_folder /path/to/images/ --mask_sky ``` +Sky masks are cached in `_sky_masks/` so subsequent runs skip regeneration. + ### Without FlashInfer (SDPA fallback) ```bash diff --git a/demo.py b/demo.py index 338321e..14dbe44 100644 --- a/demo.py +++ b/demo.py @@ -138,9 +138,8 @@ _BATCHED_NDIMS = { "world_points_conf": 4, "extrinsic": 4, "intrinsic": 4, - "chunk_sim3_scales": 2, - "chunk_sim3_poses": 4, - "chunk_se3_poses": 4, + "chunk_scales": 2, + "chunk_transforms": 4, "images": 5, } @@ -256,14 +255,13 @@ def main(): # Windowed options parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)") parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows") - parser.add_argument("--sim3", action="store_true", default=True, help="Use Sim(3) alignment between windows") - parser.add_argument("--no_sim3", dest="sim3", action="store_false", help="Disable Sim(3), use SE(3) instead") + # Visualization parser.add_argument("--port", type=int, default=8080) - parser.add_argument("--conf_threshold", type=float, default=1.0) + parser.add_argument("--conf_threshold", type=float, default=1.5) parser.add_argument("--downsample_factor", type=int, default=10) - parser.add_argument("--point_size", type=float, default=0.005) + parser.add_argument("--point_size", type=float, default=0.0007) parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") args = parser.parse_args() @@ -314,8 +312,6 @@ def main(): window_size=args.window_size, overlap_size=args.overlap_size, num_scale_frames=args.num_scale_frames, - sim3=args.sim3, - se3=not args.sim3, ) t_infer = time.time() - t0 @@ -330,7 +326,7 @@ def main(): viewer = PointCloudViewer( pred_dict=prepare_for_visualization(predictions, images_cpu), port=args.port, - init_conf_threshold=args.conf_threshold, + vis_threshold=args.conf_threshold, downsample_factor=args.downsample_factor, point_size=args.point_size, mask_sky=args.mask_sky, diff --git a/lingbot_map/layers/flashinfer_cache.py b/lingbot_map/layers/flashinfer_cache.py index 1660f97..d8e7b13 100644 --- a/lingbot_map/layers/flashinfer_cache.py +++ b/lingbot_map/layers/flashinfer_cache.py @@ -169,6 +169,11 @@ class FlashInferKVCacheManager: # Frame counter per block (determines scale vs window routing) self.frame_count: List[int] = [0] * num_blocks + # Deferred eviction support for flow-based keyframe selection. + # When True, evict_frames() becomes a no-op; caller must later call + # execute_deferred_eviction() or rollback_last_frame(). + self._defer_eviction: bool = False + # ── FlashInfer wrapper ─────────────────────────────────────────────── # plan() is called once per frame step (block_idx == 0). # run() is called per layer, reusing the same aux structures. @@ -237,11 +242,64 @@ class FlashInferKVCacheManager: Special pages are NEVER evicted. Scale pages are NEVER evicted. Only live_window_patch_pages beyond `sliding_window` are recycled. + + When ``_defer_eviction`` is True, this method is a no-op. The caller + is expected to later call ``execute_deferred_eviction()`` (keep frame) + or ``rollback_last_frame()`` (discard frame). """ + if self._defer_eviction: + return while len(self.live_window_patch_pages[block_idx]) > sliding_window: old_page = self.live_window_patch_pages[block_idx].popleft() self.free_patch_pages[block_idx].append(old_page) + def execute_deferred_eviction( + self, + block_idx: int, + scale_frames: int, + sliding_window: int, + **kwargs, + ) -> None: + """Run the eviction that was skipped while ``_defer_eviction`` was True.""" + while len(self.live_window_patch_pages[block_idx]) > sliding_window: + old_page = self.live_window_patch_pages[block_idx].popleft() + self.free_patch_pages[block_idx].append(old_page) + + def rollback_last_frame(self, block_idx: int) -> None: + """Undo the most recent ``append_frame()`` for *block_idx*. + + This reverses all three sub-operations of ``append_frame``: + patch page allocation, special-token write, and frame_count increment. + It must be called **before** any eviction for that frame (i.e. while + ``_defer_eviction`` is True or before ``evict_frames`` is called). + """ + assert self.frame_count[block_idx] > 0, ( + f"block {block_idx}: cannot rollback, frame_count is 0" + ) + + # 1) Undo patch page ── pop from whichever deque it was routed to. + if self.frame_count[block_idx] > self.scale_frames: + page_id = self.live_window_patch_pages[block_idx].pop() + else: + page_id = self.scale_patch_pages[block_idx].pop() + self.free_patch_pages[block_idx].append(page_id) + + # 2) Undo special tokens + n = self.num_special_tokens + new_count = self.special_token_count[block_idx] - n + assert new_count >= 0, ( + f"block {block_idx}: special_token_count underflow " + f"({self.special_token_count[block_idx]} - {n})" + ) + new_num_pages = math.ceil(new_count / self.page_size) if new_count > 0 else 0 + while len(self.all_special_pages[block_idx]) > new_num_pages: + freed = self.all_special_pages[block_idx].pop() + self.free_special_pages[block_idx].append(freed) + self.special_token_count[block_idx] = new_count + + # 3) Decrement frame count + self.frame_count[block_idx] -= 1 + def _gather_kv(self, block_idx: int): """ Gather all visible K and V tokens from the paged cache into dense tensors. diff --git a/lingbot_map/vis/point_cloud_viewer.py b/lingbot_map/vis/point_cloud_viewer.py index b90f281..ed19434 100644 --- a/lingbot_map/vis/point_cloud_viewer.py +++ b/lingbot_map/vis/point_cloud_viewer.py @@ -1486,7 +1486,7 @@ class PointCloudViewer: aspect=aspect, wxyz=q, position=t, - scale=0.1, + scale=0.03, color=camera_color_rgb, )