commit f9b3ae457a4659bbb6016dabb6e37b6ea8a34854 Author: LinZhuoChen Date: Thu Apr 16 09:51:30 2026 +0800 first commit diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..a9884d6 Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4df2626 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +dist/ +build/ +*.so +.eggs/ +demo_render/ +CLAUDE.md +.claude/ +.agents/ diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..e395ca3 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..854e146 --- /dev/null +++ b/README.md @@ -0,0 +1,141 @@ +

LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction

+ +

+ + + + +

+ +

+ +

+ +

+ +

+ +--- + +# Quick Start + +## Installation + +**1. Create conda environment** + +```bash +conda create -n lingbot-map python=3.10 -y +conda activate lingbot-map +``` + +**2. Install PyTorch (CUDA 12.8)** + +```bash +pip install torch==2.9.1 torchvision==0.24.1 --index-url https://download.pytorch.org/whl/cu128 +``` + +> For other CUDA versions, see [PyTorch Get Started](https://pytorch.org/get-started/locally/). + +**3. Install lingbot-map** + +```bash +pip install -e . +``` + +**4. Install FlashInfer (recommended)** + +FlashInfer provides paged KV cache attention for efficient streaming inference: + +```bash +# CUDA 12.8 + PyTorch 2.9 +pip install flashinfer-python -i https://flashinfer.ai/whl/cu128/torch2.9/ +``` + +> For other CUDA/PyTorch combinations, see [FlashInfer installation](https://docs.flashinfer.ai/installation.html). +> If FlashInfer is not installed, the model falls back to SDPA (PyTorch native attention) via `--use_sdpa`. + +**5. Visualization dependencies (optional)** + +```bash +pip install -e ".[vis]" +``` + +# Demo + +## Streaming Inference from Images + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ +``` + +## Streaming Inference from Video + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --video_path video.mp4 --fps 10 +``` + +## Streaming with Keyframe Interval + +Use `--keyframe_interval` to reduce KV cache memory by only keeping every N-th frame as a keyframe. Non-keyframe frames still produce predictions but are not stored in the cache. This is useful for long sequences +which excesses 320 frames. + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ --keyframe_interval 6 +``` + +## Windowed Inference (for long sequences, >3000 frames) +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --video_path video.mp4 --fps 10 \ + --mode windowed --window_size 64 +``` + + +## With Sky Masking + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ --mask_sky +``` + +## Without FlashInfer (SDPA fallback) + +```bash +python demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ --use_sdpa +``` + +# Model Download + +| Model Name | Huggingface Repository | Description | +| :--- | :--- | :--- | +| lingbot-map | [robbyant/lingbot-map](https://huggingface.co/robbyant/lingbot-map) | Base model checkpoint (4.63 GB) | + + +# License + +This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details. + +# Citation + +```bibtex +@article{lingbot-map2026, + title={}, + author={}, + journal={arXiv preprint arXiv:}, + year={2026} +} +``` + +# Acknowledgments + +This work builds upon several excellent open-source projects: + +- [VGGT](https://github.com/facebookresearch/vggt) +- [DINOv2](https://github.com/facebookresearch/dinov2) +- [Flashinfer](https://github.com/flashinfer-ai/flashinfer) + +--- \ No newline at end of file diff --git a/assets/teaser.png b/assets/teaser.png new file mode 100644 index 0000000..1f1e3c6 Binary files /dev/null and b/assets/teaser.png differ diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..338321e --- /dev/null +++ b/demo.py @@ -0,0 +1,346 @@ +"""LingBot-MAP demo: streaming 3D reconstruction from images or video. + +Usage: + # Streaming inference (frame-by-frame with KV cache) + python examples/demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ + + # Streaming inference with keyframe KV caching + python examples/demo.py --model_path /path/to/checkpoint.pt \ + --image_folder /path/to/images/ --mode streaming --keyframe_interval 6 + + # Windowed inference (for very long sequences, >500 frames) + python examples/demo.py --model_path /path/to/checkpoint.pt \ + --video_path video.mp4 --fps 10 --mode windowed --window_size 64 + + # From video with custom FPS sampling + python examples/demo.py --model_path /path/to/checkpoint.pt \ + --video_path video.mp4 --fps 10 +""" + +import argparse +import glob +import os +import time + +import cv2 +import numpy as np +import torch +from tqdm.auto import tqdm + +from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri +from lingbot_map.utils.geometry import closed_form_inverse_se3_general +from lingbot_map.utils.load_fn import load_and_preprocess_images + + +# ============================================================================= +# Image loading +# ============================================================================= + +def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png", + first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8): + """Load images from folder or video and preprocess into a tensor.""" + if video_path is not None: + video_name = os.path.splitext(os.path.basename(video_path))[0] + out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames") + os.makedirs(out_dir, exist_ok=True) + cap = cv2.VideoCapture(video_path) + src_fps = cap.get(cv2.CAP_PROP_FPS) or 30 + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + interval = max(1, round(src_fps / fps)) + idx, saved = 0, [] + pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame") + while True: + ret, frame = cap.read() + if not ret: + break + if idx % interval == 0: + path = os.path.join(out_dir, f"{len(saved):06d}.jpg") + cv2.imwrite(path, frame) + saved.append(path) + idx += 1 + pbar.update(1) + pbar.close() + cap.release() + paths = saved + print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})") + else: + exts = image_ext.split(",") + paths = [] + for ext in exts: + paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}"))) + paths = sorted(paths) + + if stride > 1: + paths = paths[::stride] + if first_k is not None and first_k > 0: + paths = paths[:first_k] + + print(f"Loading {len(paths)} images...") + images = load_and_preprocess_images( + paths, + mode="crop", + image_size=image_size, + patch_size=patch_size, + ) + h, w = images.shape[-2:] + print(f"Preprocessed images to {w}x{h} using canonical crop mode") + return images, paths + + +# ============================================================================= +# Model loading +# ============================================================================= + +def load_model(args, device): + """Load GCTStream model from checkpoint.""" + if getattr(args, "mode", "streaming") == "windowed": + from lingbot_map.models.gct_stream_window import GCTStream + else: + from lingbot_map.models.gct_stream import GCTStream + + print("Building model...") + model = GCTStream( + img_size=args.image_size, + patch_size=args.patch_size, + enable_3d_rope=args.enable_3d_rope, + max_frame_num=args.max_frame_num, + kv_cache_sliding_window=args.kv_cache_sliding_window, + kv_cache_scale_frames=args.kv_cache_scale_frames, + kv_cache_cross_frame_special=True, + kv_cache_include_scale_frames=True, + use_sdpa=args.use_sdpa, + ) + + if args.model_path: + print(f"Loading checkpoint: {args.model_path}") + ckpt = torch.load(args.model_path, map_location=device, weights_only=False) + state_dict = ckpt.get("model", ckpt) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing: + print(f" Missing keys: {len(missing)}") + if unexpected: + print(f" Unexpected keys: {len(unexpected)}") + print(" Checkpoint loaded.") + + return model.to(device).eval() + + +# ============================================================================= +# Post-processing +# ============================================================================= + +_BATCHED_NDIMS = { + "pose_enc": 3, + "depth": 5, + "depth_conf": 4, + "world_points": 5, + "world_points_conf": 4, + "extrinsic": 4, + "intrinsic": 4, + "chunk_sim3_scales": 2, + "chunk_sim3_poses": 4, + "chunk_se3_poses": 4, + "images": 5, +} + + +def _squeeze_single_batch(key, value): + """Drop the leading batch dimension for single-sequence demo outputs.""" + batched_ndim = _BATCHED_NDIMS.get(key) + if batched_ndim is None or not hasattr(value, "ndim"): + return value + if value.ndim == batched_ndim and value.shape[0] == 1: + return value[0] + return value + + +def postprocess(predictions, images): + """Convert pose encoding to extrinsics (c2w) and move to CPU.""" + extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + + # Convert w2c to c2w + extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype) + extrinsic_4x4[..., :3, :4] = extrinsic + extrinsic_4x4[..., 3, 3] = 1.0 + extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4) + extrinsic = extrinsic_4x4[..., :3, :4] + + predictions["extrinsic"] = extrinsic + predictions["intrinsic"] = intrinsic + predictions.pop("pose_enc_list", None) + predictions.pop("images", None) + + print("Moving results to CPU...") + for k in list(predictions.keys()): + if isinstance(predictions[k], torch.Tensor): + predictions[k] = _squeeze_single_batch( + k, predictions[k].to("cpu", non_blocking=True) + ) + images_cpu = images.to("cpu", non_blocking=True) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + return predictions, images_cpu + + +def prepare_for_visualization(predictions, images=None): + """Convert predictions to the unbatched NumPy format used by vis code.""" + vis_predictions = {} + for k, v in predictions.items(): + if isinstance(v, torch.Tensor): + v = _squeeze_single_batch(k, v.detach().cpu()) + vis_predictions[k] = v.numpy() + elif isinstance(v, np.ndarray): + vis_predictions[k] = _squeeze_single_batch(k, v) + else: + vis_predictions[k] = v + + if images is None: + images = predictions.get("images") + + if isinstance(images, torch.Tensor): + images = images.detach().cpu() + if isinstance(images, np.ndarray): + images = _squeeze_single_batch("images", images) + elif isinstance(images, torch.Tensor): + images = _squeeze_single_batch("images", images).numpy() + + if isinstance(images, torch.Tensor): + images = images.numpy() + + if images is not None: + vis_predictions["images"] = images + + return vis_predictions + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo") + + # Input + parser.add_argument("--image_folder", type=str, default=None) + parser.add_argument("--video_path", type=str, default=None) + parser.add_argument("--fps", type=int, default=10) + parser.add_argument("--first_k", type=int, default=None) + parser.add_argument("--stride", type=int, default=1) + + # Model + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--image_size", type=int, default=518) + parser.add_argument("--patch_size", type=int, default=14) + + # Inference mode + parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"], + help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences") + + # Streaming options + parser.add_argument("--enable_3d_rope", action="store_true", default=True) + parser.add_argument("--max_frame_num", type=int, default=1024) + parser.add_argument("--num_scale_frames", type=int, default=8) + parser.add_argument( + "--keyframe_interval", + type=int, + default=1, + help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame.", + ) + parser.add_argument("--kv_cache_sliding_window", type=int, default=64) + parser.add_argument("--kv_cache_scale_frames", type=int, default=8) + parser.add_argument("--use_sdpa", action="store_true", default=False, + help="Use SDPA backend (no flashinfer needed). Default: FlashInfer") + + # Windowed options + parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)") + parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows") + parser.add_argument("--sim3", action="store_true", default=True, help="Use Sim(3) alignment between windows") + parser.add_argument("--no_sim3", dest="sim3", action="store_false", help="Disable Sim(3), use SE(3) instead") + + # Visualization + parser.add_argument("--port", type=int, default=8080) + parser.add_argument("--conf_threshold", type=float, default=1.0) + parser.add_argument("--downsample_factor", type=int, default=10) + parser.add_argument("--point_size", type=float, default=0.005) + parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points") + + args = parser.parse_args() + assert args.image_folder or args.video_path, \ + "Provide --image_folder or --video_path" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # ── Load images & model ────────────────────────────────────────────────── + t0 = time.time() + images, paths = load_images( + image_folder=args.image_folder, video_path=args.video_path, + fps=args.fps, first_k=args.first_k, stride=args.stride, + image_size=args.image_size, patch_size=args.patch_size, + ) + model = load_model(args, device) + print(f"Total load time: {time.time() - t0:.1f}s") + + images = images.to(device) + num_frames = images.shape[0] + print(f"Input: {num_frames} frames, shape {tuple(images.shape)}") + print(f"Mode: {args.mode}") + + if args.mode != "streaming" and args.keyframe_interval != 1: + print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.") + args.keyframe_interval = 1 + elif args.mode == "streaming" and args.keyframe_interval > 1: + print( + f"Keyframe streaming enabled: interval={args.keyframe_interval} " + f"(after the first {args.num_scale_frames} scale frames)." + ) + + # ── Inference ──────────────────────────────────────────────────────────── + dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 + print(f"Running {args.mode} inference (dtype={dtype})...") + t0 = time.time() + + with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype): + if args.mode == "streaming": + predictions = model.inference_streaming( + images, + num_scale_frames=args.num_scale_frames, + keyframe_interval=args.keyframe_interval, + ) + else: # windowed + predictions = model.inference_windowed( + images, + window_size=args.window_size, + overlap_size=args.overlap_size, + num_scale_frames=args.num_scale_frames, + sim3=args.sim3, + se3=not args.sim3, + ) + + t_infer = time.time() - t0 + print(f"Inference done: {t_infer:.1f}s ({num_frames / t_infer:.1f} FPS)") + + # ── Post-process ───────────────────────────────────────────────────────── + predictions, images_cpu = postprocess(predictions, images) + + # ── Visualize ──────────────────────────────────────────────────────────── + try: + from lingbot_map.vis import PointCloudViewer + viewer = PointCloudViewer( + pred_dict=prepare_for_visualization(predictions, images_cpu), + port=args.port, + init_conf_threshold=args.conf_threshold, + downsample_factor=args.downsample_factor, + point_size=args.point_size, + mask_sky=args.mask_sky, + ) + print(f"3D viewer at http://localhost:{args.port}") + viewer.run() + except ImportError: + print("viser not installed. Install with: pip install lingbot-map[vis]") + print(f"Predictions contain keys: {list(predictions.keys())}") + + +if __name__ == "__main__": + main() diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000..490e0b5 Binary files /dev/null and b/docs/.DS_Store differ diff --git a/lingbot-map_paper.pdf b/lingbot-map_paper.pdf new file mode 100644 index 0000000..b99fb73 Binary files /dev/null and b/lingbot-map_paper.pdf differ diff --git a/lingbot_map/.DS_Store b/lingbot_map/.DS_Store new file mode 100644 index 0000000..3fbd3c4 Binary files /dev/null and b/lingbot_map/.DS_Store differ diff --git a/lingbot_map/__init__.py b/lingbot_map/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lingbot_map/aggregator/__init__.py b/lingbot_map/aggregator/__init__.py new file mode 100644 index 0000000..f20529e --- /dev/null +++ b/lingbot_map/aggregator/__init__.py @@ -0,0 +1,2 @@ +from .stream import AggregatorStream +from .base import AggregatorBase diff --git a/lingbot_map/aggregator/base.py b/lingbot_map/aggregator/base.py new file mode 100644 index 0000000..54712d5 --- /dev/null +++ b/lingbot_map/aggregator/base.py @@ -0,0 +1,608 @@ +""" +AggregatorBase - Base class for all Aggregator implementations. + +Provides shared functionality: +- Patch embedding (DINOv2) +- Special tokens (camera, register, scale) +- Block building +- Common forward pass structure + +Subclasses implement mode-specific attention logic. +""" + +import logging +import torch +import torch.nn as nn +from abc import ABC, abstractmethod +from typing import Optional, Tuple, List + +from lingbot_map.layers import PatchEmbed +from lingbot_map.layers.block import Block +from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter +from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2 + +logger = logging.getLogger(__name__) + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +def slice_expand_and_flatten(token, B, S, first_num_frame=1): + """ + Helper function to slice, expand and flatten tokens. + + Args: + token: Token tensor [1, 2, N, C] where first index is for first frames + B: Batch size + S: Sequence length + first_num_frame: Number of frames to use first token for + + Returns: + Flattened tokens [B*S, N, C] + """ + # token shape: [1, 2, N, C] + # Expand to [B, S, N, C] + if first_num_frame > 1: + # Use first token for first first_num_frame frames, second for rest + token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C] + token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C] + token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C] + else: + # Use first token for first frame, second for rest + token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C] + token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C] + token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C] + + # Flatten to [B*S, N, C] + return token_expanded.reshape(B * S, -1, token.shape[-1]) + + +class AggregatorBase(nn.Module, ABC): + """ + Base class for all Aggregator implementations. + + Handles shared components: + - Patch embedding (DINOv2 or conv) + - Special tokens (camera, register, optionally scale) + - Block creation (frame + global) + - RoPE (2D rotary position embeddings) + - Common forward pass scaffolding + + Subclasses must implement: + - _process_global_attention(): Mode-specific cross-frame attention logic + """ + + def __init__( + self, + # Architecture parameters + img_size=518, + patch_size=14, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + num_register_tokens=4, + # Block configuration + block_fn=Block, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + qk_norm=True, + init_values=0.01, + # Patch embedding + patch_embed="dinov2_vitl14_reg", + pretrained_path=None, + # Attention pattern + aa_order=["frame", "global"], + aa_block_size=1, + # RoPE + rope_freq=100, + disable_global_rope=False, + # Gradient checkpointing + use_reentrant: bool = False, + use_gradient_checkpoint: bool = True, + ): + super().__init__() + + # Store configuration + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.depth = depth + self.num_heads = num_heads + self.mlp_ratio = mlp_ratio + self.num_register_tokens = num_register_tokens + self.aa_order = aa_order + self.aa_block_size = aa_block_size + self.disable_global_rope = disable_global_rope + self.use_reentrant = use_reentrant + self.use_gradient_checkpoint = use_gradient_checkpoint + self.pretrained_path = pretrained_path + self.enable_ulysses_cp = False # CP disabled + + print("pretrained_path:", self.pretrained_path) + + # Validate depth + if self.depth % self.aa_block_size != 0: + raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})") + self.aa_block_num = self.depth // self.aa_block_size + + # Build patch embedding + self._build_patch_embed( + patch_embed=patch_embed, + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + embed_dim=embed_dim, + pretrained_path=pretrained_path + ) + + # Initialize RoPE + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + # Build blocks (frame + global) + self._build_blocks( + block_fn=block_fn, + depth=depth, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + ) + + # Setup special tokens (camera, register, optionally scale) + self._setup_special_tokens() + + # Register normalization constants + for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False) + + # Initialize from DINO checkpoint if available + if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None: + self._init_blocks_from_dino(self._dino_checkpoint) + del self._dino_checkpoint # Free memory + + def _build_patch_embed( + self, + patch_embed: str, + img_size: int, + patch_size: int, + num_register_tokens: int, + embed_dim: int, + pretrained_path: str, + interpolate_antialias=True, + interpolate_offset=0.0, + block_chunks=0, + init_values=1.0, + ): + """ + Build patch embedding layer. + + Supports: + - "conv": Simple convolutional patch embedding + - "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2) + """ + if "conv" in patch_embed: + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim + ) + self._dino_checkpoint = None + + else: + vit_models = { + "dinov2_vitl14_reg": vit_large, + "dinov2_vitb14_reg": vit_base, + "dinov2_vits14_reg": vit_small, + "dinov2_vitg2_reg": vit_giant2, + } + + if patch_embed not in vit_models: + raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}") + + self.patch_embed = vit_models[patch_embed]( + img_size=img_size, + patch_size=patch_size, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + block_chunks=block_chunks, + init_values=init_values, + ) + + # Load pretrained weights + try: + ckpt = torch.load(pretrained_path) + del ckpt['pos_embed'] + logger.info("Loading pretrained weights for DINOv2") + missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False) + logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") + + # Store checkpoint for block initialization + self._dino_checkpoint = ckpt + except Exception as e: + logger.warning(f"Failed to load pretrained weights: {e}") + self._dino_checkpoint = None + + # Disable gradients for mask token + if hasattr(self.patch_embed, "mask_token"): + self.patch_embed.mask_token.requires_grad_(False) + + @abstractmethod + def _build_blocks( + self, + block_fn, + depth: int, + embed_dim: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool, + proj_bias: bool, + ffn_bias: bool, + init_values: float, + qk_norm: bool, + ): + """ + Build frame_blocks and global_blocks. + + Subclasses implement mode-specific block creation. + + Must create: + - self.frame_blocks: nn.ModuleList of frame attention blocks + - self.global_blocks: nn.ModuleList of global attention blocks + """ + pass + + @abstractmethod + def _setup_special_tokens(self): + """ + Setup camera token, register tokens, and optionally scale token. + + Subclasses implement mode-specific token initialization. + + Must create: + - self.camera_token + - self.register_token + - self.scale_token (optional, for causal mode) + - self.patch_start_idx + - self.num_special_tokens + """ + pass + + def _init_blocks_from_dino(self, dino_ckpt: dict): + """ + Initialize frame_blocks and global_blocks from DINOv2 pretrained weights. + + Args: + dino_ckpt: Checkpoint dictionary from DINOv2 model + """ + logger.info("Initializing blocks from DINOv2 pretrained weights") + + # Extract block keys + dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')] + if not dino_block_keys: + logger.warning("No 'blocks' found in DINO checkpoint") + return + + # Get block indices + block_indices = set() + for key in dino_block_keys: + parts = key.split('.') + if len(parts) > 1 and parts[1].isdigit(): + block_indices.add(int(parts[1])) + + num_dino_blocks = len(block_indices) + print(f"Found {num_dino_blocks} blocks in DINO checkpoint") + + # Initialize frame_blocks + for i, block in enumerate(self.frame_blocks): + dino_block_idx = i % num_dino_blocks + block_state_dict = {} + prefix = f'blocks.{dino_block_idx}.' + for key, value in dino_ckpt.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + block_state_dict[new_key] = value + + if block_state_dict: + missing, unexpected = block.load_state_dict(block_state_dict, strict=False) + if i == 0: # Only log for first block to avoid spam + print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") + + # Initialize global_blocks + for i, block in enumerate(self.global_blocks): + dino_block_idx = i % num_dino_blocks + block_state_dict = {} + prefix = f'blocks.{dino_block_idx}.' + for key, value in dino_ckpt.items(): + if key.startswith(prefix): + new_key = key[len(prefix):] + block_state_dict[new_key] = value + + if block_state_dict: + missing, unexpected = block.load_state_dict(block_state_dict, strict=False) + if i == 0: # Only log for first block to avoid spam + print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}") + + logger.info("Successfully initialized blocks from DINOv2 weights") + + def _embed_images( + self, + images: torch.Tensor, + num_frame_for_scale: Optional[int] = None, + ) -> Tuple[torch.Tensor, int, int, int, int, int]: + """ + Embed images and prepare for attention processing. + + Handles: + - Image normalization + - Patch embedding + - Special token concatenation + - Position embedding + + Args: + images: Input images [B, S, 3, H, W] in range [0, 1] + num_frame_for_scale: Number of frames for scale estimation (passed to special tokens) + + Returns: + (tokens, B, S, S, P, C): + tokens: Embedded tokens [B*S, P, C] + B: Batch size + S: Sequence length + S: Same as above (no CP slicing) + P: Number of tokens per frame (patches + special tokens) + C: Embedding dimension + """ + B, S, C_in, H, W = images.shape + + if C_in != 3: + raise ValueError(f"Expected 3 input channels, got {C_in}") + + # Normalize images + images = (images - self._resnet_mean) / self._resnet_std + + # No CP slicing: S_local == S_global + S_local = S + S_global = S + + # Reshape for patch embedding [B*S, C, H, W] + images = images.view(B * S, C_in, H, W) + + # Patch embedding + patch_tokens = self.patch_embed(images) + if isinstance(patch_tokens, dict): + patch_tokens = patch_tokens["x_norm_patchtokens"] + + _, P_patch, C = patch_tokens.shape + + # Prepare special tokens + special_tokens = self._prepare_special_tokens( + B, S_local, S_global, C, + num_frame_for_scale=num_frame_for_scale + ) + + # Concatenate special tokens + patch tokens + tokens = torch.cat([special_tokens, patch_tokens], dim=1) + + _, P, C = tokens.shape + + return tokens, B, S_local, S_global, P, C + + @abstractmethod + def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor: + """ + Prepare special tokens (camera, register, optionally scale). + + Subclasses implement mode-specific token preparation. + + Args: + B: Batch size + S_local: Local sequence length + S_global: Global sequence length + C: Embedding dimension + **kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode) + + Returns: + Special tokens [B*S, N_special, C] + """ + pass + + def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]: + """ + Get 2D position embeddings for RoPE. + + Args: + B: Batch size + S: Sequence length + H: Image height + W: Image width + device: Device to create positions on + + Returns: + Position tensor [B*S, P, 2] or None if rope is disabled + """ + if self.rope is None: + return None + + # Get patch positions + pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device) + + # Add offset for patch tokens (skip special tokens at pos=0) + if self.patch_start_idx > 0: + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device) + pos = torch.cat([pos_special, pos], dim=1) + + return pos + + def _process_frame_attention( + self, + tokens: torch.Tensor, + B: int, + S: int, + P: int, + C: int, + frame_idx: int, + pos: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]: + """ + Process frame attention blocks. + + Frame attention operates independently per frame (no cross-frame communication). + Tokens stay in shape [B*S, P, C]. + + Args: + tokens: Input tokens [B*S, P, C] + B: Batch size + S: Sequence length + P: Tokens per frame + C: Embedding dimension + frame_idx: Current frame block index + pos: Position embeddings [B*S, P, 2] + + Returns: + (tokens, frame_idx, intermediates): + tokens: Output tokens [B*S, P, C] + frame_idx: Updated frame block index + intermediates: List of intermediate outputs [B, S, P, C] + """ + # Ensure correct shape + if tokens.shape != (B * S, P, C): + tokens = tokens.view(B * S, P, C) + + if pos is not None and pos.shape != (B * S, P, 2): + pos = pos.view(B * S, P, 2) + + intermediates = [] + + # Process blocks + for i in range(self.aa_block_size): + if self.training and self.use_gradient_checkpoint: + from torch.utils.checkpoint import checkpoint + tokens = checkpoint( + self.frame_blocks[frame_idx], + tokens, + pos, + False, # enable_ulysses_cp (always False) + use_reentrant=self.use_reentrant + ) + else: + tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False) + + frame_idx += 1 + intermediates.append(tokens.view(B, S, P, C)) + + return tokens, frame_idx, intermediates + + @abstractmethod + def _process_global_attention( + self, + tokens: torch.Tensor, + B: int, + S_local: int, + S_global: int, + P: int, + C: int, + global_idx: int, + pos: Optional[torch.Tensor] = None, + **kwargs + ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]: + """ + Process global (cross-frame) attention blocks. + + Subclasses implement mode-specific attention logic. + + Args: + tokens: Input tokens + B: Batch size + S_local: Local sequence length + S_global: Global sequence length + P: Tokens per frame + C: Embedding dimension + global_idx: Current global block index + pos: Position embeddings + **kwargs: Mode-specific parameters + + Returns: + (tokens, global_idx, intermediates): + tokens: Output tokens + global_idx: Updated global block index + intermediates: List of intermediate outputs + """ + pass + + def forward( + self, + images: torch.Tensor, + selected_idx: Optional[List[int]] = None, + # Mode-specific parameters + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + ) -> Tuple[List[torch.Tensor], int]: + """ + Forward pass. + + Args: + images: Input images [B, S, 3, H, W] in range [0, 1] + selected_idx: Which block indices to output (None = all) + num_frame_for_scale: Number of frames for scale estimation (causal mode) + sliding_window_size: Sliding window size in blocks (causal mode) + num_frame_per_block: Number of frames per processing block (causal mode) + + Returns: + (output_list, patch_start_idx): + output_list: List of block outputs [B, S, P, 2C] + patch_start_idx: Index where patch tokens start + """ + B, S_input, _, H, W = images.shape + + # Embed images + tokens, B, S_local, S_global, P, C = self._embed_images( + images, + num_frame_for_scale=num_frame_for_scale, + ) + + # Get position embeddings + pos_local = self._get_positions(B, S_local, H, W, device=images.device) + pos_global = self._get_positions(B, S_global, H, W, device=images.device) + + # Alternating attention + frame_idx = 0 + global_idx = 0 + output_list = [] + + for block_group_idx in range(self.aa_block_num): + for attn_type in self.aa_order: + if attn_type == "frame": + tokens, frame_idx, frame_intermediates = self._process_frame_attention( + tokens, B, S_local, P, C, frame_idx, pos=pos_local + ) + elif attn_type == "global": + tokens, global_idx, global_intermediates = self._process_global_attention( + tokens, B, S_local, S_global, P, C, global_idx, + pos=pos_global, + num_frame_for_scale=num_frame_for_scale, + sliding_window_size=sliding_window_size, + num_frame_per_block=num_frame_per_block, + image_height=H, + image_width=W, + ) + else: + raise ValueError(f"Unknown attention type: {attn_type}") + + # Collect outputs + if selected_idx is None or block_group_idx in selected_idx: + for i in range(len(frame_intermediates)): + # Concatenate frame and global intermediates [B, S, P, 2C] + concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1) + output_list.append(concat_inter) + + return output_list, self.patch_start_idx diff --git a/lingbot_map/aggregator/stream.py b/lingbot_map/aggregator/stream.py new file mode 100644 index 0000000..e442160 --- /dev/null +++ b/lingbot_map/aggregator/stream.py @@ -0,0 +1,531 @@ +""" +AggregatorStream - Streaming causal aggregator with FlashInfer KV cache. + +Provides: +- Temporal causal attention +- Sliding window support +- Scale token for scale estimation frames +- Streaming inference with FlashInfer paged KV cache +""" + +import logging +import torch +import torch.nn as nn +from typing import Optional, Tuple, List + +from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock +from lingbot_map.layers.rope import WanRotaryPosEmbed +from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten + +logger = logging.getLogger(__name__) + + +class AggregatorStream(AggregatorBase): + """ + Streaming causal aggregator with FlashInfer paged KV cache. + + Features: + - Temporal causal attention (each frame only attends to past frames) + - Sliding window support to limit attention scope + - Scale token for scale estimation frames + - Streaming inference with FlashInfer KV cache + """ + + def __init__( + self, + # Causal-specific parameters + sliding_window_size: int = -1, + num_frame_for_scale: int = 1, + num_random_frames: int = 0, + attend_to_special_tokens: bool = False, + attend_to_scale_frames: bool = False, + enable_3d_rope: bool = False, + max_frame_num: int = 1024, + # KV cache parameters + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + # Base class parameters via **kwargs + **kwargs + ): + """ + Initialize AggregatorStream. + + Args: + sliding_window_size: Sliding window size in blocks (-1 for full causal) + num_frame_for_scale: Number of scale estimation frames + num_random_frames: Number of random frames for long-range dependencies + attend_to_special_tokens: Enable cross-frame special token attention + attend_to_scale_frames: Include scale frames in attention + enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache + max_frame_num: Maximum number of frames for 3D RoPE + kv_cache_sliding_window: Sliding window size for KV cache eviction + kv_cache_scale_frames: Number of scale frames to keep in KV cache + kv_cache_cross_frame_special: Keep special tokens from evicted frames + kv_cache_include_scale_frames: Include scale frames in KV cache + kv_cache_camera_only: Only keep camera tokens from evicted frames + **kwargs: Base class parameters + """ + self.sliding_window_size = sliding_window_size + self.num_frame_for_scale = num_frame_for_scale + self.num_random_frames = num_random_frames + self.attend_to_special_tokens = attend_to_special_tokens + self.attend_to_scale_frames = attend_to_scale_frames + self.enable_3d_rope = enable_3d_rope + self.max_frame_num = max_frame_num + # KV cache parameters + self.kv_cache_sliding_window = kv_cache_sliding_window + self.kv_cache_scale_frames = kv_cache_scale_frames + self.kv_cache_cross_frame_special = kv_cache_cross_frame_special + self.kv_cache_include_scale_frames = kv_cache_include_scale_frames + self.kv_cache_camera_only = kv_cache_camera_only + + # Pop kwargs that are passed but not needed by base class + kwargs.pop('enable_stream_inference', None) + use_flashinfer = kwargs.pop('use_flashinfer', True) + kwargs.pop('use_flexflash', None) + use_sdpa = kwargs.pop('use_sdpa', False) + + # Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache) + self.use_sdpa = use_sdpa + self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested + + # Call parent __init__ + super().__init__(**kwargs) + + # Initialize KV cache + self._init_kv_cache() + + # Initialize 3D RoPE if enabled + if self.enable_3d_rope: + self._init_3d_rope() + + def _build_blocks( + self, + block_fn, + depth: int, + embed_dim: int, + num_heads: int, + mlp_ratio: float, + qkv_bias: bool, + proj_bias: bool, + ffn_bias: bool, + init_values: float, + qk_norm: bool, + ): + """Build frame and global blocks for streaming causal mode.""" + block_params = dict( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + init_values=init_values, + qk_norm=qk_norm, + ) + + # Frame blocks: Standard Block + RoPE + self.frame_blocks = nn.ModuleList([ + block_fn(**block_params, rope=self.rope) + for _ in range(depth) + ]) + + # Global blocks: FlashInferBlock (default) or SDPABlock (fallback) + GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock + self.global_blocks = nn.ModuleList([ + GlobalBlockCls( + **block_params, + rope=self.rope if not self.disable_global_rope else None, + kv_cache_sliding_window=self.kv_cache_sliding_window, + kv_cache_scale_frames=self.kv_cache_scale_frames, + kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, + kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, + kv_cache_camera_only=self.kv_cache_camera_only, + ) + for _ in range(depth) + ]) + + def _setup_special_tokens(self): + """Setup camera, register, and scale tokens for causal mode.""" + # Camera token + self.camera_token = nn.Parameter( + torch.randn(1, 2, 1, self.embed_dim) + ) + + # Register tokens + if self.num_register_tokens > 0: + self.register_token = nn.Parameter( + torch.randn(1, 2, self.num_register_tokens, self.embed_dim) + ) + + # Scale token (causal mode specific) + self.scale_token = nn.Parameter( + torch.ones(1, 2, 1, self.embed_dim) + ) + + # Initialize + nn.init.normal_(self.camera_token, std=1e-6) + if self.num_register_tokens > 0: + nn.init.normal_(self.register_token, std=1e-6) + nn.init.normal_(self.scale_token, std=1e-6) + + # Token indexing (includes scale token) + self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale + self.num_special_tokens = 1 + self.num_register_tokens + 1 + + def _init_kv_cache(self): + """Initialize KV cache for streaming inference.""" + self.kv_cache_manager = None # FlashInfer (lazy-initialized) + self.kv_cache = {} # Dict-based cache for SDPA + self.total_frames_processed = 0 + self._cached_pos3d = None + + if self.use_sdpa: + # Dict-based KV cache for SDPA + if hasattr(self, 'depth'): + for i in range(self.depth): + self.kv_cache[f"k_{i}"] = None + self.kv_cache[f"v_{i}"] = None + self.kv_cache[f"k_{i}_special"] = None + self.kv_cache[f"v_{i}_special"] = None + logger.info(f"SDPA KV cache initialized with {self.depth} blocks") + else: + logger.info("FlashInfer KV cache will be lazily initialized on first forward") + + def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None): + """Lazily initialize FlashInferKVCacheManager on first use. + + Args: + device: Device for cache tensors. + dtype: Data type for cache tensors. + tokens_per_frame: Actual number of tokens per frame (patches + specials). + If None, falls back to assuming square images of self.img_size. + """ + if self.kv_cache_manager is None: + from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager + num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L + head_dim = 64 + if tokens_per_frame is None: + tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens + # max_num_frames: scale + window + headroom + max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16 + self.kv_cache_manager = FlashInferKVCacheManager( + num_blocks=self.depth, + max_num_frames=max_num_frames, + tokens_per_frame=tokens_per_frame, + num_heads=num_heads, + head_dim=head_dim, + dtype=dtype, + device=device, + num_special_tokens=self.num_special_tokens, + scale_frames=self.kv_cache_scale_frames, + sliding_window=self.kv_cache_sliding_window, + max_total_frames=self.max_frame_num + 100, + force_fp32=getattr(self, 'kv_cache_force_fp32', False), + fa3=getattr(self, 'kv_cache_fa3', False), + ) + logger.info( + f"FlashInfer KV cache manager initialized: {self.depth} blocks, " + f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}" + ) + return self.kv_cache_manager + + def clean_kv_cache(self): + """Clean KV cache (call this when starting a new sequence).""" + if self.kv_cache_manager is not None: + self.kv_cache_manager.reset() + if self.kv_cache: + for key in list(self.kv_cache.keys()): + if key == "_skip_append": + self.kv_cache[key] = False + else: + self.kv_cache[key] = None + self.total_frames_processed = 0 + self._cached_pos3d = None + logger.info("KV cache cleaned") + + def _init_3d_rope(self): + """Initialize 3D RoPE for streaming inference.""" + if not self.enable_3d_rope: + self.rope3d = None + return + + num_heads = 16 + head_dim = self.embed_dim // num_heads + + self.rope3d = WanRotaryPosEmbed( + attention_head_dim=head_dim, + patch_size=(1, self.patch_size, self.patch_size), + max_seq_len=self.max_frame_num, + ) + logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}") + + def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end): + """ + Generate 3D RoPE positions for streaming mode with correct global frame indices. + + Args: + num_frames: Number of frames in current batch + H, W: Image height and width + device: Device to create positions on + f_start: Global start frame index + f_end: Global end frame index + + Returns: + pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor + """ + if self.rope3d is None: + return None + + pph = H // self.patch_size + ppw = W // self.patch_size + + pos3d = self.rope3d( + ppf=num_frames, + pph=pph, + ppw=ppw, + patch_start_idx=self.num_special_tokens, + device=device, + f_start=f_start, + f_end=f_end + ) + return pos3d + + def _prepare_special_tokens( + self, + B: int, + S_local: int, + S_global: int, + C: int, + num_frame_for_scale: Optional[int] = None, + ) -> torch.Tensor: + """ + Prepare camera, register, and scale tokens. + + Args: + B: Batch size + S_local: Local sequence length + S_global: Global sequence length + C: Embedding dimension + num_frame_for_scale: Number of frames for scale estimation + + Returns: + Special tokens [B*S_global, N_special, C] + """ + # Get effective num_frame_for_scale + scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale + + # Check cache state for both backends + has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0 + has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None + + # Determine if we're in causal inference mode based on KV cache state + causal_inference = True + + if causal_inference and has_flashinfer_cache: + S_cached = self.kv_cache_manager.num_frames + S_true = S_cached + S_global + elif causal_inference and has_sdpa_cache: + _, _, S_cached, _, _ = self.kv_cache["k_0"].shape + S_true = S_cached + S_global + else: + S_true = S_global + + # Expand tokens based on mode + if causal_inference and S_true > S_global: + # Streaming mode: expand with S_true, then slice to get current frames + effective_scale_frames = min(scale_frames, S_true) + + camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true) + camera_token = camera_token_full[-S_global:, :, :] + + register_token_full = slice_expand_and_flatten(self.register_token, B, S_true) + register_token = register_token_full[-S_global:, :, :] + scale_token_full = slice_expand_and_flatten( + self.scale_token, B, S_true, first_num_frame=effective_scale_frames + ) + scale_token = scale_token_full[-S_global:, :, :] + else: + # Batch mode or first inference: expand directly + effective_scale_frames = min(scale_frames, S_global) + + camera_token = slice_expand_and_flatten(self.camera_token, B, S_global) + register_token = slice_expand_and_flatten(self.register_token, B, S_global) + scale_token = slice_expand_and_flatten( + self.scale_token, B, S_global, first_num_frame=effective_scale_frames + ) + + special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1) + + # Verify shape + expected_shape = (B * S_global, self.num_special_tokens, C) + assert special_tokens.shape == expected_shape, \ + f"Expected {expected_shape}, got {special_tokens.shape}" + + return special_tokens + + def _process_global_attention( + self, + tokens: torch.Tensor, + B: int, + S_local: int, + S_global: int, + P: int, + C: int, + global_idx: int, + pos: Optional[torch.Tensor] = None, + # Mode-specific parameters + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + **kwargs, + ) -> Tuple[torch.Tensor, int, List[torch.Tensor]]: + """ + Process causal global attention via FlashInfer streaming path. + + Args: + tokens: Input tokens + B: Batch size + S_local: Local sequence length + S_global: Global sequence length + P: Tokens per frame + C: Embedding dimension + global_idx: Current global block index + pos: Position embeddings + num_frame_for_scale: Number of frames for scale estimation + sliding_window_size: Sliding window size in blocks + num_frame_per_block: Number of frames per processing block + + Returns: + (tokens, global_idx, intermediates) + """ + # Extract image dimensions from kwargs for 3D RoPE + image_height = kwargs.get('image_height', self.img_size) + image_width = kwargs.get('image_width', self.img_size) + + return self._process_causal_stream( + tokens, B, S_local, S_global, P, C, global_idx, pos, + num_frame_per_block, sliding_window_size, num_frame_for_scale, + image_height=image_height, image_width=image_width + ) + + def _process_causal_stream( + self, + tokens: torch.Tensor, + B: int, + S_local: int, + S_global: int, + P: int, + C: int, + global_idx: int, + pos: Optional[torch.Tensor] = None, + num_frame_per_block: int = 1, + sliding_window_size: Optional[int] = None, + num_frame_for_scale: Optional[int] = None, + image_height: Optional[int] = None, + image_width: Optional[int] = None, + ): + """ + Causal attention for streaming inference using FlashInfer KV cache. + + Args: + tokens: Input tokens [B*S_local, P, C] + B: Batch size + S_local: Local sequence length + S_global: Global sequence length + P: Number of patches per frame (includes special tokens) + C: Channel dimension + global_idx: Starting block index + pos: Position embeddings [B*S_global, P, 2] + num_frame_per_block: Number of frames per block + sliding_window_size: Sliding window size in blocks + num_frame_for_scale: Number of scale frames + image_height: Image height for 3D RoPE calculation + image_width: Image width for 3D RoPE calculation + + Returns: + (tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs + """ + # Get effective parameters + scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale + + # Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C] + if tokens.shape != (B, S_local * P, C): + tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C) + + # Calculate number of frames for block mask + num_frames = S_global + num_patches = P - self.num_special_tokens + + # Check if this is the first block group + is_first_block_group = (global_idx < self.aa_block_size) + + if self.enable_3d_rope and self.rope3d is not None: + if is_first_block_group: + f_start = self.total_frames_processed + f_end = self.total_frames_processed + S_global + + H = image_height if image_height is not None else self.img_size + W = image_width if image_width is not None else self.img_size + pos3d = self._get_3d_positions_streaming( + S_global, H, W, tokens.device, f_start, f_end + ) + self._cached_pos3d = pos3d + else: + pos3d = self._cached_pos3d + pos = pos3d + else: + # Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2] + if pos is not None and pos.shape != (B, S_global * P, 2): + pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2) + + intermediates = [] + + # Process blocks with KV cache + for _ in range(self.aa_block_size): + num_patches = P - self.num_special_tokens + if self.use_sdpa: + # SDPA: dict-based KV cache + tokens = self.global_blocks[global_idx]( + tokens, + pos=pos, + enable_ulysses_cp=False, + num_patches=num_patches, + num_special=self.num_special_tokens, + num_frames=num_frames, + enable_3d_rope=self.enable_3d_rope, + kv_cache=self.kv_cache, + global_idx=global_idx, + num_frame_per_block=num_frame_per_block, + num_frame_for_scale=scale_frames, + num_register_tokens=self.num_register_tokens, + ) + else: + # FlashInfer: paged KV cache manager + manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P) + tokens = self.global_blocks[global_idx]( + tokens, + pos=pos, + enable_ulysses_cp=False, + num_patches=num_patches, + num_special=self.num_special_tokens, + num_frames=num_frames, + enable_3d_rope=self.enable_3d_rope, + kv_cache=manager, + global_idx=global_idx, + num_frame_per_block=num_frame_per_block, + num_frame_for_scale=scale_frames, + num_register_tokens=self.num_register_tokens, + ) + + global_idx += 1 + intermediates.append(tokens.view(B, S_local, P, C)) + + # Update total frames processed counter only on the first block group + if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)): + self.total_frames_processed += S_global + + return tokens, global_idx, intermediates diff --git a/lingbot_map/heads/__init__.py b/lingbot_map/heads/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lingbot_map/heads/camera_head.py b/lingbot_map/heads/camera_head.py new file mode 100644 index 0000000..1f97711 --- /dev/null +++ b/lingbot_map/heads/camera_head.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lingbot_map.layers import Mlp +from lingbot_map.layers.block import Block +from lingbot_map.layers.block import CameraBlock +from lingbot_map.heads.head_act import activate_pose +from lingbot_map.layers.rope import WanRotaryPosEmbed +from functools import partial +from torch.utils.checkpoint import checkpoint + + +class CameraHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + enable_ulysses_cp=False, + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + + self.enable_ulysses_cp = enable_ulysses_cp + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C]. + num_iterations (int): Number of refinement iterations. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape # S is expected to be 1. + pred_pose_enc = None + pred_pose_enc_list = [] + + for _ in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + # Apply trunk blocks with enable_ulysses_cp + for block in self.trunk: + pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift + + +class CameraCausalHead(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_in: int = 2048, + trunk_depth: int = 4, + pose_encoding_type: str = "absT_quaR_FoV", + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + trans_act: str = "linear", + quat_act: str = "linear", + fl_act: str = "relu", # Field of view activations: ensures FOV values are positive. + num_iterations = 4, + elementwise_attn_output_gate: bool = False, + sliding_window_size: int = -1, + attend_to_scale_frames: bool = False, + num_random_frames: int = 0, + enable_ulysses_cp: bool = False, + attn_class: str = "flexflashattn_varlen", + # KV cache parameters + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + # 3D RoPE parameters + enable_3d_rope: bool = False, + max_frame_num: int = 1024, + rope_theta: float = 10000.0, + ): + super().__init__() + + if pose_encoding_type == "absT_quaR_FoV": + self.target_dim = 9 + else: + raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}") + + self.trans_act = trans_act + self.quat_act = quat_act + self.fl_act = fl_act + self.trunk_depth = trunk_depth + self.sliding_window_size = sliding_window_size + self.enable_ulysses_cp = enable_ulysses_cp + self.num_heads = num_heads + + # 3D RoPE for temporal position encoding + self.enable_3d_rope = enable_3d_rope + if enable_3d_rope: + head_dim = dim_in // num_heads + # For camera head: each frame has 1 token (frame_seqlen=1) + # patch_size is (max_frames, h=1, w=1) for 3D RoPE + # fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder + self.rope3d = WanRotaryPosEmbed( + attention_head_dim=head_dim, + patch_size=(max_frame_num, 1, 1), + theta=rope_theta, + fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation + ) + else: + self.rope3d = None + + # Build the trunk using a sequence of transformer blocks. + self.trunk = nn.Sequential( + *[ + CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only) + for _ in range(trunk_depth) + ] + ) + + # Normalizations for camera token and trunk output. + self.token_norm = nn.LayerNorm(dim_in) + self.trunk_norm = nn.LayerNorm(dim_in) + + # Learnable empty camera pose token. + self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim)) + self.embed_pose = nn.Linear(self.target_dim, dim_in) + + # Module for producing modulation parameters: shift, scale, and a gate. + self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True)) + + # Adaptive layer normalization without affine parameters. + self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6) + self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0) + + self.num_iterations = num_iterations + + self.kv_cache = None + self.pos_cache = None + self.frame_idx = 0 + self.cp_size = 1 + + ## Get cp size if enable ulysses cp + if self.enable_ulysses_cp: + from torchtitan.distributed.sequence_parallel import ( + init_sequence_parallel, + get_ulysses_sequence_parallel_rank, + get_ulysses_sequence_parallel_world_size, + ) + + self.cp_size = get_ulysses_sequence_parallel_world_size() + + + + def clean_kv_cache(self): + del self.kv_cache + self.kv_cache = None + self.frame_idx = 0 + + def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list: + """ + Forward pass to predict camera parameters. + + Args: + aggregated_tokens_list (list): List of token tensors from the network; + the last tensor is used for prediction. + num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4. + sliding_window_size (int, optional): Override the sliding window size for this forward pass. + If None, use the default self.sliding_window_size. + + Returns: + list: A list of predicted camera encodings (post-activation) from each iteration. + """ + # Use passed sliding_window_size if provided, otherwise use default + effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size + + # Use tokens from the last block for camera prediction. + tokens = aggregated_tokens_list[-1] + + # Extract the camera tokens + pose_tokens = tokens[:, :, 0] + pose_tokens = self.token_norm(pose_tokens) + + if causal_inference: + if self.kv_cache is None: + self.kv_cache = [] + for i in range(self.num_iterations): + self.kv_cache.append({"_skip_append": False}) + for j in range(self.trunk_depth): + self.kv_cache[i][f"k_{j}"] = None + self.kv_cache[i][f"v_{j}"] = None + + pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size) + return pred_pose_enc_list + + def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list: + """ + Iteratively refine camera pose predictions. + + Args: + pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C]. + num_iterations (int): Number of refinement iterations. + sliding_window_size (int, optional): Sliding window size to use. + + Returns: + list: List of activated camera encodings from each iteration. + """ + B, S, C = pose_tokens.shape + pred_pose_enc = None + pred_pose_enc_list = [] + + # Check if this is the first call (processing scale frames) + # Scale frames should use batch mode attention for numerical consistency + is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0) + + # Generate 3D RoPE positions if enabled + pos3d = None + if self.rope3d is not None: + # For camera tokens: shape [B, S, C] where each frame has 1 token + # Position for frame f is (f, 0, 0) - temporal varies, spatial fixed + + # In streaming mode with KV cache, use frame_idx to track global position + # Otherwise, generate positions from 0 + if self.kv_cache is not None: + f_start = self.frame_idx + f_end = self.frame_idx + S + else: + f_start = 0 + f_end = None # Will use ppf as frame count + + pos3d = self.rope3d( + ppf=S * self.cp_size, # Total frames (with CP) + pph=1, # height = 1 (camera token) + ppw=1, # width = 1 (camera token) + patch_start_idx=0, # No special tokens before + device=pose_tokens.device, + f_start=f_start, + f_end=f_end, + ) # Returns [1, 1, S*cp_size, head_dim//2] complex + + for i in range(num_iterations): + # Use a learned empty pose for the first iteration. + if pred_pose_enc is None: + module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1)) + else: + # Detach the previous prediction to avoid backprop through time. + pred_pose_enc = pred_pose_enc.detach() + module_input = self.embed_pose(pred_pose_enc) + + # Generate modulation parameters and split them into shift, scale, and gate components. + shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1) + + # Adaptive layer normalization and modulation. + pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa) + pose_tokens_modulated = pose_tokens_modulated + pose_tokens + + for idx in range(self.trunk_depth): + pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames) + # Compute the delta update for the pose encoding. + pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated)) + + if pred_pose_enc is None: + pred_pose_enc = pred_pose_enc_delta + else: + pred_pose_enc = pred_pose_enc + pred_pose_enc_delta + + # Apply final activation functions for translation, quaternion, and field-of-view. + activated_pose = activate_pose( + pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act + ) + pred_pose_enc_list.append(activated_pose) + + # Update frame_idx for streaming mode (KV cache) + if self.kv_cache is not None: + self.frame_idx += S + + return pred_pose_enc_list + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + Modulate the input tensor using scaling and shifting parameters. + """ + # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19 + return x * (1 + scale) + shift + + + + +class CameraDecoder(nn.Module): + def __init__( + self, + in_dim, + out_dim, + dec_embed_dim=512, + depth=5, + dec_num_heads=8, + mlp_ratio=4, + rope=None, + need_project=True, + use_checkpoint=False, + ): + super().__init__() + + self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity() + self.use_checkpoint = use_checkpoint + + self.blocks = nn.ModuleList([ + Block( + dim=dec_embed_dim, + num_heads=dec_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=True, + proj_bias=True, + ffn_bias=True, + drop_path=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + ffn_layer=Mlp, + init_values=None, + qk_norm=False, + # attn_class=MemEffAttentionRope, + rope=rope + ) for _ in range(depth)]) + + self.linear_out = nn.Linear(dec_embed_dim, out_dim) + + def forward(self, hidden, xpos=None): + hidden = self.projects(hidden) + B, V, P, C = hidden.shape + hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3]) + for i, blk in enumerate(self.blocks): + if self.use_checkpoint and self.training: + hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False) + else: + hidden = blk(hidden, pos=xpos) + out = self.linear_out(hidden).view(B, V, P, -1) + + return out diff --git a/lingbot_map/heads/dpt_head.py b/lingbot_map/heads/dpt_head.py new file mode 100644 index 0000000..ab11f91 --- /dev/null +++ b/lingbot_map/heads/dpt_head.py @@ -0,0 +1,679 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +# Inspired by https://github.com/DepthAnything/Depth-Anything-V2 + + +import os +from typing import List, Dict, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .head_act import activate_head +from .utils import create_uv_grid, position_grid_to_embed + + +class DPTHead(nn.Module): + """ + DPT Head for dense prediction tasks. + + This implementation follows the architecture described in "Vision Transformers for Dense Prediction" + (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer + backbone and produces dense predictions by fusing multi-scale features. + + Args: + dim_in (int): Input dimension (channels). + patch_size (int, optional): Patch size. Default is 14. + output_dim (int, optional): Number of output channels. Default is 4. + activation (str, optional): Activation type. Default is "inv_log". + conf_activation (str, optional): Confidence activation type. Default is "expp1". + features (int, optional): Feature channels for intermediate representations. Default is 256. + out_channels (List[int], optional): Output channels for each intermediate layer. + intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT. + pos_embed (bool, optional): Whether to use positional embedding. Default is True. + feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False. + down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1. + """ + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "inv_log", + conf_activation: str = "expp1", + features: int = 256, + out_channels: List[int] = [256, 512, 1024, 1024], + intermediate_layer_idx: List[int] = [0, 1, 2, 3], + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + ) -> None: + super(DPTHead, self).__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.feature_only = feature_only + self.down_ratio = down_ratio + self.intermediate_layer_idx = intermediate_layer_idx + + self.norm = nn.LayerNorm(dim_in) + + # Projection layers for each output channel from tokens. + self.projects = nn.ModuleList( + [nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # Resize layers for upsampling feature maps. + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1 + ), + ] + ) + + self.scratch = _make_scratch(out_channels, features, expand=False) + + # Attach additional modules to scratch. + self.scratch.stem_transpose = None + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + head_features_1 = features + head_features_2 = 32 + + if feature_only: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1) + else: + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + conv2_in_channels = head_features_1 // 2 + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + def forward( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_chunk_size: int = 8, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass through the DPT head, supports processing by chunking frames. + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1]. + patch_start_idx (int): Starting index for patch tokens in the token sequence. + Used to separate patch tokens from other tokens (e.g., camera or register tokens). + frames_chunk_size (int, optional): Number of frames to process in each chunk. + If None or larger than S, all frames are processed at once. Default: 8. + + Returns: + Tensor or Tuple[Tensor, Tensor]: + - If feature_only=True: Feature maps with shape [B, S, C, H, W] + - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W] + """ + B, _, _, H, W = images.shape + + S = aggregated_tokens_list[0].shape[1] + + # If frames_chunk_size is not specified or greater than S, process all frames at once + if frames_chunk_size is None or frames_chunk_size >= S: + return self._forward_impl(aggregated_tokens_list, images, patch_start_idx) + + # Otherwise, process frames in chunks to manage memory usage + assert frames_chunk_size > 0 + + # Process frames in batches + all_preds = [] + all_conf = [] + + for frames_start_idx in range(0, S, frames_chunk_size): + frames_end_idx = min(frames_start_idx + frames_chunk_size, S) + + # Process batch of frames + if self.feature_only: + chunk_output = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_output) + else: + chunk_preds, chunk_conf = self._forward_impl( + aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx + ) + all_preds.append(chunk_preds) + all_conf.append(chunk_conf) + + # Concatenate results along the sequence dimension + if self.feature_only: + return torch.cat(all_preds, dim=1) + else: + return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1) + + def _forward_impl( + self, + aggregated_tokens_list: List[torch.Tensor], + images: torch.Tensor, + patch_start_idx: int, + frames_start_idx: int = None, + frames_end_idx: int = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Implementation of the forward pass through the DPT head. + + This method processes a specific chunk of frames from the sequence. + + Args: + aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers. + images (Tensor): Input images with shape [B, S, 3, H, W]. + patch_start_idx (int): Starting index for patch tokens. + frames_start_idx (int, optional): Starting index for frames to process. + frames_end_idx (int, optional): Ending index for frames to process. + + Returns: + Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence). + """ + + B, _, _, H, W = images.shape + + patch_h, patch_w = H // self.patch_size, W // self.patch_size + + out = [] + dpt_idx = 0 + + for layer_idx in self.intermediate_layer_idx: + x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:] + + + + if frames_start_idx is not None and frames_end_idx is not None: + x = x[:, frames_start_idx:frames_end_idx] + + B, S = x.shape[0], x.shape[1] + + x = x.reshape(B * S, -1, x.shape[-1]) + + x = self.norm(x) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[dpt_idx](x) + if self.pos_embed: + x = self._apply_pos_embed(x, W, H) + x = self.resize_layers[dpt_idx](x) + + out.append(x) + dpt_idx += 1 + + # Fuse features from multiple layers. + out = self.scratch_forward(out) + # Interpolate fused output to match target image resolution. + out = custom_interpolate( + out, + (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), + mode="bilinear", + align_corners=True, + ) + + if self.pos_embed: + out = self._apply_pos_embed(out, W, H) + + if self.feature_only: + return out.view(B, S, *out.shape[1:]) + + out = self.scratch.output_conv2(out) + preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation) + + preds = preds.view(B, S, *preds.shape[1:]) + conf = conf.view(B, S, *conf.shape[1:]) + return preds, conf + + def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """ + Apply positional embedding to tensor x. + """ + patch_w = x.shape[-1] + patch_h = x.shape[-2] + pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pos_embed = position_grid_to_embed(pos_embed, x.shape[1]) + pos_embed = pos_embed * ratio + pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pos_embed + + def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor: + """ + Forward pass through the fusion blocks. + + Args: + features (List[Tensor]): List of feature maps from different layers. + + Returns: + Tensor: Fused feature map. + """ + layer_1, layer_2, layer_3, layer_4 = features + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + del layer_4_rn, layer_4 + + out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:]) + del layer_3_rn, layer_3 + + out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:]) + del layer_2_rn, layer_2 + + out = self.scratch.refinenet1(out, layer_1_rn) + del layer_1_rn, layer_1 + + out = self.scratch.output_conv1(out) + return out + + +################################################################################ +# Modules +################################################################################ + + +def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module: + return FeatureFusionBlock( + features, + nn.ReLU(inplace=True), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module: + scratch = nn.Module() + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn, groups=1): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.norm1 = None + self.norm2 = None + + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None, + has_residual=True, + groups=1, + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + self.groups = groups + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups + ) + + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.has_residual = has_residual + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups) + + self.skip_add = nn.quantized.FloatFunctional() + self.size = size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if self.has_residual: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + output = self.out_conv(output) + + return output + + +def custom_interpolate( + x: torch.Tensor, + size: Tuple[int, int] = None, + scale_factor: float = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate. + """ + if size is None: + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + + input_elements = size[0] * size[1] * x.shape[0] * x.shape[1] + + if input_elements > INT_MAX: + chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0) + interpolated_chunks = [ + nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks + ] + x = torch.cat(interpolated_chunks, dim=0) + return x.contiguous() + else: + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) + +class DPTHead_Update(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead_Update, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w, return_intermediate=True): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + if return_intermediate: + return out, path_1, path_2, path_3, path_4 + else: + out = self.scratch.output_conv2(out) + return out + +def _make_fusion_block_slam(features, use_bn, size=None): + return FeatureFusionBlock_slam( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class FeatureFusionBlock_slam(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_slam, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output \ No newline at end of file diff --git a/lingbot_map/heads/head_act.py b/lingbot_map/heads/head_act.py new file mode 100644 index 0000000..2dedfcf --- /dev/null +++ b/lingbot_map/heads/head_act.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"): + """ + Activate pose parameters with specified activation functions. + + Args: + pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length] + trans_act: Activation type for translation component + quat_act: Activation type for quaternion component + fl_act: Activation type for focal length component + + Returns: + Activated pose parameters tensor + """ + T = pred_pose_enc[..., :3] + quat = pred_pose_enc[..., 3:7] + fl = pred_pose_enc[..., 7:] # or fov + + T = base_pose_act(T, trans_act) + quat = base_pose_act(quat, quat_act) + fl = base_pose_act(fl, fl_act) # or fov + + pred_pose_enc = torch.cat([T, quat, fl], dim=-1) + + return pred_pose_enc + + +def base_pose_act(pose_enc, act_type="linear"): + """ + Apply basic activation function to pose parameters. + + Args: + pose_enc: Tensor containing encoded pose parameters + act_type: Activation type ("linear", "inv_log", "exp", "relu") + + Returns: + Activated pose parameters + """ + if act_type == "linear": + return pose_enc + elif act_type == "inv_log": + return inverse_log_transform(pose_enc) + elif act_type == "exp": + return torch.exp(pose_enc) + elif act_type == "relu": + return F.relu(pose_enc) + else: + raise ValueError(f"Unknown act_type: {act_type}") + + +def activate_head(out, activation="norm_exp", conf_activation="expp1"): + """ + Process network output to extract 3D points and confidence values. + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + xyz = fmap[:, :, :, :-1] + conf = fmap[:, :, :, -1] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "inv_log": + pts3d = inverse_log_transform(xyz) + elif activation == "xy_inv_log": + xy, z = xyz.split([2, 1], dim=-1) + z = inverse_log_transform(z) + pts3d = torch.cat([xy * z, z], dim=-1) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +def inverse_log_transform(y): + """ + Apply inverse log transform: sign(y) * (exp(|y|) - 1) + + Args: + y: Input tensor + + Returns: + Transformed tensor + """ + return torch.sign(y) * (torch.expm1(torch.abs(y))) diff --git a/lingbot_map/heads/utils.py b/lingbot_map/heads/utils.py new file mode 100644 index 0000000..533fc8a --- /dev/null +++ b/lingbot_map/heads/utils.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + + +def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + device = pos.device + omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid diff --git a/lingbot_map/layers/__init__.py b/lingbot_map/layers/__init__.py new file mode 100644 index 0000000..75bac30 --- /dev/null +++ b/lingbot_map/layers/__init__.py @@ -0,0 +1,5 @@ +from lingbot_map.layers.mlp import Mlp +from lingbot_map.layers.patch_embed import PatchEmbed +from lingbot_map.layers.block import Block +from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused +from lingbot_map.layers.attention import Attention as MemEffAttention diff --git a/lingbot_map/layers/attention.py b/lingbot_map/layers/attention.py new file mode 100644 index 0000000..4a4a0d3 --- /dev/null +++ b/lingbot_map/layers/attention.py @@ -0,0 +1,766 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import math +import warnings +import torch + +from torch import Tensor +from torch import nn +import torch.nn.functional as F + +from lingbot_map.layers.rope import apply_rotary_emb + +from einops import rearrange + +# FlashInfer imports (optional - for paged attention) +try: + import flashinfer + FLASHINFER_AVAILABLE = True +except ImportError: + FLASHINFER_AVAILABLE = False + print("flashinfer not available") + +try: + from torchtitan.distributed.sequence_parallel import ( + gather_seq_scatter_heads, + gather_heads_scatter_seq, + pad_tensor, + slice_input_tensor_scale_grad, + gather_outputs, + ) +except ImportError: + print("torchtitan not available for ulysses cp") + +def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int): + """Gather sequence dimension and scatter head dimension for Q, K, V tensors.""" + q = gather_seq_scatter_heads(q, seq_dim, head_dim) + k = gather_seq_scatter_heads(k, seq_dim, head_dim) + v = gather_seq_scatter_heads(v, seq_dim, head_dim) + return q, k, v + +from typing_extensions import List +from typing import Optional, Tuple + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if enable_ulysses_cp: + q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + if enable_ulysses_cp: + x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) + + x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CausalAttention(nn.Module): + """ + Causal self-attention module with KV cache support for streaming inference. + Used by CasualBlockCamera in camera_head.py. + """ + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + elementwise_attn_output_gate=False, + # KV cache eviction parameters (matching build_attn_mask) + kv_cache_sliding_window: int =64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token) + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None + + # Store KV cache eviction parameters + self.kv_cache_sliding_window = kv_cache_sliding_window + self.kv_cache_scale_frames = kv_cache_scale_frames + self.kv_cache_cross_frame_special = kv_cache_cross_frame_special + self.kv_cache_include_scale_frames = kv_cache_include_scale_frames + self.kv_cache_camera_only = kv_cache_camera_only + + def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor: + B, N, C = x.shape + + # Calculate special token indices + camera_token_idx = 0 + scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale + + # [3, B, num_heads, N, head_dim] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.gate_proj is not None: + gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + if kv_cache is None: + q, k = self.q_norm(q), self.k_norm(k) + if enable_ulysses_cp: + q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) + N = q.shape[2] # Update N after gather + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif enable_3d_rope and pos is not None: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + with torch.no_grad(): + block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]] + if block_mask.dim() == 2: + block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N] + block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1]) + + video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N] + video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1]) + + mask = block_mask | ~video_mask + + # Apply sliding window mask if sliding_window_size > 0 + # sliding_window_size is in units of num_frame_per_block + if sliding_window_size > 0 and frame_seqlen is not None: + # Create sliding window mask: each frame can only attend to frames within the window + num_frames = N // frame_seqlen + sliding_mask = torch.zeros_like(mask, dtype=torch.bool) + + for i in range(num_frames): + q_start = i * frame_seqlen + q_end = (i + 1) * frame_seqlen + # Calculate the window start: sliding_window_size is in units of num_frame_per_block + # So the actual window size in frames is sliding_window_size * num_frame_per_block + window_size_in_frames = sliding_window_size * num_frame_per_block + window_start_frame = max(0, i - window_size_in_frames + 1) + k_start = window_start_frame * frame_seqlen + k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal) + sliding_mask[:, :, q_start:q_end, k_start:k_end] = True + + # Combine with existing mask: both masks need to allow attention + mask = mask & sliding_mask + + # If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames + if num_frame_for_scale > 0: + for i in range(num_frames): + q_start = i * frame_seqlen + q_end = (i + 1) * frame_seqlen + # Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask) + mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True + + ## global attention for the first num_frame_for_scale frames + if num_frame_for_scale > 0: + mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=mask + ) + else: + # Apply RoPE to current k before caching + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif enable_3d_rope and pos is not None: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + # Check if we should skip appending to cache (non-keyframe in keyframe mode) + skip_append = kv_cache.get("_skip_append", False) + + k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + + if not skip_append: + # KEYFRAME: store in cache (original behavior) + if kv_cache[f"k_{global_idx}"] is None: + kv_cache[f"k_{global_idx}"] = k_reshaped + kv_cache[f"v_{global_idx}"] = v_reshaped + else: + num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3] + k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2) + kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2) + + # Apply sliding window eviction BEFORE attention to match causal_3drope behavior + # This ensures current frame only attends to frames within the sliding window + self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx) + + # Retrieve full k, v from cache (already RoPE-applied, already evicted) + k = kv_cache[f"k_{global_idx}"].clone() + v = kv_cache[f"v_{global_idx}"].clone() + else: + # NON-KEYFRAME: attend to [cached + current] without storing in cache + if kv_cache[f"k_{global_idx}"] is not None: + k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2) + v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2) + else: + k = k_reshaped + v = v_reshaped + a, b, c, d, e = k.shape + + k = k.reshape(a, b, c*d, e) + v = v.reshape(a, b, c*d, e) + + # Prepend special tokens (camera + scale) from evicted frames if they exist + if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None: + special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D] + special_v = kv_cache[f"v_{global_idx}_special"] + sa, sb, sc, sd, se = special_k.shape + special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D] + special_v = special_v.reshape(sa, sb, sc * sd, se) + + # Prepend special tokens (older frames first) + k = torch.cat([special_k, k], dim=2) + v = torch.cat([special_v, v], dim=2) + + # Note: k from cache is already RoPE-applied, no need to apply again + + if self.fused_attn: + # Use mask-based SDPA to ensure same kernel as batch mode + # The causal constraint is enforced by KV cache contents, not by mask + mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device) + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=mask, + ) + + if self.gate_proj is not None: + x = x * torch.sigmoid(gate_score) + if enable_ulysses_cp: + x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) + # Use actual dimensions from attention output, not original input C + # x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim] + x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx): + """ + Apply sliding window eviction to KV cache BEFORE attention. + + This ensures current frame only attends to frames within the sliding window, + matching the behavior of causal_3drope's attention mask. + """ + sliding_window_frames = self.kv_cache_sliding_window + scale_frames = self.kv_cache_scale_frames + + if kv_cache[f"k_{global_idx}"].shape[3] > 1: + num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2] + + if num_cached_frames > sliding_window_frames + scale_frames: + evict_start = scale_frames + evict_end = num_cached_frames - sliding_window_frames + + if evict_end > evict_start: + evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :] + evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :] + + if self.kv_cache_cross_frame_special: + if self.kv_cache_camera_only: + # Only keep camera token + new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() + new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() + else: + # Keep ALL special tokens (camera + register + scale) to match attention_mask behavior + # Special tokens are in range [camera_token_idx, scale_token_idx+1) + new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() + new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() + + if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None: + kv_cache[f"k_{global_idx}_special"] = new_special_k + kv_cache[f"v_{global_idx}_special"] = new_special_v + else: + kv_cache[f"k_{global_idx}_special"] = torch.cat( + [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2) + kv_cache[f"v_{global_idx}_special"] = torch.cat( + [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2) + + if self.kv_cache_include_scale_frames: + kv_cache[f"k_{global_idx}"] = torch.cat([ + kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :], + kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] + ], dim=2) + kv_cache[f"v_{global_idx}"] = torch.cat([ + kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :], + kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] + ], dim=2) + else: + kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] + kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] + + +class FlashInferAttention(Attention): + """ + FlashInfer variant of the GCT attention layer. + Uses FlashInferKVCacheManager for paged KV cache storage and + FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper). + Supports the same optimized token layout and KV cache streaming inference. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, + rope=None, + # KV cache eviction parameters + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + ) -> None: + if not FLASHINFER_AVAILABLE: + raise RuntimeError("FlashInfer is not available. Please install flashinfer.") + + super().__init__( + dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + # Store KV cache eviction parameters + self.kv_cache_sliding_window = kv_cache_sliding_window + self.kv_cache_scale_frames = kv_cache_scale_frames + self.kv_cache_cross_frame_special = kv_cache_cross_frame_special + self.kv_cache_include_scale_frames = kv_cache_include_scale_frames + self.kv_cache_camera_only = kv_cache_camera_only + + def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple: + """Fused pre-attention ops for single-frame streaming (Phase 2). + + Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to + [tpf, H, D] format ready for append_frame + compute_attention. + + Extracted as a method so torch.compile can capture all pre-attn ops as one + CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 -> + squeeze/permute/contiguous x3). + """ + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim] + q, k = self.q_norm(q), self.k_norm(k) + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif self.rope is not None: # enable_3d_rope=True + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + # Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode) + q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous() + k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous() + v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous() + return q_nhd, k_nhd, v_nhd + + def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, + num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False, + # KV cache parameters (kv_cache is a FlashInferKVCacheManager or None) + kv_cache=None, global_idx=0, num_frame_per_block=1, + num_frame_for_scale=-1, num_register_tokens=4) -> Tensor: + """ + Forward pass with FlashInfer paged KV cache and attention. + + Args: + x: Input tensor [B, N, C] + kv_cache: FlashInferKVCacheManager instance or None (batch mode) + global_idx: Block index for per-block cache access + """ + from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager + + B, N, C = x.shape + + # Detect if using optimized layout + using_optimized_layout = (num_patches is not None and num_special is not None + and num_frames is not None) + + # ========== Batch Mode (no KV cache manager) ========== + if not isinstance(kv_cache, FlashInferKVCacheManager): + # [3, B, num_heads, N, head_dim] + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim] + q, k = self.q_norm(q), self.k_norm(k) + + if enable_ulysses_cp: + if using_optimized_layout: + boundary = num_frames * num_patches + q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :] + q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :] + q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv( + q_patch, k_patch, v_patch, seq_dim=2, head_dim=1 + ) + q_special, k_special, v_special = gather_seq_scatter_heads_qkv( + q_special, k_special, v_special, seq_dim=2, head_dim=1 + ) + q = torch.cat([q_patch, q_special], dim=2) + k = torch.cat([k_patch, k_special], dim=2) + v = torch.cat([v_patch, v_special], dim=2) + else: + q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1) + + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif self.rope is not None and enable_3d_rope: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + # Batch mode: use SDPA for numerical consistency with SDPA variant + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + + if enable_ulysses_cp: + if using_optimized_layout: + seq_global = x.shape[2] + seq_local = num_frames * (num_patches + num_special) + cp_size = seq_global // seq_local + boundary_global = num_frames * cp_size * num_patches + x_patch = x[:, :, :boundary_global, :] + x_special = x[:, :, boundary_global:, :] + x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1) + x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1) + x = torch.cat([x_patch, x_special], dim=2) + else: + x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1) + + x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) + + # ========== Streaming Mode (with FlashInferKVCacheManager) ========== + else: + manager = kv_cache # FlashInferKVCacheManager + + # Phase 1 (scale frames): num_frames > 1 — multi-frame batch + # Phase 2 (streaming): num_frames == 1 — single frame + is_multi_frame = (num_frames is not None and num_frames > 1) + + if is_multi_frame: + # Phase 1: compute full self-attention via SDPA (all frames attend to each other), + # then append each frame's K/V to the paged cache one at a time. + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + # Apply RoPE before caching (RoPE baked into K before append) + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif self.rope is not None and enable_3d_rope: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) + + # Append each frame's K/V to the paged cache individually. + tpf = manager.tokens_per_frame + k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D] + v_all = v.squeeze(0).permute(1, 0, 2) + for f_idx in range(num_frames): + s = f_idx * tpf + manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous()) + manager.evict_frames( + block_idx=global_idx, + scale_frames=self.kv_cache_scale_frames, + sliding_window=self.kv_cache_sliding_window, + cross_frame_special=self.kv_cache_cross_frame_special, + include_scale_frames=self.kv_cache_include_scale_frames, + camera_only=self.kv_cache_camera_only, + num_register_tokens=num_register_tokens, + ) + else: + # Phase 2: single-frame streaming via FlashInfer paged attention. + q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope) + + # 1. Append to paged cache + manager.append_frame(global_idx, k_nhd, v_nhd) + + # 2. Apply sliding window eviction + manager.evict_frames( + block_idx=global_idx, + scale_frames=self.kv_cache_scale_frames, + sliding_window=self.kv_cache_sliding_window, + cross_frame_special=self.kv_cache_cross_frame_special, + include_scale_frames=self.kv_cache_include_scale_frames, + camera_only=self.kv_cache_camera_only, + num_register_tokens=num_register_tokens, + ) + + # 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper + x = manager.compute_attention(global_idx, q_nhd) + + # Convert back: [tpf, H, D] -> [B, tpf, C]. + x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SDPAAttention(Attention): + """ + SDPA variant for streaming inference. + Uses F.scaled_dot_product_attention with dict-based KV cache. + No FlashInfer dependency required — works on any CUDA GPU. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, + rope=None, + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + ) -> None: + super().__init__( + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, + attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, + qk_norm=qk_norm, fused_attn=fused_attn, rope=rope, + ) + self.kv_cache_sliding_window = kv_cache_sliding_window + self.kv_cache_scale_frames = kv_cache_scale_frames + self.kv_cache_cross_frame_special = kv_cache_cross_frame_special + self.kv_cache_include_scale_frames = kv_cache_include_scale_frames + self.kv_cache_camera_only = kv_cache_camera_only + + def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, + num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False, + kv_cache=None, global_idx=0, num_frame_per_block=1, + num_frame_for_scale=-1, num_register_tokens=4) -> Tensor: + B, N, C = x.shape + using_optimized_layout = (num_patches is not None and num_special is not None + and num_frames is not None) + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + # ========== Batch Mode (no KV cache) ========== + if kv_cache is None: + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif self.rope is not None and enable_3d_rope: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim) + + # ========== Streaming Mode (with KV cache dict) ========== + else: + if self.rope is not None and not enable_3d_rope: + q = self.rope(q, pos) + k = self.rope(k, pos) + elif self.rope is not None and enable_3d_rope: + q = apply_rotary_emb(q, pos) + k = apply_rotary_emb(k, pos) + + camera_token_idx = 0 + scale_token_idx = camera_token_idx + num_register_tokens + 1 + + if kv_cache[f"k_{global_idx}"] is None: + kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block, + N // num_frame_per_block, self.head_dim) + kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block, + N // num_frame_per_block, self.head_dim) + else: + num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3] + kv_cache[f"k_{global_idx}"] = torch.cat(( + kv_cache[f"k_{global_idx}"], + k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + ), dim=2) + kv_cache[f"v_{global_idx}"] = torch.cat(( + kv_cache[f"v_{global_idx}"], + v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim) + ), dim=2) + + self._apply_kv_cache_eviction( + kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens + ) + + k_cached = kv_cache[f"k_{global_idx}"].clone() + v_cached = kv_cache[f"v_{global_idx}"].clone() + a, b, c, d, e = k_cached.shape + k_full = k_cached.reshape(a, b, c * d, e) + v_full = v_cached.reshape(a, b, c * d, e) + + if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None: + special_k = kv_cache[f"k_{global_idx}_special"] + special_v = kv_cache[f"v_{global_idx}_special"] + sa, sb, sc, sd, se = special_k.shape + k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2) + v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2) + + q_seq_len = q.shape[2] + x = F.scaled_dot_product_attention( + q, k_full, v_full, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens): + """Apply sliding window eviction to KV cache.""" + sliding_window_frames = self.kv_cache_sliding_window + scale_frames = self.kv_cache_scale_frames + + if kv_cache[f"k_{global_idx}"].shape[3] > 1: + num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2] + if num_cached_frames > sliding_window_frames + scale_frames: + evict_start = scale_frames + evict_end = num_cached_frames - sliding_window_frames + if evict_end > evict_start: + evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :] + evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :] + + if self.kv_cache_cross_frame_special: + if self.kv_cache_camera_only: + new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() + new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone() + else: + new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() + new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone() + + if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None: + kv_cache[f"k_{global_idx}_special"] = new_special_k + kv_cache[f"v_{global_idx}_special"] = new_special_v + else: + kv_cache[f"k_{global_idx}_special"] = torch.cat( + [kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2) + kv_cache[f"v_{global_idx}_special"] = torch.cat( + [kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2) + + if self.kv_cache_include_scale_frames: + kv_cache[f"k_{global_idx}"] = torch.cat([ + kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :], + kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] + ], dim=2) + kv_cache[f"v_{global_idx}"] = torch.cat([ + kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :], + kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] + ], dim=2) + else: + kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :] + kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :] diff --git a/lingbot_map/layers/block.py b/lingbot_map/layers/block.py new file mode 100644 index 0000000..0cd0b6b --- /dev/null +++ b/lingbot_map/layers/block.py @@ -0,0 +1,514 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings +import math + +import torch +from torch import nn, Tensor + +from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention +from functools import lru_cache, partial +from torch.nn.attention.flex_attention import BlockMask, create_mask +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + fused_attn=fused_attn, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, + num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp, + num_patches=num_patches, num_special=num_special, num_frames=num_frames, + enable_3d_rope=enable_3d_rope)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + x = drop_add_residual_stochastic_depth( + x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +class FlashInferBlock(nn.Module): + """ + FlashInfer variant of causal block for GCT. + Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels). + Supports optimized token layout and KV cache streaming inference. + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = FlashInferAttention( + dim=dim, + num_heads=num_heads, + qk_norm=qk_norm, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + rope=rope, + kv_cache_sliding_window=kv_cache_sliding_window, + kv_cache_scale_frames=kv_cache_scale_frames, + kv_cache_cross_frame_special=kv_cache_cross_frame_special, + kv_cache_include_scale_frames=kv_cache_include_scale_frames, + kv_cache_camera_only=kv_cache_camera_only, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple: + """Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit. + + Extracted as a named method so torch.compile can capture norm1 + qkv-linear + + reshape + q_norm + k_norm + RoPE + format as a single CUDA graph. + + Returns: + (q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim], + ready for manager.append_frame + manager.compute_attention. + """ + return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope) + + def forward( + self, + x: Tensor, + pos=None, + enable_ulysses_cp=False, + num_patches=None, + num_special=None, + num_frames=None, + enable_3d_rope=False, + kv_cache=None, + global_idx=0, + num_frame_per_block=1, + num_frame_for_scale=-1, + num_register_tokens=4, + ) -> Tensor: + # Phase 2 (streaming): single-frame FlashInfer paged attention. + # Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph. + is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1)) + if is_streaming: + manager = kv_cache + # Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format + q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope) + # Eager: write frame K/V to paged cache + manager.append_frame(global_idx, k_nhd, v_nhd) + # CPU-only: update eviction state (deque ops, no GPU kernel) + manager.evict_frames( + block_idx=global_idx, + scale_frames=self.attn.kv_cache_scale_frames, + sliding_window=self.attn.kv_cache_sliding_window, + cross_frame_special=self.attn.kv_cache_cross_frame_special, + include_scale_frames=self.attn.kv_cache_include_scale_frames, + camera_only=self.attn.kv_cache_camera_only, + num_register_tokens=num_register_tokens, + ) + # Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper + attn_x = manager.compute_attention(global_idx, q_nhd) + # [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output) + attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0], + self.attn.num_heads * self.attn.head_dim) + # Compiled: output projection + attn_x = self.attn.proj(attn_x) + x = x + self.ls1(attn_x) + else: + # Phase 1 (multi-frame scale pass) or non-streaming training path + x = x + self.ls1(self.attn( + self.norm1(x), + pos=pos, + enable_ulysses_cp=enable_ulysses_cp, + num_patches=num_patches, + num_special=num_special, + num_frames=num_frames, + enable_3d_rope=enable_3d_rope, + kv_cache=kv_cache, + global_idx=global_idx, + num_frame_per_block=num_frame_per_block, + num_frame_for_scale=num_frame_for_scale, + num_register_tokens=num_register_tokens, + )) + x = self.ffn_residual(x) + return x + + def ffn_residual(self, x: Tensor) -> Tensor: + """FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in. + + Includes the residual add (x + ...) so torch.compile captures the entire + ffn branch as one CUDA graph. + """ + return x + self.ls2(self.mlp(self.norm2(x))) + + +class CameraBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + elementwise_attn_output_gate: bool = False, + sliding_window_size: int = -1, + attend_to_scale_frames: bool = False, + num_random_frames: int = 0, + # KV cache parameters + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + self.attn = CausalAttention(dim=dim, num_heads=num_heads, + qk_norm=qk_norm, qkv_bias=qkv_bias, + rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate, + kv_cache_sliding_window=kv_cache_sliding_window, + kv_cache_scale_frames=kv_cache_scale_frames, + kv_cache_cross_frame_special=kv_cache_cross_frame_special, + kv_cache_include_scale_frames=kv_cache_include_scale_frames, + kv_cache_camera_only=kv_cache_camera_only) + + self.sliding_window_size = sliding_window_size + self.attend_to_scale_frames = attend_to_scale_frames + self.num_random_frames = num_random_frames + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + self.masks = {} + + @torch.no_grad() + def _prepare_blockwise_causal_attn_mask(self, + device: torch.device | str, num_frames: int = 21, + frame_seqlen: int = 1560, num_frame_per_block=1 + ) -> BlockMask: + """ + we will divide the token sequence into the following format + [1 latent frame] [1 latent frame] ... [1 latent frame] + We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros(total_length + padded_length, + device=device, dtype=torch.long) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange( + start=0, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device + ) + + for tmp in frame_indices: + ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \ + frame_seqlen * num_frame_per_block + + def attention_mask(b, h, q_idx, kv_idx): + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask + + block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, device=device) + + return block_mask + + def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor: + # Use passed sliding_window_size if provided, otherwise use self.sliding_window_size + effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size + + # Fast path for full attention (camera head) - skip mask computation + if full_attention: + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + mask_block = self._prepare_blockwise_causal_attn_mask( + device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block) + + + def attn_residual_func(x: Tensor, pos=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames, + enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x + + +class SDPABlock(nn.Module): + """ + SDPA variant for streaming inference. Uses F.scaled_dot_product_attention + with dict-based KV cache. No FlashInfer dependency required. + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = SDPAAttention( + dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias, + proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope, + kv_cache_sliding_window=kv_cache_sliding_window, + kv_cache_scale_frames=kv_cache_scale_frames, + kv_cache_cross_frame_special=kv_cache_cross_frame_special, + kv_cache_include_scale_frames=kv_cache_include_scale_frames, + kv_cache_camera_only=kv_cache_camera_only, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, drop=drop, bias=ffn_bias) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, + num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False, + kv_cache=None, global_idx=0, num_frame_per_block=1, + num_frame_for_scale=-1, num_register_tokens=4) -> Tensor: + def attn_residual_func(x, pos=None): + return self.ls1(self.attn( + self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp, + num_patches=num_patches, num_special=num_special, num_frames=num_frames, + enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx, + num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, + num_register_tokens=num_register_tokens, + )) + + def ffn_residual_func(x): + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos)) + x = x + self.drop_path1(ffn_residual_func(x)) + else: + x = x + attn_residual_func(x, pos=pos) + x = x + ffn_residual_func(x) + return x diff --git a/lingbot_map/layers/drop_path.py b/lingbot_map/layers/drop_path.py new file mode 100644 index 0000000..1d640e0 --- /dev/null +++ b/lingbot_map/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/lingbot_map/layers/flashinfer_cache.py b/lingbot_map/layers/flashinfer_cache.py new file mode 100644 index 0000000..1660f97 --- /dev/null +++ b/lingbot_map/layers/flashinfer_cache.py @@ -0,0 +1,582 @@ +""" +FlashInfer KV Cache Manager — Two-Stream Paged Design. + +Two logical streams sharing one physical page pool per layer: + + Patch stream (recyclable): + - page_size = patches_per_frame (256 for 224×224; 972 for 504×378) + - Exactly 1 patch page per frame + - Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames) + - Recent frames → live_window_patch_pages (evicted when > sliding_window) + + Special stream (append-only, never recycled): + - num_special_tokens (6) special tokens per frame + - Packed continuously: one special page holds floor(page_size/6) frames + e.g. page_size=256 → 42 frames per special page, 4 slots wasted + - Specials written for EVERY frame (including scale + window), not just evicted ones. + +Physical layout per block: + kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D] + Pages 0 .. max_patch_pages-1 : patch page pool (recyclable) + Pages max_patch_pages .. max_pages-1: special page pool (append-only) + dim 1: 0=K 1=V + +Attention computation: + visible = scale_patch_pages + live_window_patch_pages + all_special_pages + Special pages placed LAST → paged_kv_last_page_len naturally describes + the partial special-tail without a custom mask. + + plan() is called ONCE per frame step (when block_idx == 0). + run() is called per layer, reusing the same plan. All layers at the + same frame step have identical page structures (same page IDs in same + positions), so reusing the plan across layers is correct. + +Public API is drop-in compatible with the previous FlashInferKVCacheManager: + append_frame(block_idx, k, v) + evict_frames(block_idx, scale_frames, sliding_window, ...) + compute_attention(block_idx, q) -> out + reset() +""" + +import collections +import math +from typing import List + +import torch +from torch import Tensor + +try: + import flashinfer + FLASHINFER_AVAILABLE = True +except ImportError: + FLASHINFER_AVAILABLE = False + + +class FlashInferKVCacheManager: + """ + Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only). + + Args: + num_blocks: Number of Transformer blocks (one cache per block). + max_num_frames: Maximum frames held in the KV window at once + (scale_frames + sliding_window + headroom). + tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262). + num_heads: Number of KV heads (= QO heads; MHA assumed). + head_dim: Head dimension (64 for ViT-L). + dtype: Storage dtype (bfloat16 / float16). + device: CUDA device. + num_special_tokens: Special tokens per frame: camera + register×N + scale (6). + scale_frames: Number of always-resident scale frames (8). + sliding_window: Sliding window size (64). + max_total_frames: Upper bound on total frames ever processed; used to + pre-allocate the special page pool (default 2048). + """ + + def __init__( + self, + num_blocks: int, + max_num_frames: int, + tokens_per_frame: int, + num_heads: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + num_special_tokens: int = 6, + scale_frames: int = 8, + sliding_window: int = 64, + max_total_frames: int = 2048, + force_fp32: bool = False, + fa3: bool = False, + ): + if not FLASHINFER_AVAILABLE: + raise RuntimeError("FlashInfer is not available. Please install flashinfer.") + + self.num_blocks = num_blocks + self.num_special_tokens = num_special_tokens # 6 + self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ... + # Use exact page_size = patches_per_frame to eliminate zero-padded slots. + # FA2 (backend="fa2") supports non-power-of-2 page sizes. + # FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True. + p = self.patches_per_frame + if fa3: + # Round up to next power-of-2 for FA3 SM90 kernel requirement. + # e.g. 999 → 1024 (25 zero-padded slots per patch page) + self.page_size = 1 << (p - 1).bit_length() + else: + self.page_size = p # exact: no zero padding in patch pages + self.scale_frames = scale_frames # 8 + self.sliding_window = sliding_window # 64 + self.num_heads = num_heads + self.head_dim = head_dim + self.tokens_per_frame = tokens_per_frame + + assert self.patches_per_frame > 0, ( + f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}" + ) + assert self.page_size > 0 + + # force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and + # instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention + # in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode. + self.force_fp32 = force_fp32 + if force_fp32: + self.dtype = torch.float32 + else: + if dtype == torch.float32: + dtype = torch.bfloat16 + self.dtype = dtype + self.device = device + + # ── Page pool sizing ───────────────────────────────────────────────── + # Patch: scale + window + 16 headroom (pages recycled → fixed count) + max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88 + # Special: enough for max_total_frames × 6 tokens, plus 16 headroom + max_special_pages = ( + math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16 + ) + self.max_patch_pages = max_patch_pages + self.max_num_pages = max_patch_pages + max_special_pages + + # ── Physical paged KV caches ───────────────────────────────────────── + # Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1) + self.kv_caches: List[Tensor] = [ + torch.zeros( + self.max_num_pages, 2, self.page_size, num_heads, head_dim, + dtype=dtype, device=device, + ) + for _ in range(num_blocks) + ] + + # ── Per-block state ────────────────────────────────────────────────── + # Patch pages (IDs 0 .. max_patch_pages-1) + self.scale_patch_pages: List[collections.deque] = [ + collections.deque() for _ in range(num_blocks) + ] + self.live_window_patch_pages: List[collections.deque] = [ + collections.deque() for _ in range(num_blocks) + ] + self.free_patch_pages: List[List[int]] = [ + list(range(max_patch_pages)) for _ in range(num_blocks) + ] + + # Special pages (IDs max_patch_pages .. max_num_pages-1) + self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)] + self.free_special_pages: List[List[int]] = [ + list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks) + ] + self.special_token_count: List[int] = [0] * num_blocks + + # Frame counter per block (determines scale vs window routing) + self.frame_count: List[int] = [0] * num_blocks + + # ── FlashInfer wrapper ─────────────────────────────────────────────── + # plan() is called once per frame step (block_idx == 0). + # run() is called per layer, reusing the same aux structures. + # backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size). + # FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in + # FlashInfer 0.2.5 at 518×378 resolution. + _fi_backend = "fa3" if fa3 else "fa2" + self.workspace_buffer = torch.zeros( + 128 * 1024 * 1024, dtype=torch.uint8, device=device + ) + self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + kv_layout="NHD", + backend=_fi_backend, + ) + + # plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed) + self._qo_indptr = torch.tensor( + [0, tokens_per_frame], dtype=torch.int32, device=device + ) + + # ========================================================================= + # Public API (drop-in compatible with previous FlashInferKVCacheManager) + # ========================================================================= + + def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None: + """ + Append one frame's K/V tensors to the two-stream cache. + + Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1] + i.e. specials come first (matching stream.py's patch_start_idx convention). + + Args: + block_idx: Block/layer index (0 … num_blocks-1). + k: [tokens_per_frame, H, D] NHD layout. + v: [tokens_per_frame, H, D] NHD layout. + """ + n = self.num_special_tokens # 6 + sp_k = k[:n].to(self.dtype) # [6, H, D] + patch_k = k[n:].to(self.dtype) # [256, H, D] + sp_v = v[:n].to(self.dtype) + patch_v = v[n:].to(self.dtype) + + assert patch_k.shape[0] == self.patches_per_frame, ( + f"block {block_idx}: expected {self.patches_per_frame} patch tokens, " + f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})" + ) + + self._write_patch_page(block_idx, patch_k, patch_v) + self._write_special_tokens(block_idx, sp_k, sp_v) + self.frame_count[block_idx] += 1 + + def evict_frames( + self, + block_idx: int, + scale_frames: int, + sliding_window: int, + cross_frame_special: bool = True, + include_scale_frames: bool = True, + camera_only: bool = False, + num_register_tokens: int = 4, + ) -> None: + """ + Evict old window patch pages (recycle to free list). + + Special pages are NEVER evicted. + Scale pages are NEVER evicted. + Only live_window_patch_pages beyond `sliding_window` are recycled. + """ + while len(self.live_window_patch_pages[block_idx]) > sliding_window: + old_page = self.live_window_patch_pages[block_idx].popleft() + self.free_patch_pages[block_idx].append(old_page) + + def _gather_kv(self, block_idx: int): + """ + Gather all visible K and V tokens from the paged cache into dense tensors. + + Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only + supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32. + + Returns: + k_flat: [kv_len, H, D] — all visible K tokens concatenated + v_flat: [kv_len, H, D] — all visible V tokens concatenated + """ + visible = self.build_visible_page_table(block_idx) + last_len = self.compute_last_page_len(block_idx) + P = self.page_size + + parts_k, parts_v = [], [] + for i, pid in enumerate(visible): + n = last_len if (i == len(visible) - 1) else P + parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D] + parts_v.append(self.kv_caches[block_idx][pid, 1, :n]) + + k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D] + v_flat = torch.cat(parts_v, dim=0) + return k_flat, v_flat + + def compute_attention(self, block_idx: int, q: Tensor) -> Tensor: + """ + Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper. + + When self.force_fp32 is True, gathers all visible K/V into dense tensors + and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel. + This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16. + + plan() is called once per frame step (when block_idx == 0). + All layers at the same step share the same visible page structure, + so the plan is reused by calling run() with each layer's kv_cache. + + Args: + block_idx: Block/layer index. + q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262). + + Returns: + out: [q_len, H, D] + """ + if self.frame_count[block_idx] == 0: + # No KV present yet (should not occur in normal usage after append_frame) + return torch.zeros_like(q) + + if self.force_fp32: + # ── fp32 gather+SDPA path ───────────────────────────────────────── + # Gather visible K/V from paged cache and run SDPA in fp32. + # This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy. + # q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout) + import torch.nn.functional as F_nn + k_flat, v_flat = self._gather_kv(block_idx) + q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D] + k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D] + v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D] + out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b) + return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D] + + if block_idx == 0: + # ── Plan once per frame step ────────────────────────────────────── + # Build visible page table from block 0's state. + # All blocks have identical page structures, so this plan is valid + # for all subsequent run() calls (block_idx = 1, 2, ...). + visible = self.build_visible_page_table(0) + last_len = self.compute_last_page_len(0) + + assert visible, "visible page table is empty after append_frame" + assert 1 <= last_len <= self.page_size, ( + f"block 0: last_page_len={last_len} out of [1, {self.page_size}]" + ) + + paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device) + paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device) + paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device) + + self.prefill_wrapper.plan( + self._qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + num_qo_heads = self.num_heads, + num_kv_heads = self.num_heads, + head_dim_qk = self.head_dim, + page_size = self.page_size, + causal = False, # custom page ordering; no causal mask + pos_encoding_mode = "NONE", # RoPE applied externally before append + q_data_type = self.dtype, + ) + + # ── Run attention for this layer ────────────────────────────────────── + # Cast q to storage dtype (LayerNorm may upcast to float32 under autocast). + return self.prefill_wrapper.run( + q = q.to(self.dtype).contiguous(), + paged_kv_cache = self.kv_caches[block_idx], + ) # → [q_len, H, D] + + def reset(self) -> None: + """Reset all per-block state for a new sequence.""" + for i in range(self.num_blocks): + self.scale_patch_pages[i].clear() + self.live_window_patch_pages[i].clear() + self.all_special_pages[i].clear() + self.free_patch_pages[i] = list(range(self.max_patch_pages)) + self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages)) + self.special_token_count[i] = 0 + self.frame_count[i] = 0 + + # ========================================================================= + # Helper methods + # ========================================================================= + + def build_visible_page_table(self, block_idx: int) -> List[int]: + """ + Return page IDs in strict order: scale → window → special. + + Placing special pages last means only the final page may be partially + full, so paged_kv_last_page_len = compute_last_page_len() is sufficient + without a custom attention mask. + """ + return ( + list(self.scale_patch_pages[block_idx]) + + list(self.live_window_patch_pages[block_idx]) + + list(self.all_special_pages[block_idx]) + ) + + def compute_last_page_len(self, block_idx: int) -> int: + """ + Valid token count in the last page of the visible sequence. + + - No special pages → last page is a patch page. + Returns patches_per_frame (real tokens written), + which may be < page_size when page_size was rounded + up to a power of 2. + - Special tail partial → special_token_count % page_size. + - Special tail exactly full → page_size. + """ + if not self.all_special_pages[block_idx]: + # Last page is a patch page. We wrote patches_per_frame tokens (0..P-1); + # positions P..page_size-1 are zero padding. Tell FlashInfer the true + # valid count so it doesn't read beyond the real tokens. + return self.patches_per_frame + + tail = self.special_token_count[block_idx] % self.page_size + return self.page_size if tail == 0 else tail + + # ── Internal write helpers ──────────────────────────────────────────────── + + def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int: + """ + Allocate one free patch page and write patches_per_frame patch tokens. + + Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids + the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache. + kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1). + patch_k/v fill exactly one full page (patches_per_frame == page_size). + + Routes to scale_patch_pages if still filling scale quota, + otherwise to live_window_patch_pages. + + Returns: + page_id: Physical page index used. + """ + assert self.free_patch_pages[block_idx], ( + f"block {block_idx}: patch page pool exhausted — " + f"scale={len(self.scale_patch_pages[block_idx])}, " + f"window={len(self.live_window_patch_pages[block_idx])}, " + f"free={len(self.free_patch_pages[block_idx])}" + ) + + page_id = self.free_patch_pages[block_idx].pop() + + # Direct slice write: positions 0..patches_per_frame-1. + # When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224), + # this is equivalent to a full-page write. When page_size > patches_per_frame + # (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999), + # positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init). + P = self.patches_per_frame + self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K + self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V + + if len(self.scale_patch_pages[block_idx]) < self.scale_frames: + self.scale_patch_pages[block_idx].append(page_id) + else: + self.live_window_patch_pages[block_idx].append(page_id) + + return page_id + + def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None: + """ + Append num_special_tokens (6) special tokens to the special stream. + + Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1, + tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch + overhead of flashinfer.page.append_paged_kv_cache. + + Handles page-boundary crossing: if 6 tokens straddle two pages, performs + two slice writes (rare — page_size=256 >> 6). + """ + remaining = self.num_special_tokens # 6 + written = 0 + + while remaining > 0: + tail_offset = self.special_token_count[block_idx] % self.page_size + + if tail_offset == 0: + # Current tail page is full (or no page exists) — allocate a new one + assert self.free_special_pages[block_idx], ( + f"block {block_idx}: special page pool exhausted at " + f"special_token_count={self.special_token_count[block_idx]}. " + f"Increase max_total_frames." + ) + new_page = self.free_special_pages[block_idx].pop() + self.all_special_pages[block_idx].append(new_page) + + tail_page = self.all_special_pages[block_idx][-1] + space = self.page_size - tail_offset # free slots in tail page + write_n = min(remaining, space) + + # Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n] + # shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :] + end = tail_offset + write_n + self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n] + self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n] + + self.special_token_count[block_idx] += write_n + written += write_n + remaining -= write_n + + # ── Legacy property (used by stream.py) ────────────────────────────────── + + @property + def num_frames(self) -> int: + """Number of frames appended to block 0 (representative).""" + return self.frame_count[0] if self.frame_count else 0 + + +# ============================================================================= +# Sanity check +# ============================================================================= + +def _sanity_check(): + """ + Minimal smoke test. + Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()" + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not torch.cuda.is_available(): + print("[sanity_check] CUDA not available — skipping.") + return + + tokens_per_frame = 262 # 256 patch + 6 special (224×224) + num_special = 6 + patches_per_frame = tokens_per_frame - num_special # 256 + page_size = patches_per_frame # 256 + + mgr = FlashInferKVCacheManager( + num_blocks = 2, + max_num_frames = 88, + tokens_per_frame = tokens_per_frame, + num_heads = 16, + head_dim = 64, + dtype = torch.bfloat16, + device = device, + num_special_tokens = num_special, + scale_frames = 8, + sliding_window = 64, + max_total_frames = 200, + ) + + def make_kv(): + k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device) + v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device) + return k, v + + def make_q(): + return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device) + + for block in range(2): + for t in range(100): + k, v = make_kv() + mgr.append_frame(block, k, v) + mgr.evict_frames(block, scale_frames=8, sliding_window=64) + + # ── Page count checks ─────────────────────────────────────────────── + n_scale = len(mgr.scale_patch_pages[block]) + n_window = len(mgr.live_window_patch_pages[block]) + n_spec = len(mgr.all_special_pages[block]) + sp_count = mgr.special_token_count[block] + + assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8" + assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64" + # 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages + expected_spec_pages = math.ceil(100 * num_special / page_size) + assert n_spec == expected_spec_pages, ( + f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}" + ) + assert sp_count == 100 * num_special, ( + f"block {block}: special_token_count = {sp_count}, expected {100*num_special}" + ) + + # ── last_page_len ──────────────────────────────────────────────────── + last_len = mgr.compute_last_page_len(block) + tail = sp_count % page_size + expected_len = page_size if tail == 0 else tail + assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}" + + # ── visible page table order ───────────────────────────────────────── + visible = mgr.build_visible_page_table(block) + assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch" + for pid in visible[:n_scale + n_window]: + assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range" + for pid in visible[n_scale + n_window:]: + assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range" + + # ── forward pass: plan() once for block 0, run() for both blocks ───── + if block == 1: + # Simulate the actual calling pattern: plan on block 0, run on both + q0 = make_q() + out0 = mgr.compute_attention(0, q0) # triggers plan() + q1 = make_q() + out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache + assert out0.shape == (tokens_per_frame, 16, 64) + assert out1.shape == (tokens_per_frame, 16, 64) + + print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, " + f"special_pages={n_spec}, special_tokens={sp_count}, " + f"last_page_len={last_len}") + + mgr.reset() + assert mgr.frame_count[0] == 0 + print("\n[sanity_check] All assertions passed.") + + +if __name__ == "__main__": + _sanity_check() diff --git a/lingbot_map/layers/layer_scale.py b/lingbot_map/layers/layer_scale.py new file mode 100644 index 0000000..4ddfc51 --- /dev/null +++ b/lingbot_map/layers/layer_scale.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/lingbot_map/layers/mlp.py b/lingbot_map/layers/mlp.py new file mode 100644 index 0000000..bbf9432 --- /dev/null +++ b/lingbot_map/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/lingbot_map/layers/patch_embed.py b/lingbot_map/layers/patch_embed.py new file mode 100644 index 0000000..bc19605 --- /dev/null +++ b/lingbot_map/layers/patch_embed.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/lingbot_map/layers/rope.py b/lingbot_map/layers/rope.py new file mode 100644 index 0000000..7f44e31 --- /dev/null +++ b/lingbot_map/layers/rope.py @@ -0,0 +1,474 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + +from typing import List, Optional, Tuple, Union + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) + + + +def get_1d_rotary_pos_embed( + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, + freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) +): + """ + 计算1D旋转位置编码(RoPE)的频率张量。 + + RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。 + 公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000) + + Args: + dim: 特征维度,必须是偶数(因为要成对处理) + pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S] + theta: 基础频率,控制位置编码的周期性(默认10000) + use_real: 是否返回实数形式(cos和sin分开)还是复数形式 + linear_factor: 线性缩放因子,用于上下文扩展 + ntk_factor: NTK-Aware缩放因子,用于处理更长的序列 + repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构) + freqs_dtype: 频率张量的数据类型 + + Returns: + 复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j) + 实数形式:两个 [S, D] 的张量(cos和sin) + """ + # 确保维度是偶数(RoPE需要成对处理维度) + assert dim % 2 == 0 + + # 将位置转换为torch张量 + if isinstance(pos, int): + pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1] + if isinstance(pos, np.ndarray): + pos = torch.from_numpy(pos) # [S] + + # 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列) + theta = theta * ntk_factor + + # 步骤1:计算频率 θ_i = 1 / (θ^(2i/d)) + # 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理) + # 公式:freq_i = 1 / (theta^(2i/d) * linear_factor) + freqs = ( + 1.0 + / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim)) + / linear_factor + ) # [D/2],每个频率对应一个维度对 + + # 步骤2:计算位置-频率矩阵 + # 使用外积:pos[m] * freqs[i] = m * θ_i + # 结果:每个位置m和每个频率i的组合 + freqs = torch.outer(pos, freqs) # [S, D/2] + + # 步骤3:根据返回格式转换 + if use_real and repeat_interleave_real: + # 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型) + # 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D] + return freqs_cos, freqs_sin + elif use_real: + # 方式2:拼接重复(用于stable audio, allegro等模型) + # 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n] + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] + return freqs_cos, freqs_sin + else: + # 方式3:复数形式(用于lumina等模型) + # 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ) + # torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2] + return freqs_cis + + +class WanRotaryPosEmbed(nn.Module): + """ + 3D旋转位置编码(3D RoPE)模块 + + 核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。 + 每个维度(t, h, w)独立使用RoPE,然后拼接起来。 + + 公式: + 对于3D位置 (f, h, w)(帧、高度、宽度): + - 帧维度使用 dim_f 个特征维度 + - 高度维度使用 dim_h 个特征维度 + - 宽度维度使用 dim_w 个特征维度 + 其中 dim_f + dim_h + dim_w = attention_head_dim + """ + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int = 1024, + theta: float = 10000.0, + fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22], + ): + super().__init__() + + self.attention_head_dim = attention_head_dim # 注意力头的总维度 + self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w) + self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率) + + # 步骤1:分配维度给三个空间维度 + if fhw_dim is not None: + # 如果指定了维度分配,使用指定的 + assert attention_head_dim == sum( + fhw_dim + ), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}" + t_dim, h_dim, w_dim = fhw_dim + else: + # 否则自动分配:h和w各占1/3,t占剩余 + # 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22 + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + # 保存维度分配以便在forward中使用 + self.fhw_dim = (t_dim, h_dim, w_dim) + + # 步骤2:为每个维度预计算频率 + # 分别计算时间、高度、宽度三个维度的RoPE频率 + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + # 每个维度独立调用1D RoPE + # 返回复数形式的频率: [max_seq_len, dim//2] + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 + ) + freqs.append(freq) + # 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2] + self.freqs = torch.cat(freqs, dim=1) + + def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor: + """ + 前向传播:为3D输入(视频帧+patch)生成旋转位置编码 + + 参数: + - ppf (int): 帧数(patches per frame),当f_end为None时使用 + - pph (int): 每帧的patch高度数量 + - ppw (int): 每帧的patch宽度数量 + - patch_start_idx (int): 每帧的特殊token数量(在patches之前) + - device: 计算设备(CPU/GPU) + - f_start (int): 起始帧索引(用于causal模式),默认为0 + - f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数 + + 返回: + - freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor + + Token排列顺序: + [frame0_special_token_0, ..., frame0_special_token_N, + frame0_patch_0, ..., frame0_patch_M, + frame1_special_token_0, ..., frame1_special_token_N, + frame1_patch_0, ..., frame1_patch_M, + ...] + + 模式: + - 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始 + - Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算 + """ + + # 步骤1:将预计算的频率移到目标设备,并分割成三个维度 + self.freqs = self.freqs.to(device) + # 获取实际的维度分配 + if hasattr(self, 'fhw_dim') and self.fhw_dim is not None: + t_dim, h_dim, w_dim = self.fhw_dim + else: + # 自动分配的情况 + h_dim = w_dim = 2 * (self.attention_head_dim // 6) + t_dim = self.attention_head_dim - h_dim - w_dim + + # 使用正确的split sizes(每个维度的一半) + freqs = self.freqs.split_with_sizes( + [ + t_dim // 2, # 时间维度 + h_dim // 2, # 高度维度 + w_dim // 2, # 宽度维度 + ], + dim=1, + ) + + # 处理causal模式:如果指定了f_end,重新计算ppf和帧范围 + if f_end is not None: + ppf = f_end - f_start + frame_slice = slice(f_start, f_end) + else: + # 非causal模式:使用从0开始的ppf个帧 + frame_slice = slice(0, ppf) + + # 步骤2:处理特殊token(如果存在) + ## For other tokens + if patch_start_idx > 0: + # 2.1 为特殊token生成位置编码 + # 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置 + # camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5) + # Shape: (ppf, patch_start_idx, dim) + freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化 + freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,... + freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,... + freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维 + freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim) + + # 2.2 为图像patch生成位置编码 + # Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx + # 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理 + # Shape: (ppf, pph, ppw, dim) + freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度 + freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始 + freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始 + freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维 + freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度 + + # 步骤3:按照正确的顺序组合特殊token和patches + # 每帧内部顺序:[特殊tokens, patches] + # Concatenate special tokens and patches for each frame along the second dimension + # Shape: (ppf, patch_start_idx + pph * ppw, dim) + freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim) + + # 步骤4:展平为最终形状并添加batch和head维度 + # Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim) + freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1) + freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度 + return freqs + + # 如果没有特殊token(patch_start_idx == 0),只处理图像patches + # 所有patches位于 (f, 0:pph, 0:ppw) + freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度 + freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始 + freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始 + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim) + return freqs + +def apply_rotary_emb(x, freqs): + """ + 应用旋转位置编码到输入特征 + + 核心思想:使用复数乘法实现特征旋转,保持相对位置信息 + + 数学原理: + 对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法: + (x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ)) + = (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ)) + + 这等价于旋转矩阵: + [cos(θ) -sin(θ)] [x1] + [sin(θ) cos(θ)] [x2] + + 参数: + - x: 输入特征 [batch, heads, seq_len, head_dim] + - freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2] + + 返回: + - x_out: 旋转后的特征 [batch, heads, seq_len, head_dim] + + 实现步骤: + 1. 将x的每两个连续特征看作一个复数 (real, imag) + 2. 与预计算的复数频率 e^(iθ) 相乘 + 3. 转回实数表示 + """ + # 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag) + # 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2] + x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2) + + # 步骤2:转换为复数表示 [b, h, seq, 32] + # 每个元素是 real + imag*i + x_complex = torch.view_as_complex(x_reshaped) + + # 步骤3:复数乘法实现旋转 + # x_complex * freqs 相当于将每对特征旋转θ角度 + # freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式 + x_rotated = x_complex * freqs + + # 步骤4:转回实数表示 [b, h, seq, 32, 2] + x_real = torch.view_as_real(x_rotated) + + # 步骤5:展平最后两维 [b, h, seq, 64] + x_out = x_real.flatten(3) + + # 步骤6:转回原始数据类型 + return x_out.to(x.dtype) diff --git a/lingbot_map/layers/swiglu_ffn.py b/lingbot_map/layers/swiglu_ffn.py new file mode 100644 index 0000000..1dd991e --- /dev/null +++ b/lingbot_map/layers/swiglu_ffn.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +# try: +# if XFORMERS_ENABLED: +# from xformers.ops import SwiGLU + +# XFORMERS_AVAILABLE = True +# warnings.warn("xFormers is available (SwiGLU)") +# else: +# warnings.warn("xFormers is disabled (SwiGLU)") +# raise ImportError +# except ImportError: +SwiGLU = SwiGLUFFN +XFORMERS_AVAILABLE = False + +# warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias) diff --git a/lingbot_map/layers/vision_transformer.py b/lingbot_map/layers/vision_transformer.py new file mode 100644 index 0000000..b6d0373 --- /dev/null +++ b/lingbot_map/layers/vision_transformer.py @@ -0,0 +1,411 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from torch.nn.init import trunc_normal_ +from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block + +# TODO: Check this +# We replace NestedTensorBlock with Block +from .block import Block + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + drop_cls_token=False, + qk_norm=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 if not drop_cls_token else 0 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + self.use_reentrant = False # hardcoded to False + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.drop_cls_token = drop_cls_token + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=qk_norm, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + if not self.drop_cls_token: + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + else: + patch_pos_embed = pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + if not self.drop_cls_token: + x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + else: + x = patch_pos_embed + return x.to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + if self.training: + x = checkpoint(blk, x, use_reentrant=self.use_reentrant) + else: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=True, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/lingbot_map/models/__init__.py b/lingbot_map/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lingbot_map/models/gct_base.py b/lingbot_map/models/gct_base.py new file mode 100644 index 0000000..9fe6e7f --- /dev/null +++ b/lingbot_map/models/gct_base.py @@ -0,0 +1,359 @@ +""" +GCTBase - Base class for GCT model implementations. + +Provides shared functionality: +- Prediction heads (camera, depth, point) +- Forward pass structure +- Model hub mixin (PyTorchModelHubMixin) +""" + +import logging +import numpy as np +import torch +import torch.nn as nn +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any, List, Union +from huggingface_hub import PyTorchModelHubMixin + +from lingbot_map.heads.dpt_head import DPTHead +from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri +from lingbot_map.utils.geometry import closed_form_inverse_se3 + +logger = logging.getLogger(__name__) + + +class GCTBase(nn.Module, PyTorchModelHubMixin, ABC): + """ + Base class for GCT model implementations. + + Handles shared components: + - Prediction heads (camera, depth, point) + - Forward pass structure + - Input normalization + + Subclasses must implement: + - _build_aggregator(): Create mode-specific aggregator + - _build_camera_head(): Create mode-specific camera head + """ + + def __init__( + self, + # Architecture parameters + img_size: int = 518, + patch_size: int = 14, + embed_dim: int = 1024, + patch_embed: str = 'dinov2_vitl14_reg', + disable_global_rope: bool = False, + # Head configuration + enable_camera: bool = True, + enable_point: bool = True, + enable_local_point: bool = False, + enable_depth: bool = True, + enable_track: bool = False, + # Camera head sliding window + enable_camera_sliding_window: bool = False, + # 3D RoPE + enable_3d_rope: bool = False, + # Context Parallelism (kept for checkpoint compatibility but not used) + enable_ulysses_cp: bool = False, + # Normalization + enable_normalize: bool = False, + # Prediction normalization + pred_normalization: bool = False, + pred_normalization_detach_scale: bool = False, + # Gradient checkpointing + use_gradient_checkpoint: bool = True, + ): + super().__init__() + + # Store configuration + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.patch_embed = patch_embed + self.disable_global_rope = disable_global_rope + + self.enable_ulysses_cp = False # CP disabled in standalone package + self.enable_normalize = enable_normalize + self.pred_normalization = pred_normalization + self.pred_normalization_detach_scale = pred_normalization_detach_scale + self.use_gradient_checkpoint = use_gradient_checkpoint + + # Head flags + self.enable_camera = enable_camera + self.enable_point = enable_point + self.enable_local_point = enable_local_point + self.enable_depth = enable_depth + self.enable_track = enable_track + self.enable_camera_sliding_window = enable_camera_sliding_window + self.enable_3d_rope = enable_3d_rope + + # Build aggregator (subclass-specific) + self.aggregator = self._build_aggregator() + + # Build prediction heads (subclass-specific) + self.camera_head = self._build_camera_head() if enable_camera else None + self.point_head = self._build_point_head() if enable_point else None + self.local_point_head = self._build_local_point_head() if enable_local_point else None + self.depth_head = self._build_depth_head() if enable_depth else None + + @abstractmethod + def _build_aggregator(self) -> nn.Module: + pass + + @abstractmethod + def _build_camera_head(self) -> nn.Module: + pass + + def _build_depth_head(self) -> nn.Module: + return DPTHead( + dim_in=2 * self.embed_dim, + patch_size=self.patch_size, + output_dim=2, + activation="exp", + conf_activation="expp1" + ) + + def _build_point_head(self) -> nn.Module: + return DPTHead( + dim_in=2 * self.embed_dim, + patch_size=self.patch_size, + output_dim=4, + activation="inv_log", + conf_activation="expp1" + ) + + def _build_local_point_head(self) -> nn.Module: + return DPTHead( + dim_in=2 * self.embed_dim, + patch_size=self.patch_size, + output_dim=4, + activation="inv_log", + conf_activation="expp1" + ) + + def _normalize_input(self, images: torch.Tensor, query_points=None): + if len(images.shape) == 4: + images = images.unsqueeze(0) + if query_points is not None and len(query_points.shape) == 2: + query_points = query_points.unsqueeze(0) + return images, query_points + + @abstractmethod + def _aggregate_features( + self, + images: torch.Tensor, + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + view_graphs: Optional[torch.Tensor] = None, + causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None, + ordered_video: Optional[torch.Tensor] = None, + is_cp_sliced: bool = False, + ) -> tuple: + pass + + def _predict_camera( + self, + aggregated_tokens_list: list, + mask: Optional[torch.Tensor] = None, + causal_inference: bool = False, + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + gather_outputs: bool = True, + ) -> Dict[str, torch.Tensor]: + if self.camera_head is None: + return {} + + aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] + + camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1 + + with torch.amp.autocast('cuda', enabled=False): + pose_enc_list = self.camera_head( + aggregated_tokens_list_fp32, + mask=mask, + causal_inference=causal_inference, + num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1, + sliding_window_size=camera_sliding_window, + num_frame_per_block=num_frame_per_block, + ) + + return { + "pose_enc": pose_enc_list[-1], + "pose_enc_list": pose_enc_list, + } + + def _predict_depth( + self, + aggregated_tokens_list: list, + images: torch.Tensor, + patch_start_idx: int, + gather_outputs: bool = True, + ) -> Dict[str, torch.Tensor]: + if self.depth_head is None: + return {} + + aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] + images_fp32 = images.float() + + with torch.amp.autocast('cuda', enabled=False): + depth, depth_conf = self.depth_head( + aggregated_tokens_list_fp32, + images=images_fp32, + patch_start_idx=patch_start_idx + ) + + return {"depth": depth, "depth_conf": depth_conf} + + def _predict_points( + self, + aggregated_tokens_list: list, + images: torch.Tensor, + patch_start_idx: int, + gather_outputs: bool = True, + ) -> Dict[str, torch.Tensor]: + if self.point_head is None: + return {} + + aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] + images_fp32 = images.float() + + with torch.amp.autocast('cuda', enabled=False): + pts3d, pts3d_conf = self.point_head( + aggregated_tokens_list_fp32, + images=images_fp32, + patch_start_idx=patch_start_idx + ) + + return {"world_points": pts3d, "world_points_conf": pts3d_conf} + + def _predict_local_points( + self, + aggregated_tokens_list: list, + images: torch.Tensor, + patch_start_idx: int, + gather_outputs: bool = True, + ) -> Dict[str, torch.Tensor]: + if self.local_point_head is None: + return {} + + aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list] + images_fp32 = images.float() + + with torch.amp.autocast('cuda', enabled=False): + pts3d, pts3d_conf = self.local_point_head( + aggregated_tokens_list_fp32, + images=images_fp32, + patch_start_idx=patch_start_idx + ) + + return {"cam_points": pts3d, "cam_points_conf": pts3d_conf} + + def _unproject_depth_to_world( + self, + depth: torch.Tensor, + pose_enc: torch.Tensor, + ) -> torch.Tensor: + B, S, H, W, _ = depth.shape + device = depth.device + dtype = depth.dtype + + image_size_hw = (H, W) + extrinsics, intrinsics = pose_encoding_to_extri_intri( + pose_enc, image_size_hw=image_size_hw, build_intrinsics=True + ) + + extrinsics_flat = extrinsics.view(B * S, 3, 4) + extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype) + extrinsics_4x4[:, :3, :] = extrinsics_flat + extrinsics_4x4[:, 3, 3] = 1.0 + c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4) + + y_grid, x_grid = torch.meshgrid( + torch.arange(H, device=device, dtype=dtype), + torch.arange(W, device=device, dtype=dtype), + indexing='ij' + ) + pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1) + + intrinsics_inv = torch.inverse(intrinsics) + camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords) + camera_points = camera_coords * depth + + ones = torch.ones_like(camera_points[..., :1]) + camera_points_h = torch.cat([camera_points, ones], dim=-1) + world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h) + + return world_points_h[..., :3] + + def forward( + self, + images: torch.Tensor, + query_points: Optional[torch.Tensor] = None, + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + mask: Optional[torch.Tensor] = None, + causal_inference: bool = False, + ordered_video: Optional[torch.Tensor] = None, + gather_outputs: bool = True, + point_masks: Optional[torch.Tensor] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass of the GCT model. + + Args: + images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1] + query_points: Optional query points [N, 2] or [B, N, 2] + + Returns: + Dictionary containing predictions: + - pose_enc: Camera pose encoding [B, S, 9] + - depth: Depth maps [B, S, H, W, 1] + - depth_conf: Depth confidence [B, S, H, W] + - world_points: 3D world coordinates [B, S, H, W, 3] + - world_points_conf: Point confidence [B, S, H, W] + """ + images, query_points = self._normalize_input(images, query_points) + + aggregated_tokens_list, patch_start_idx = self._aggregate_features( + images, + num_frame_for_scale=num_frame_for_scale, + sliding_window_size=sliding_window_size, + num_frame_per_block=num_frame_per_block, + ) + + predictions = {} + + predictions.update(self._predict_camera( + aggregated_tokens_list, + mask=ordered_video, + causal_inference=causal_inference, + num_frame_for_scale=num_frame_for_scale, + sliding_window_size=sliding_window_size, + num_frame_per_block=num_frame_per_block, + gather_outputs=gather_outputs, + )) + + predictions.update(self._predict_depth( + aggregated_tokens_list, images, patch_start_idx, + gather_outputs=gather_outputs, + )) + + predictions.update(self._predict_points( + aggregated_tokens_list, images, patch_start_idx, + gather_outputs=gather_outputs, + )) + + predictions.update(self._predict_local_points( + aggregated_tokens_list, images, patch_start_idx, + gather_outputs=gather_outputs, + )) + + if not self.training: + predictions["images"] = images + + return predictions diff --git a/lingbot_map/models/gct_stream.py b/lingbot_map/models/gct_stream.py new file mode 100644 index 0000000..030a090 --- /dev/null +++ b/lingbot_map/models/gct_stream.py @@ -0,0 +1,444 @@ +""" +GCTStream - Streaming GCT with KV cache for online inference. + +Provides streaming inference functionality: +- Temporal causal attention with KV cache +- Sliding window support +- Efficient frame-by-frame processing +- 3D RoPE support for temporal consistency +""" + +import logging +import torch +import torch.nn as nn +from typing import Optional, Dict, Any, List +from tqdm.auto import tqdm + +from lingbot_map.heads.camera_head import CameraCausalHead +from lingbot_map.models.gct_base import GCTBase +from lingbot_map.aggregator.stream import AggregatorStream + +logger = logging.getLogger(__name__) + + +class GCTStream(GCTBase): + """ + Streaming GCT model with KV cache for efficient online inference. + + Features: + - AggregatorStream with KV cache support (FlashInfer backend) + - CameraCausalHead for pose refinement + - Sliding window attention for memory efficiency + - Frame-by-frame streaming inference + """ + + def __init__( + self, + # Architecture parameters + img_size: int = 518, + patch_size: int = 14, + embed_dim: int = 1024, + patch_embed: str = 'dinov2_vitl14_reg', + pretrained_path: str = '', + disable_global_rope: bool = False, + # Head configuration + enable_camera: bool = True, + enable_point: bool = True, + enable_local_point: bool = False, + enable_depth: bool = True, + enable_track: bool = False, + # Normalization + enable_normalize: bool = False, + # Prediction normalization + pred_normalization: bool = False, + # Stream-specific parameters + sliding_window_size: int = -1, + num_frame_for_scale: int = 1, + num_random_frames: int = 0, + attend_to_special_tokens: bool = False, + attend_to_scale_frames: bool = False, + enable_stream_inference: bool = True, # Default to True for streaming + enable_3d_rope: bool = False, + max_frame_num: int = 1024, + # Camera head 3D RoPE (separate from aggregator 3D RoPE) + enable_camera_3d_rope: bool = False, + camera_rope_theta: float = 10000.0, + # Scale token configuration (kept for checkpoint compat, ignored) + use_scale_token: bool = True, + # KV cache parameters + kv_cache_sliding_window: int = 64, + kv_cache_scale_frames: int = 8, + kv_cache_cross_frame_special: bool = True, + kv_cache_include_scale_frames: bool = True, + kv_cache_camera_only: bool = False, + # Backend selection + use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer + # Gradient checkpointing + use_gradient_checkpoint: bool = True, + ): + """ + Initialize GCTStream. + + Args: + img_size: Input image size + patch_size: Patch size for embedding + embed_dim: Embedding dimension + patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.) + pretrained_path: Path to pretrained DINOv2 weights + disable_global_rope: Disable RoPE in global attention + enable_camera/point/depth/track: Enable prediction heads + enable_normalize: Enable normalization + sliding_window_size: Sliding window size in blocks (-1 for full causal) + num_frame_for_scale: Number of scale estimation frames + num_random_frames: Number of random frames for long-range dependencies + attend_to_special_tokens: Enable cross-frame special token attention + attend_to_scale_frames: Whether to attend to scale frames + enable_stream_inference: Enable streaming inference with KV cache + enable_3d_rope: Enable 3D RoPE for temporal consistency + max_frame_num: Maximum number of frames for 3D RoPE + use_scale_token: Kept for checkpoint compatibility, ignored + kv_cache_sliding_window: Sliding window size for KV cache eviction + kv_cache_scale_frames: Number of scale frames to keep in KV cache + kv_cache_cross_frame_special: Keep special tokens from evicted frames + kv_cache_include_scale_frames: Include scale frames in KV cache + kv_cache_camera_only: Only keep camera tokens from evicted frames + """ + # Store stream-specific parameters before calling super().__init__() + self.pretrained_path = pretrained_path + self.sliding_window_size = sliding_window_size + self.num_frame_for_scale = num_frame_for_scale + self.num_random_frames = num_random_frames + self.attend_to_special_tokens = attend_to_special_tokens + self.attend_to_scale_frames = attend_to_scale_frames + self.enable_stream_inference = enable_stream_inference + self.enable_3d_rope = enable_3d_rope + self.max_frame_num = max_frame_num + # Camera head 3D RoPE settings + self.enable_camera_3d_rope = enable_camera_3d_rope + self.camera_rope_theta = camera_rope_theta + # KV cache parameters + self.kv_cache_sliding_window = kv_cache_sliding_window + self.kv_cache_scale_frames = kv_cache_scale_frames + self.kv_cache_cross_frame_special = kv_cache_cross_frame_special + self.kv_cache_include_scale_frames = kv_cache_include_scale_frames + self.kv_cache_camera_only = kv_cache_camera_only + self.use_sdpa = use_sdpa + + # Call base class __init__ (will call _build_aggregator) + super().__init__( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + patch_embed=patch_embed, + disable_global_rope=disable_global_rope, + enable_camera=enable_camera, + enable_point=enable_point, + enable_local_point=enable_local_point, + enable_depth=enable_depth, + enable_track=enable_track, + enable_normalize=enable_normalize, + pred_normalization=pred_normalization, + enable_3d_rope=enable_3d_rope, + use_gradient_checkpoint=use_gradient_checkpoint, + ) + + def _build_aggregator(self) -> nn.Module: + """ + Build streaming aggregator with KV cache support (FlashInfer backend). + + Returns: + AggregatorStream module + """ + return AggregatorStream( + img_size=self.img_size, + patch_size=self.patch_size, + embed_dim=self.embed_dim, + patch_embed=self.patch_embed, + pretrained_path=self.pretrained_path, + disable_global_rope=self.disable_global_rope, + sliding_window_size=self.sliding_window_size, + num_frame_for_scale=self.num_frame_for_scale, + num_random_frames=self.num_random_frames, + attend_to_special_tokens=self.attend_to_special_tokens, + attend_to_scale_frames=self.attend_to_scale_frames, + enable_stream_inference=self.enable_stream_inference, + enable_3d_rope=self.enable_3d_rope, + max_frame_num=self.max_frame_num, + # Backend: FlashInfer (default) or SDPA (fallback) + use_flashinfer=not self.use_sdpa, + use_sdpa=self.use_sdpa, + kv_cache_sliding_window=self.kv_cache_sliding_window, + kv_cache_scale_frames=self.kv_cache_scale_frames, + kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, + kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, + kv_cache_camera_only=self.kv_cache_camera_only, + use_gradient_checkpoint=self.use_gradient_checkpoint, + ) + + def _build_camera_head(self) -> nn.Module: + """ + Build causal camera head for streaming inference. + + Returns: + CameraCausalHead module or None + """ + return CameraCausalHead( + dim_in=2 * self.embed_dim, + sliding_window_size=self.sliding_window_size, + attend_to_scale_frames=self.attend_to_scale_frames, + # KV cache parameters + kv_cache_sliding_window=self.kv_cache_sliding_window, + kv_cache_scale_frames=self.kv_cache_scale_frames, + kv_cache_cross_frame_special=self.kv_cache_cross_frame_special, + kv_cache_include_scale_frames=self.kv_cache_include_scale_frames, + kv_cache_camera_only=self.kv_cache_camera_only, + # Camera head 3D RoPE parameters + enable_3d_rope=self.enable_camera_3d_rope, + max_frame_num=self.max_frame_num, + rope_theta=self.camera_rope_theta, + ) + + def _aggregate_features( + self, + images: torch.Tensor, + num_frame_for_scale: Optional[int] = None, + sliding_window_size: Optional[int] = None, + num_frame_per_block: int = 1, + **kwargs, + ) -> tuple: + """ + Run aggregator to get multi-scale features. + + Args: + images: Input images [B, S, 3, H, W] + num_frame_for_scale: Number of frames for scale estimation + sliding_window_size: Override sliding window size + num_frame_per_block: Number of frames per block + + Returns: + (aggregated_tokens_list, patch_start_idx) + """ + aggregated_tokens_list, patch_start_idx = self.aggregator( + images, + selected_idx=[4, 11, 17, 23], + num_frame_for_scale=num_frame_for_scale, + sliding_window_size=sliding_window_size, + num_frame_per_block=num_frame_per_block, + ) + return aggregated_tokens_list, patch_start_idx + + def clean_kv_cache(self): + """ + Clean KV cache in aggregator. + + Call this method when starting a new video sequence to clear + cached key-value pairs from previous sequences. + """ + if hasattr(self.aggregator, 'clean_kv_cache'): + self.aggregator.clean_kv_cache() + else: + logger.warning("Aggregator does not support KV cache cleaning") + if hasattr(self.camera_head, 'kv_cache'): + self.camera_head.clean_kv_cache() + else: + logger.warning("Camera head does not support KV cache cleaning") + + def _set_skip_append(self, skip: bool): + """Set _skip_append flag on all KV caches (aggregator + camera head). + + When skip=True, attention layers will attend to [cached_kv + current_kv] + but will NOT store the current frame's KV in cache. This is used for + non-keyframe processing in keyframe-based streaming inference. + + Args: + skip: If True, subsequent forward passes will not append KV to cache. + """ + if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None: + self.aggregator.kv_cache["_skip_append"] = skip + if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None: + for cache_dict in self.camera_head.kv_cache: + cache_dict["_skip_append"] = skip + + def get_kv_cache_info(self) -> Dict[str, Any]: + """ + Get information about current KV cache state. + + Returns: + Dictionary with cache statistics: + - num_cached_blocks: Number of blocks with cached KV + - cache_memory_mb: Approximate memory usage in MB + """ + if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None: + return {"num_cached_blocks": 0, "cache_memory_mb": 0.0} + + kv_cache = self.aggregator.kv_cache + num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special')) + + # Estimate memory usage + total_elements = 0 + for _, v in kv_cache.items(): + if v is not None and torch.is_tensor(v): + total_elements += v.numel() + + # Assume bfloat16 (2 bytes per element) + cache_memory_mb = (total_elements * 2) / (1024 * 1024) + + return { + "num_cached_blocks": num_cached, + "cache_memory_mb": round(cache_memory_mb, 2) + } + + @torch.no_grad() + def inference_streaming( + self, + images: torch.Tensor, + num_scale_frames: Optional[int] = None, + keyframe_interval: int = 1, + output_device: Optional[torch.device] = None, + ) -> Dict[str, torch.Tensor]: + """ + Streaming inference: process scale frames first, then frame-by-frame. + + This method enables efficient online inference by: + 1. Processing initial scale frames together (bidirectional attention via scale token) + 2. Processing remaining frames one-by-one with KV cache (causal streaming) + + Keyframe mode (keyframe_interval > 1): + - Every keyframe_interval-th frame (after scale frames) is a keyframe + - Keyframes: KV is stored in cache (normal behavior) + - Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard) + - All frames produce full predictions regardless of keyframe status + - Reduces KV cache memory growth by ~1/keyframe_interval + + Args: + images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1] + num_scale_frames: Number of initial frames for scale estimation. + If None, uses self.num_frame_for_scale. + keyframe_interval: Every N-th frame (after scale frames) is a keyframe + whose KV persists in cache. 1 = every frame is a + keyframe (default, same as original behavior). + output_device: Device to store output predictions on. If None, keeps on + the same device as the model. Set to torch.device('cpu') + to offload predictions per-frame and avoid GPU OOM on + long sequences. + + Returns: + Dictionary containing predictions for all frames: + - pose_enc: [B, S, 9] + - depth: [B, S, H, W, 1] + - depth_conf: [B, S, H, W] + - world_points: [B, S, H, W, 3] + - world_points_conf: [B, S, H, W] + """ + # Normalize input shape + if len(images.shape) == 4: + images = images.unsqueeze(0) + B, S, C, H, W = images.shape + + # Determine number of scale frames + scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale + scale_frames = min(scale_frames, S) # Cap to available frames + + # Helper to move tensor to output device + def _to_out(t: torch.Tensor) -> torch.Tensor: + if output_device is not None: + return t.to(output_device) + return t + + # Clean KV caches before starting new sequence + self.clean_kv_cache() + + # Phase 1: Process scale frames together + # These frames get bidirectional attention among themselves via scale token + logger.info(f'Processing {scale_frames} scale frames...') + scale_images = images[:, :scale_frames] + scale_output = self.forward( + scale_images, + num_frame_for_scale=scale_frames, + num_frame_per_block=scale_frames, # Process all scale frames as one block + causal_inference=True, + ) + + # Initialize output lists with scale frame predictions (offload if needed) + all_pose_enc = [_to_out(scale_output["pose_enc"])] + all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else [] + all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else [] + all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else [] + all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else [] + del scale_output + + # Phase 2: Process remaining frames one-by-one + pbar = tqdm( + range(scale_frames, S), + desc='Streaming inference', + initial=scale_frames, + total=S, + ) + for i in pbar: + frame_image = images[:, i:i+1] + + # Determine if this frame is a keyframe + is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0) + + if not is_keyframe: + self._set_skip_append(True) + + frame_output = self.forward( + frame_image, + num_frame_for_scale=scale_frames, # Keep same for scale token logic + num_frame_per_block=1, # Single frame per block + causal_inference=True, + ) + + if not is_keyframe: + self._set_skip_append(False) + + all_pose_enc.append(_to_out(frame_output["pose_enc"])) + if "depth" in frame_output: + all_depth.append(_to_out(frame_output["depth"])) + if "depth_conf" in frame_output: + all_depth_conf.append(_to_out(frame_output["depth_conf"])) + if "world_points" in frame_output: + all_world_points.append(_to_out(frame_output["world_points"])) + if "world_points_conf" in frame_output: + all_world_points_conf.append(_to_out(frame_output["world_points_conf"])) + del frame_output + + # Free GPU memory before concatenation + if output_device is not None: + # Move images to output device, then free GPU copy + images_out = _to_out(images) + del images + # Clean KV cache (no longer needed after inference) + self.clean_kv_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + images_out = images + + # Concatenate all predictions along sequence dimension + predictions = { + "pose_enc": torch.cat(all_pose_enc, dim=1), + } + del all_pose_enc + if all_depth: + predictions["depth"] = torch.cat(all_depth, dim=1) + del all_depth + if all_depth_conf: + predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1) + del all_depth_conf + if all_world_points: + predictions["world_points"] = torch.cat(all_world_points, dim=1) + del all_world_points + if all_world_points_conf: + predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1) + del all_world_points_conf + + # Store images for visualization + predictions["images"] = images_out + + # Apply prediction normalization if enabled + if self.pred_normalization: + predictions = self._normalize_predictions(predictions) + + return predictions diff --git a/lingbot_map/utils/__init__.py b/lingbot_map/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lingbot_map/utils/geometry.py b/lingbot_map/utils/geometry.py new file mode 100644 index 0000000..9270af0 --- /dev/null +++ b/lingbot_map/utils/geometry.py @@ -0,0 +1,774 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np +from scipy.spatial.transform import Rotation as R + +from scipy.spatial.transform import Rotation +try: + from lietorch import SE3, Sim3 +except ImportError: + SE3 = Sim3 = None +import torch.nn.functional as F + +try: + from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion +except ImportError: + apply_distortion = iterative_undistortion = single_undistortion = None + + +def unproject_depth_map_to_point_map( + depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray +) -> np.ndarray: + """ + Unproject a batch of depth maps to 3D world coordinates. + + Args: + depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W) + extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4) + intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3) + + Returns: + np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3) + """ + if isinstance(depth_map, torch.Tensor): + depth_map = depth_map.cpu().numpy() + if isinstance(extrinsics_cam, torch.Tensor): + extrinsics_cam = extrinsics_cam.cpu().numpy() + if isinstance(intrinsics_cam, torch.Tensor): + intrinsics_cam = intrinsics_cam.cpu().numpy() + + world_points_list = [] + for frame_idx in range(depth_map.shape[0]): + cur_world_points, _, _ = depth_to_world_coords_points( + depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx] + ) + world_points_list.append(cur_world_points) + world_points_array = np.stack(world_points_list, axis=0) + + return world_points_array + + +def depth_to_world_coords_points( + depth_map: np.ndarray, + extrinsic: np.ndarray, + intrinsic: np.ndarray, + eps=1e-8, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Convert a depth map to world coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world. + + Returns: + tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W). + """ + if depth_map is None: + return None, None, None + + # Valid depth mask + point_mask = depth_map > eps + + # Convert depth map to camera coordinates + cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic) + + # Multiply with the inverse of extrinsic matrix to transform to world coordinates + # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4)) + cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0] + + R_cam_to_world = cam_to_world_extrinsic[:3, :3] + t_cam_to_world = cam_to_world_extrinsic[:3, 3] + + # Apply the rotation and translation to the camera coordinates + world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3 + # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world + + return world_coords_points, cam_coords_points, point_mask + + +def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """ + Convert a depth map to camera coordinates. + + Args: + depth_map (np.ndarray): Depth map of shape (H, W). + intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3). + + Returns: + tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3) + """ + H, W = depth_map.shape + assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3" + assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew" + + # Intrinsic parameters + fu, fv = intrinsic[0, 0], intrinsic[1, 1] + cu, cv = intrinsic[0, 2], intrinsic[1, 2] + + # Generate grid of pixel coordinates + u, v = np.meshgrid(np.arange(W), np.arange(H)) + + # Unproject to camera coordinates + x_cam = (u - cu) * depth_map / fu + y_cam = (v - cv) * depth_map / fv + z_cam = depth_map + + # Stack to form camera coordinates + cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32) + + return cam_coords + + +def closed_form_inverse_se3(se3, R=None, T=None): + """ + Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch. + + If `R` and `T` are provided, they must correspond to the rotation and translation + components of `se3`. Otherwise, they will be extracted from `se3`. + + Args: + se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices. + R (optional): Nx3x3 array or tensor of rotation matrices. + T (optional): Nx3x1 array or tensor of translation vectors. + + Returns: + Inverted SE3 matrices with the same type and device as `se3`. + + Shapes: + se3: (N, 4, 4) + R: (N, 3, 3) + T: (N, 3, 1) + """ + # Check if se3 is a numpy array or a torch tensor + is_numpy = isinstance(se3, np.ndarray) + + # Validate shapes + if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4): + raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.") + + # Extract R and T if not provided + if R is None: + R = se3[:, :3, :3] # (N,3,3) + if T is None: + T = se3[:, :3, 3:] # (N,3,1) + + # Transpose R + if is_numpy: + # Compute the transpose of the rotation for NumPy + R_transposed = np.transpose(R, (0, 2, 1)) + # -R^T t for NumPy + top_right = -np.matmul(R_transposed, T) + inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1)) + else: + R_transposed = R.transpose(1, 2) # (N,3,3) + top_right = -torch.bmm(R_transposed, T) # (N,3,1) + inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1) + inverted_matrix = inverted_matrix.to(R.dtype).to(R.device) + + inverted_matrix[:, :3, :3] = R_transposed + inverted_matrix[:, :3, 3:] = top_right + + return inverted_matrix + +def closed_form_inverse_se3_general(se3, R=None, T=None): + """ + 支持任意 batch 维度的 SE3 逆运算 + se3: (..., 4, 4) 或 (..., 3, 4) + """ + batch_shape = se3.shape[:-2] + if R is None: + R = se3[..., :3, :3] + if T is None: + T = se3[..., :3, 3:] + R_transposed = R.transpose(-2, -1) + top_right = -R_transposed @ T + # 构造单位阵 + eye = torch.eye(4, 4, dtype=R.dtype, device=R.device) + inverted_matrix = eye.expand(*batch_shape, 4, 4).clone() + inverted_matrix[..., :3, :3] = R_transposed + inverted_matrix[..., :3, 3:] = top_right + return inverted_matrix + + +# TODO: this code can be further cleaned up + + +def project_world_points_to_camera_points_batch(world_points, cam_extrinsics): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape BxSxHxWx3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4. + Returns: + """ + # TODO: merge this into project_world_points_to_cam + + # device = world_points.device + # with torch.autocast(device_type=device.type, enabled=False): + ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1) + world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4) + + # extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4) + extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3) + + # world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1) + world_points_h_exp = world_points_h.unsqueeze(-1) + + # Now perform the matrix multiplication + # (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1) + camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1) + + return camera_points + + + +def project_world_points_to_cam( + world_points, + cam_extrinsics, + cam_intrinsics=None, + distortion_params=None, + default=0, + only_points_cam=False, +): + """ + Transforms 3D points to 2D using extrinsic and intrinsic parameters. + Args: + world_points (torch.Tensor): 3D points of shape Px3. + cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4. + cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3. + distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion. + Returns: + torch.Tensor: Transformed 2D points of shape BxNx2. + """ + device = world_points.device + # with torch.autocast(device_type=device.type, dtype=torch.double): + with torch.autocast(device_type=device.type, enabled=False): + N = world_points.shape[0] # Number of points + B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras + world_points_homogeneous = torch.cat( + [world_points, torch.ones_like(world_points[..., 0:1])], dim=1 + ) # Nx4 + # Reshape for batch processing + world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand( + B, -1, -1 + ) # BxNx4 + + # Step 1: Apply extrinsic parameters + # Transform 3D points to camera coordinate system for all cameras + cam_points = torch.bmm( + cam_extrinsics, world_points_homogeneous.transpose(-1, -2) + ) + + if only_points_cam: + return None, cam_points + + # Step 2: Apply intrinsic parameters and (optional) distortion + image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default) + + return image_points, cam_points + + + +def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0): + """ + Applies intrinsic parameters and optional distortion to the given 3D points. + + Args: + cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3. + cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN. + distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + default (float, optional): Default value to replace NaNs in the output. + + Returns: + pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2. + """ + + # Normalized device coordinates (NDC) + cam_points = cam_points / cam_points[:, 2:3, :] + ndc_xy = cam_points[:, :2, :] + + # Apply distortion if distortion_params are provided + if distortion_params is not None: + x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1]) + distorted_xy = torch.stack([x_distorted, y_distorted], dim=1) + else: + distorted_xy = ndc_xy + + # Prepare cam_points for batch matrix multiplication + cam_coords_homo = torch.cat( + (distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1 + ) # Bx3xN + # Apply intrinsic parameters using batch matrix multiplication + pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN + + # Extract x and y coordinates + pixel_coords = pixel_coords[:, :2, :] # Bx2xN + + # Replace NaNs with default value + pixel_coords = torch.nan_to_num(pixel_coords, nan=default) + + return pixel_coords.transpose(1, 2) # BxNx2 + + + + +def cam_from_img(pred_tracks, intrinsics, extra_params=None): + """ + Normalize predicted tracks based on camera intrinsics. + Args: + intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3]. + pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2]. + extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4. + Returns: + torch.Tensor: Normalized tracks tensor. + """ + + # We don't want to do intrinsics_inv = torch.inverse(intrinsics) here + # otherwise we can use something like + # tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2)) + + principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2) + focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2) + tracks_normalized = (pred_tracks - principal_point) / focal_length + + if extra_params is not None: + # Apply iterative undistortion + try: + tracks_normalized = iterative_undistortion( + extra_params, tracks_normalized + ) + except: + tracks_normalized = single_undistortion( + extra_params, tracks_normalized + ) + + return tracks_normalized + +## Droid SLAM Part + +MIN_DEPTH = 0.2 + +def extract_intrinsics(intrinsics): + return intrinsics[...,None,None,:].unbind(dim=-1) + +def projective_transform( + poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False +): + """map points from ii->jj""" + + # inverse project (pinhole) + X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian) + + # transform + Gij = poses[:, jj] * poses[:, ii].inv() + + # Gij.data[:, ii == jj] = torch.as_tensor( + # [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda" + # ) + X1, Ja = actp(Gij, X0, jacobian=jacobian) + + # project (pinhole) + x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth) + + # exclude points too close to camera + valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float() + valid = valid.unsqueeze(-1) + + if jacobian: + # Ji transforms according to dual adjoint + Jj = torch.matmul(Jp, Ja) + Ji = -Gij[:, :, None, None, None].adjT(Jj) + + Jz = Gij[:, :, None, None] * Jz + Jz = torch.matmul(Jp, Jz.unsqueeze(-1)) + + return x1, valid, (Ji, Jj, Jz) + + return x1, valid + + +def induced_flow(poses, disps, intrinsics, ii, jj): + """optical flow induced by camera motion""" + + ht, wd = disps.shape[2:] + y, x = torch.meshgrid( + torch.arange(ht, device=disps.device, dtype=torch.float), + torch.arange(wd, device=disps.device, dtype=torch.float), + indexing="ij", + ) + + coords0 = torch.stack([x, y], dim=-1) + coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False) + + return coords1[..., :2] - coords0, valid + +def all_pairs_distance_matrix(poses, beta=2.5): + """ compute distance matrix between all pairs of poses """ + poses = np.array(poses, dtype=np.float32) + poses[:,:3] *= beta # scale to balence rot + trans + poses = SE3(torch.from_numpy(poses)) + + r = (poses[:,None].inv() * poses[None,:]).log() + return r.norm(dim=-1).cpu().numpy() + +def pose_matrix_to_quaternion(pose): + """ convert 4x4 pose matrix to (t, q) """ + q = Rotation.from_matrix(pose[..., :3, :3]).as_quat() + return np.concatenate([pose[..., :3, 3], q], axis=-1) + +def compute_distance_matrix_flow(poses, disps, intrinsics): + """ compute flow magnitude between all pairs of frames """ + if not isinstance(poses, SE3): + poses = torch.from_numpy(poses).float().cuda()[None] + poses = SE3(poses).inv() + + disps = torch.from_numpy(disps).float().cuda()[None] + intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] + + N = poses.shape[1] + + ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) + ii = ii.reshape(-1).cuda() + jj = jj.reshape(-1).cuda() + + MAX_FLOW = 100.0 + matrix = np.zeros((N, N), dtype=np.float32) + + s = 2048 + for i in range(0, ii.shape[0], s): + flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) + flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s]) + + flow = torch.stack([flow1, flow2], dim=2) + val = torch.stack([val1, val2], dim=2) + + mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) + mag = mag.view(mag.shape[1], -1) + val = val.view(val.shape[1], -1) + + mag = (mag * val).mean(-1) / val.mean(-1) + mag[val.mean(-1) < 0.7] = np.inf + + i1 = ii[i:i+s].cpu().numpy() + j1 = jj[i:i+s].cpu().numpy() + matrix[i1, j1] = mag.cpu().numpy() + + return matrix + + +def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4): + """ compute flow magnitude between all pairs of frames """ + # if not isinstance(poses, SE3): + # poses = torch.from_numpy(poses).float().cuda()[None] + # poses = SE3(poses).inv() + + # disps = torch.from_numpy(disps).float().cuda()[None] + # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None] + + N = poses.shape[1] + + ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N)) + ii = ii.reshape(-1) + jj = jj.reshape(-1) + + MAX_FLOW = 128.0 + matrix = np.zeros((N, N), dtype=np.float32) + + s = 2048 + for i in range(0, ii.shape[0], s): + flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True) + flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) + flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True) + flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s]) + + flow1 = flow1a + beta * flow1b + val1 = val1a * val2b + + flow2 = flow2a + beta * flow2b + val2 = val2a * val2b + + flow = torch.stack([flow1, flow2], dim=2) + val = torch.stack([val1, val2], dim=2) + + mag = flow.norm(dim=-1).clamp(max=MAX_FLOW) + mag = mag.view(mag.shape[1], -1) + val = val.view(val.shape[1], -1) + + mag = (mag * val).mean(-1) / val.mean(-1) + mag[val.mean(-1) < 0.8] = np.inf + + i1 = ii[i:i+s].cpu().numpy() + j1 = jj[i:i+s].cpu().numpy() + matrix[i1, j1] = mag.cpu().numpy() + + return matrix + +def coords_grid(ht, wd, **kwargs): + y, x = torch.meshgrid( + torch.arange(ht, dtype=torch.float, **kwargs), + torch.arange(wd, dtype=torch.float, **kwargs), + indexing="ij", + ) + + return torch.stack([x, y], dim=-1) + + +def iproj(disps, intrinsics, jacobian=False): + """pinhole camera inverse projection""" + ht, wd = disps.shape[2:] + fx, fy, cx, cy = extract_intrinsics(intrinsics) + + y, x = torch.meshgrid( + torch.arange(ht, device=disps.device, dtype=torch.float), + torch.arange(wd, device=disps.device, dtype=torch.float), + indexing="ij", + ) + + i = torch.ones_like(disps) + X = (x - cx) / fx + Y = (y - cy) / fy + pts = torch.stack([X, Y, i, disps], dim=-1) + + if jacobian: + J = torch.zeros_like(pts) + J[..., -1] = 1.0 + return pts, J + + return pts, None + + +def proj(Xs, intrinsics, jacobian=False, return_depth=False): + """pinhole camera projection""" + fx, fy, cx, cy = extract_intrinsics(intrinsics) + X, Y, Z, D = Xs.unbind(dim=-1) + + Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z) + d = 1.0 / Z + + x = fx * (X * d) + cx + y = fy * (Y * d) + cy + if return_depth: + coords = torch.stack([x, y, D * d], dim=-1) + else: + coords = torch.stack([x, y], dim=-1) + + if jacobian: + B, N, H, W = d.shape + o = torch.zeros_like(d) + proj_jac = torch.stack( + [ + fx * d, + o, + -fx * X * d * d, + o, + o, + fy * d, + -fy * Y * d * d, + o, + # o, o, -D*d*d, d, + ], + dim=-1, + ).view(B, N, H, W, 2, 4) + + return coords, proj_jac + + return coords, None + + +def actp(Gij, X0, jacobian=False): + """action on point cloud""" + X1 = Gij[:, :, None, None] * X0 + + if jacobian: + X, Y, Z, d = X1.unbind(dim=-1) + o = torch.zeros_like(d) + B, N, H, W = d.shape + + if isinstance(Gij, SE3): + Ja = torch.stack( + [ + d, + o, + o, + o, + Z, + -Y, + o, + d, + o, + -Z, + o, + X, + o, + o, + d, + Y, + -X, + o, + o, + o, + o, + o, + o, + o, + ], + dim=-1, + ).view(B, N, H, W, 4, 6) + + elif isinstance(Gij, Sim3): + Ja = torch.stack( + [ + d, + o, + o, + o, + Z, + -Y, + X, + o, + d, + o, + -Z, + o, + X, + Y, + o, + o, + d, + Y, + -X, + o, + Z, + o, + o, + o, + o, + o, + o, + o, + ], + dim=-1, + ).view(B, N, H, W, 4, 7) + + return X1, Ja + + return X1, None + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + +def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + return standardize_quaternion(out) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + quaternions = F.normalize(quaternions, p=2, dim=-1) + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + +def umeyama(X, Y): + """ + Estimates the Sim(3) transformation between `X` and `Y` point sets. + + Estimates c, R and t such as c * R @ X + t ~ Y. + + Parameters + ---------- + X : numpy.array + (m, n) shaped numpy array. m is the dimension of the points, + n is the number of points in the point set. + Y : numpy.array + (m, n) shaped numpy array. Indexes should be consistent with `X`. + That is, Y[:, i] must be the point corresponding to X[:, i]. + + Returns + ------- + c : float + Scale factor. + R : numpy.array + (3, 3) shaped rotation matrix. + t : numpy.array + (3, 1) shaped translation vector. + """ + mu_x = X.mean(axis=1).reshape(-1, 1) + mu_y = Y.mean(axis=1).reshape(-1, 1) + var_x = np.square(X - mu_x).sum(axis=0).mean() + cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1] + U, D, VH = np.linalg.svd(cov_xy) + S = np.eye(X.shape[0]) + if np.linalg.det(U) * np.linalg.det(VH) < 0: + S[-1, -1] = -1 + c = np.trace(np.diag(D) @ S) / var_x + R = U @ S @ VH + t = mu_y - c * R @ mu_x + return c, R, t diff --git a/lingbot_map/utils/load_fn.py b/lingbot_map/utils/load_fn.py new file mode 100644 index 0000000..40209b7 --- /dev/null +++ b/lingbot_map/utils/load_fn.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from PIL import Image +from torchvision import transforms as TF +import numpy as np + + +def load_and_preprocess_images_square(image_path_list, target_size=1024): + """ + Load and preprocess images by center padding to square and resizing to target size. + Also returns the position information of original pixels after transformation. + + Args: + image_path_list (list): List of paths to image files + target_size (int, optional): Target size for both width and height. Defaults to 518. + + Returns: + tuple: ( + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size), + torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image + ) + + Raises: + ValueError: If the input list is empty + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + images = [] + original_coords = [] # Renamed from position_info to be more descriptive + to_tensor = TF.ToTensor() + + for image_path in image_path_list: + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background + if img.mode == "RGBA": + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(background, img) + + # Convert to RGB + img = img.convert("RGB") + + # Get original dimensions + width, height = img.size + + # Make the image square by padding the shorter dimension + max_dim = max(width, height) + + # Calculate padding + left = (max_dim - width) // 2 + top = (max_dim - height) // 2 + + # Calculate scale factor for resizing + scale = target_size / max_dim + + # Calculate final coordinates of original image in target space + x1 = left * scale + y1 = top * scale + x2 = (left + width) * scale + y2 = (top + height) * scale + + # Store original image coordinates and scale + original_coords.append(np.array([x1, y1, x2, y2, width, height])) + + # Create a new black square image and paste original + square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) + square_img.paste(img, (left, top)) + + # Resize to target size + square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC) + + # Convert to tensor + img_tensor = to_tensor(square_img) + images.append(img_tensor) + + # Stack all images + images = torch.stack(images) + original_coords = torch.from_numpy(np.array(original_coords)).float() + + # Add additional dimension if single image to ensure correct shape + if len(image_path_list) == 1: + if images.dim() == 3: + images = images.unsqueeze(0) + original_coords = original_coords.unsqueeze(0) + + return images, original_coords + + +def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16): + """ + A quick start function to load and preprocess images for model input. + This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes. + + Args: + image_path_list (list): List of paths to image files + mode (str, optional): Preprocessing mode, either "crop" or "pad". + - "crop" (default): Sets width to 518px and center crops height if needed. + - "pad": Preserves all pixels by making the largest dimension 518px + and padding the smaller dimension to reach a square shape. + + Returns: + torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W) + + Raises: + ValueError: If the input list is empty or if mode is invalid + + Notes: + - Images with different dimensions will be padded with white (value=1.0) + - A warning is printed when images have different shapes + - When mode="crop": The function ensures width=518px while maintaining aspect ratio + and height is center-cropped if larger than 518px + - When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio + and the smaller dimension is padded to reach a square shape (518x518) + - Dimensions are adjusted to be divisible by 14 for compatibility with model requirements + """ + # Check for empty list + if len(image_path_list) == 0: + raise ValueError("At least 1 image is required") + + + + # Validate mode + if mode not in ["crop", "pad"]: + raise ValueError("Mode must be either 'crop' or 'pad'") + + images = [] + shapes = set() + to_tensor = TF.ToTensor() + target_size = image_size + + # First process all images and collect their shapes + for i, image_path in enumerate(image_path_list): + # Open image + img = Image.open(image_path) + + # If there's an alpha channel, blend onto white background: + if img.mode == "RGBA": + # Create white background + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) + # Alpha composite onto the white background + img = Image.alpha_composite(background, img) + + # Now convert to "RGB" (this step assigns white for transparent areas) + img = img.convert("RGB") + + width, height = img.size + + if fx is not None: + fx[i] = fx[i] * width + fy[i] = fy[i] * height + cx[i] = cx[i] * width + cy[i] = cy[i] * height + + if mode == "pad": + # Make the largest dimension 518px while maintaining aspect ratio + if width >= height: + new_width = target_size + new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14 + else: + new_height = target_size + new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14 + + else: # mode == "crop" + # Original behavior: set width to 518px + new_width = target_size + # Calculate height maintaining aspect ratio, divisible by 14 + new_height = round(height * (new_width / width) / patch_size) * patch_size + + # Resize with new dimensions (width, height) + img = img.resize((new_width, new_height), Image.Resampling.BICUBIC) + img = to_tensor(img) # Convert to tensor (0, 1) + + # Center crop height if it's larger than 518 (only in crop mode) + if mode == "crop" and new_height > target_size: + start_y = (new_height - target_size) // 2 + img = img[:, start_y : start_y + target_size, :] + if fx is not None: + fx[i] = fx[i] * new_width / width + fy[i] = fy[i] * new_height / height + + cx[i] = img.shape[2] / 2 + cy[i] = img.shape[1] / 2 + + # For pad mode, pad to make a square of target_size x target_size + if mode == "pad": + h_padding = target_size - img.shape[1] + w_padding = target_size - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + # Pad with white (value=1.0) + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + + shapes.add((img.shape[1], img.shape[2])) + images.append(img) + + # Check if we have different shapes + # In theory our model can also work well with different shapes + if len(shapes) > 1: + print(f"Warning: Found images with different shapes: {shapes}") + # Find maximum dimensions + max_height = max(shape[0] for shape in shapes) + max_width = max(shape[1] for shape in shapes) + + # Pad images if necessary + padded_images = [] + for img in images: + h_padding = max_height - img.shape[1] + w_padding = max_width - img.shape[2] + + if h_padding > 0 or w_padding > 0: + pad_top = h_padding // 2 + pad_bottom = h_padding - pad_top + pad_left = w_padding // 2 + pad_right = w_padding - pad_left + + img = torch.nn.functional.pad( + img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0 + ) + padded_images.append(img) + images = padded_images + + images = torch.stack(images) # concatenate images + + # Ensure correct shape when single image + if len(image_path_list) == 1: + # Verify shape is (1, C, H, W) + if images.dim() == 3: + images = images.unsqueeze(0) + if fx is not None: + return images, fx, fy, cx, cy + return images diff --git a/lingbot_map/utils/pose_enc.py b/lingbot_map/utils/pose_enc.py new file mode 100644 index 0000000..9d029fd --- /dev/null +++ b/lingbot_map/utils/pose_enc.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from .rotation import quat_to_mat, mat_to_quat +import os +import torch +import numpy as np +import gzip +import json +import random +import logging +import warnings + +from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general + + +def extri_intri_to_pose_encoding( + extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512) +): + """Convert camera extrinsics and intrinsics to a compact pose encoding. + + This function transforms camera parameters into a unified pose encoding format, + which can be used for various downstream tasks like pose prediction or representation. + + Args: + extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4, + where B is batch size and S is sequence length. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation. + The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector. + intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3. + Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for computing field of view values. For example: (256, 512). + pose_encoding_type (str): Type of pose encoding to use. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + + Returns: + torch.Tensor: Encoded camera pose parameters with shape BxSx9. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + """ + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + + if pose_encoding_type == "absT_quaR_FoV": + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + else: + raise NotImplementedError + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512) +): + """Convert a pose encoding back to camera extrinsics and intrinsics. + + This function performs the inverse operation of extri_intri_to_pose_encoding, + reconstructing the full camera parameters from the compact encoding. + + Args: + pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9, + where B is batch size and S is sequence length. + For "absT_quaR_FoV" type, the 9 dimensions are: + - [:3] = absolute translation vector T (3D) + - [3:7] = rotation as quaternion quat (4D) + - [7:] = field of view (2D) + image_size_hw (tuple): Tuple of (height, width) of the image in pixels. + Required for reconstructing intrinsics from field of view values. + For example: (256, 512). + pose_encoding_type (str): Type of pose encoding used. Currently only + supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view). + build_intrinsics (bool): Whether to reconstruct the intrinsics matrix. + If False, only extrinsics are returned and intrinsics will be None. + + Returns: + tuple: (extrinsics, intrinsics) + - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4. + In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world + transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is + a 3x1 translation vector. + - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3, + or None if build_intrinsics is False. Defined in pixels, with format: + [[fx, 0, cx], + [0, fy, cy], + [0, 0, 1]] + where fx, fy are focal lengths and (cx, cy) is the principal point, + assumed to be at the center of the image (W/2, H/2). + """ + + intrinsics = None + + if pose_encoding_type == "absT_quaR_FoV": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + if build_intrinsics: + H, W = image_size_hw + fy = (H / 2.0) / torch.tan(fov_h / 2.0) + fx = (W / 2.0) / torch.tan(fov_w / 2.0) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + elif pose_encoding_type == "absT_quaR": + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + intrinsics = None + + return extrinsics, intrinsics + +def convert_pt3d_RT_to_opencv(Rot, Trans): + """ + Convert Point3D extrinsic matrices to OpenCV convention. + + Args: + Rot: 3D rotation matrix in Point3D format + Trans: 3D translation vector in Point3D format + + Returns: + extri_opencv: 3x4 extrinsic matrix in OpenCV format + """ + rot_pt3d = np.array(Rot) + trans_pt3d = np.array(Trans) + + trans_pt3d[:2] *= -1 + rot_pt3d[:, :2] *= -1 + rot_pt3d = rot_pt3d.transpose(1, 0) + extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None])) + return extri_opencv + + +def build_pair_index(N, B=1): + """ + Build indices for all possible pairs of frames. + + Args: + N: Number of frames + B: Batch size + + Returns: + i1, i2: Indices for all possible pairs + """ + i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1) + i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]] + return i1, i2 + + +def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15): + """ + Calculate rotation angle error between ground truth and predicted rotations. + + Args: + rot_gt: Ground truth rotation matrices + rot_pred: Predicted rotation matrices + batch_size: Batch size for reshaping the result + eps: Small value to avoid numerical issues + + Returns: + Rotation angle error in degrees + """ + q_pred = mat_to_quat(rot_pred) + q_gt = mat_to_quat(rot_gt) + + loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps) + err_q = torch.arccos(1 - 2 * loss_q) + + rel_rangle_deg = err_q * 180 / np.pi + + if batch_size is not None: + rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1) + + return rel_rangle_deg + + +def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True): + """ + Calculate translation angle error between ground truth and predicted translations. + + Args: + tvec_gt: Ground truth translation vectors + tvec_pred: Predicted translation vectors + batch_size: Batch size for reshaping the result + ambiguity: Whether to handle direction ambiguity + + Returns: + Translation angle error in degrees + """ + rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred) + rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi + + if ambiguity: + rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs()) + + if batch_size is not None: + rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1) + + return rel_tangle_deg + + +def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6): + """ + Normalize the translation vectors and compute the angle between them. + + Args: + t_gt: Ground truth translation vectors + t: Predicted translation vectors + eps: Small value to avoid division by zero + default_err: Default error value for invalid cases + + Returns: + Angular error between translation vectors in radians + """ + t_norm = torch.norm(t, dim=1, keepdim=True) + t = t / (t_norm + eps) + + t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) + t_gt = t_gt / (t_gt_norm + eps) + + loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) + err_t = torch.acos(torch.sqrt(1 - loss_t)) + + err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err + return err_t + + +def calculate_auc_np(r_error, t_error, max_threshold=30): + """ + Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy. + + Args: + r_error: numpy array representing R error values (Degree) + t_error: numpy array representing T error values (Degree) + max_threshold: Maximum threshold value for binning the histogram + + Returns: + AUC value and the normalized histogram + """ + error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1) + max_errors = np.max(error_matrix, axis=1) + bins = np.arange(max_threshold + 1) + histogram, _ = np.histogram(max_errors, bins=bins) + num_pairs = float(len(max_errors)) + normalized_histogram = histogram.astype(float) / num_pairs + return np.mean(np.cumsum(normalized_histogram)), normalized_histogram + + +def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames): + """ + Compute rotation and translation errors between predicted and ground truth poses. + This function assumes the input poses are world-to-camera (w2c) transformations. + + Args: + pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4) + gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4) + num_frames: Number of frames (N) + + Returns: + Rotation and translation angle errors in degrees + """ + pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames) + + relative_pose_gt = gt_se3[pair_idx_i1].bmm( + closed_form_inverse_se3(gt_se3[pair_idx_i2]) + ) + relative_pose_pred = pred_se3[pair_idx_i1].bmm( + closed_form_inverse_se3(pred_se3[pair_idx_i2]) + ) + + rel_rangle_deg = rotation_angle( + relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3] + ) + rel_tangle_deg = translation_angle( + relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3] + ) + + return rel_rangle_deg, rel_tangle_deg + + +def colmap_to_opencv_intrinsics(K): + """ + Modify camera intrinsics to follow a different convention. + Coordinates of the center of the top-left pixels are by default: + - (0.5, 0.5) in Colmap + - (0,0) in OpenCV + """ + K = K.copy() + K[..., 0, 2] -= 0.5 + K[..., 1, 2] -= 0.5 + return K + +def read_camera_parameters(filename): + with open(filename) as f: + lines = f.readlines() + lines = [line.rstrip() for line in lines] + # extrinsics: line [1,5), 4x4 matrix + extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) + # intrinsics: line [7-10), 3x3 matrix + intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) + + return intrinsics, extrinsics \ No newline at end of file diff --git a/lingbot_map/utils/rotation.py b/lingbot_map/utils/rotation.py new file mode 100644 index 0000000..f972afd --- /dev/null +++ b/lingbot_map/utils/rotation.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d + +import torch +import numpy as np +import torch.nn.functional as F + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1) + + q_abs = _sqrt_positive_part( + torch.stack( + [1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1 + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,)) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) diff --git a/lingbot_map/vis/__init__.py b/lingbot_map/vis/__init__.py new file mode 100644 index 0000000..c0422e6 --- /dev/null +++ b/lingbot_map/vis/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +GCT Visualization Module + +This module provides visualization utilities for 3D reconstruction results: +- PointCloudViewer: Interactive point cloud viewer with camera visualization +- viser_wrapper: Quick visualization wrapper for predictions +- predictions_to_glb: Export predictions to GLB 3D format +- Colorization and utility functions + +Usage: + from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb + + # Interactive visualization + viewer = PointCloudViewer(pred_dict=predictions, port=8080) + viewer.run() + + # Quick visualization + viser_wrapper(predictions, port=8080) + + # Export to GLB + scene = predictions_to_glb(predictions) + scene.export("output.glb") +""" + +from lingbot_map.vis.point_cloud_viewer import PointCloudViewer +from lingbot_map.vis.viser_wrapper import viser_wrapper +from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar +from lingbot_map.vis.sky_segmentation import ( + apply_sky_segmentation, + download_skyseg_model, + load_or_create_sky_masks, + segment_sky, +) +from lingbot_map.vis.glb_export import predictions_to_glb + +__all__ = [ + # Main viewer + "PointCloudViewer", + # Quick visualization + "viser_wrapper", + # GLB export + "predictions_to_glb", + # Utilities + "CameraState", + "colorize", + "colorize_np", + "get_vertical_colorbar", + # Sky segmentation + "apply_sky_segmentation", + "segment_sky", + "download_skyseg_model", + "load_or_create_sky_masks", +] diff --git a/lingbot_map/vis/glb_export.py b/lingbot_map/vis/glb_export.py new file mode 100644 index 0000000..b2ccd74 --- /dev/null +++ b/lingbot_map/vis/glb_export.py @@ -0,0 +1,509 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +GLB 3D export utilities for GCT predictions. +""" + +import os +import copy +from typing import Optional, Tuple + +import numpy as np +import cv2 +import matplotlib +from scipy.spatial.transform import Rotation + +from lingbot_map.vis.sky_segmentation import ( + _SKYSEG_INPUT_SIZE, + _SKYSEG_SOFT_THRESHOLD, + _mask_to_float, + _mask_to_uint8, + _result_map_to_non_sky_conf, +) + +try: + import trimesh +except ImportError: + trimesh = None + print("trimesh not found. GLB export will not work.") + + +def predictions_to_glb( + predictions: dict, + conf_thres: float = 50.0, + filter_by_frames: str = "all", + mask_black_bg: bool = False, + mask_white_bg: bool = False, + show_cam: bool = True, + mask_sky: bool = False, + target_dir: Optional[str] = None, + prediction_mode: str = "Predicted Pointmap", +) -> "trimesh.Scene": + """ + Converts GCT predictions to a 3D scene represented as a GLB file. + + Args: + predictions: Dictionary containing model predictions with keys: + - world_points: 3D point coordinates (S, H, W, 3) + - world_points_conf: Confidence scores (S, H, W) + - images: Input images (S, H, W, 3) or (S, 3, H, W) + - extrinsic: Camera extrinsic matrices (S, 3, 4) + conf_thres: Percentage of low-confidence points to filter out + filter_by_frames: Frame filter specification ("all" or frame index) + mask_black_bg: Mask out black background pixels + mask_white_bg: Mask out white background pixels + show_cam: Include camera visualization + mask_sky: Apply sky segmentation mask + target_dir: Output directory for intermediate files + prediction_mode: "Predicted Pointmap" or "Predicted Depthmap" + + Returns: + trimesh.Scene: Processed 3D scene containing point cloud and cameras + + Raises: + ValueError: If input predictions structure is invalid + ImportError: If trimesh is not available + """ + if trimesh is None: + raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh") + + if not isinstance(predictions, dict): + raise ValueError("predictions must be a dictionary") + + if conf_thres is None: + conf_thres = 10.0 + + print("Building GLB scene") + + # Parse frame filter + selected_frame_idx = None + if filter_by_frames != "all" and filter_by_frames != "All": + try: + selected_frame_idx = int(filter_by_frames.split(":")[0]) + except (ValueError, IndexError): + pass + + # Select prediction source + if "Pointmap" in prediction_mode: + print("Using Pointmap Branch") + if "world_points" in predictions: + pred_world_points = predictions["world_points"] + pred_world_points_conf = predictions.get( + "world_points_conf", np.ones_like(pred_world_points[..., 0]) + ) + else: + print("Warning: world_points not found, falling back to depth-based points") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get( + "depth_conf", np.ones_like(pred_world_points[..., 0]) + ) + else: + print("Using Depthmap and Camera Branch") + pred_world_points = predictions["world_points_from_depth"] + pred_world_points_conf = predictions.get( + "depth_conf", np.ones_like(pred_world_points[..., 0]) + ) + + images = predictions["images"] + camera_matrices = predictions["extrinsic"] + + # Apply sky segmentation if enabled + if mask_sky and target_dir is not None: + pred_world_points_conf = _apply_sky_mask( + pred_world_points_conf, target_dir, images + ) + + # Apply frame filter + if selected_frame_idx is not None: + pred_world_points = pred_world_points[selected_frame_idx][None] + pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None] + images = images[selected_frame_idx][None] + camera_matrices = camera_matrices[selected_frame_idx][None] + + # Prepare vertices and colors + vertices_3d = pred_world_points.reshape(-1, 3) + + # Handle different image formats + if images.ndim == 4 and images.shape[1] == 3: # NCHW format + colors_rgb = np.transpose(images, (0, 2, 3, 1)) + else: + colors_rgb = images + colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + + # Apply confidence filtering + conf = pred_world_points_conf.reshape(-1) + conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0 + conf_mask = (conf >= conf_threshold) & (conf > 1e-5) + + # Apply background masking + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + conf_mask = conf_mask & black_bg_mask + + if mask_white_bg: + white_bg_mask = ~( + (colors_rgb[:, 0] > 240) & + (colors_rgb[:, 1] > 240) & + (colors_rgb[:, 2] > 240) + ) + conf_mask = conf_mask & white_bg_mask + + vertices_3d = vertices_3d[conf_mask] + colors_rgb = colors_rgb[conf_mask] + + # Handle empty point cloud + if vertices_3d is None or np.asarray(vertices_3d).size == 0: + vertices_3d = np.array([[1, 0, 0]]) + colors_rgb = np.array([[255, 255, 255]]) + scene_scale = 1 + else: + lower_percentile = np.percentile(vertices_3d, 5, axis=0) + upper_percentile = np.percentile(vertices_3d, 95, axis=0) + scene_scale = np.linalg.norm(upper_percentile - lower_percentile) + + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + + # Build scene + scene_3d = trimesh.Scene() + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + # Prepare camera matrices + num_cameras = len(camera_matrices) + extrinsics_matrices = np.zeros((num_cameras, 4, 4)) + extrinsics_matrices[:, :3, :4] = camera_matrices + extrinsics_matrices[:, 3, 3] = 1 + + # Add cameras + if show_cam: + for i in range(num_cameras): + world_to_camera = extrinsics_matrices[i] + camera_to_world = np.linalg.inv(world_to_camera) + rgba_color = colormap(i / num_cameras) + current_color = tuple(int(255 * x) for x in rgba_color[:3]) + integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale) + + # Align scene + scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices) + + print("GLB Scene built") + return scene_3d + + +def _apply_sky_mask( + conf: np.ndarray, + target_dir: str, + images: np.ndarray +) -> np.ndarray: + """Apply sky segmentation mask to confidence scores.""" + try: + import onnxruntime + except ImportError: + print("Warning: onnxruntime not available, skipping sky masking") + return conf + + target_dir_images = os.path.join(target_dir, "images") + if not os.path.exists(target_dir_images): + print(f"Warning: Images directory not found at {target_dir_images}") + return conf + + image_list = sorted(os.listdir(target_dir_images)) + S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2]) + + skyseg_model_path = "skyseg.onnx" + if not os.path.exists(skyseg_model_path): + print("Downloading skyseg.onnx...") + download_file_from_url( + "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", + skyseg_model_path + ) + + skyseg_session = onnxruntime.InferenceSession(skyseg_model_path) + sky_mask_list = [] + + for i, image_name in enumerate(image_list[:S]): + image_filepath = os.path.join(target_dir_images, image_name) + mask_filepath = os.path.join(target_dir, "sky_masks", image_name) + + if os.path.exists(mask_filepath): + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + else: + sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath) + + if sky_mask.shape[0] != H or sky_mask.shape[1] != W: + sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR) + + sky_mask_list.append(_mask_to_float(sky_mask)) + + sky_mask_array = np.array(sky_mask_list) + sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32) + return conf * sky_mask_binary + + +def integrate_camera_into_scene( + scene: "trimesh.Scene", + transform: np.ndarray, + face_colors: Tuple[int, int, int], + scene_scale: float, + frustum_thickness: float = 1.0, +): + """ + Integrates a camera mesh into the 3D scene. + + Args: + scene: The 3D scene to add the camera model + transform: Transformation matrix for camera positioning + face_colors: RGB color tuple for the camera + scene_scale: Scale of the scene + frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker) + """ + cam_width = scene_scale * 0.05 + cam_height = scene_scale * 0.1 + + rot_45_degree = np.eye(4) + rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix() + rot_45_degree[2, 3] = -cam_height + + opengl_transform = get_opengl_conversion_matrix() + complete_transform = transform @ opengl_transform @ rot_45_degree + camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4) + + # Build thicker frustum by stacking rotated copies + slight_rotation = np.eye(4) + slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix() + + shell_scales = [1.0, 0.95] + shell_transforms = [np.eye(4), slight_rotation] + # Add extra shells for thickness + if frustum_thickness > 1.0: + n_extra = max(1, int(frustum_thickness - 1)) + for k in range(1, n_extra + 1): + # Progressively rotated and scaled copies + angle = 2.0 + k * 2.0 + scale = 1.0 + k * 0.02 + rot = np.eye(4) + rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix() + shell_scales.append(scale) + shell_transforms.append(rot) + rot_neg = np.eye(4) + rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix() + shell_scales.append(scale) + shell_transforms.append(rot_neg) + + vertices_parts = [] + for s, t_mat in zip(shell_scales, shell_transforms): + vertices_parts.append( + transform_points(t_mat, s * camera_cone_shape.vertices) + ) + vertices_combined = np.concatenate(vertices_parts) + vertices_transformed = transform_points(complete_transform, vertices_combined) + + mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales)) + camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces) + camera_mesh.visual.face_colors[:, :3] = face_colors + scene.add_geometry(camera_mesh) + + +def apply_scene_alignment( + scene_3d: "trimesh.Scene", + extrinsics_matrices: np.ndarray +) -> "trimesh.Scene": + """ + Aligns the 3D scene based on the extrinsics of the first camera. + + Args: + scene_3d: The 3D scene to be aligned + extrinsics_matrices: Camera extrinsic matrices + + Returns: + Aligned 3D scene + """ + opengl_conversion_matrix = get_opengl_conversion_matrix() + + align_rotation = np.eye(4) + align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix() + + initial_transformation = ( + np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation + ) + scene_3d.apply_transform(initial_transformation) + return scene_3d + + +def get_opengl_conversion_matrix() -> np.ndarray: + """Returns the OpenGL conversion matrix (flips Y and Z axes).""" + matrix = np.identity(4) + matrix[1, 1] = -1 + matrix[2, 2] = -1 + return matrix + + +def transform_points( + transformation: np.ndarray, + points: np.ndarray, + dim: Optional[int] = None +) -> np.ndarray: + """ + Applies a 4x4 transformation to a set of points. + + Args: + transformation: Transformation matrix + points: Points to be transformed + dim: Dimension for reshaping the result + + Returns: + Transformed points + """ + points = np.asarray(points) + initial_shape = points.shape[:-1] + dim = dim or points.shape[-1] + + transformation = transformation.swapaxes(-1, -2) + points = points @ transformation[..., :-1, :] + transformation[..., -1:, :] + + return points[..., :dim].reshape(*initial_shape, dim) + + +def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray: + """Computes the faces for the camera mesh.""" + faces_list = [] + num_vertices_cone = len(cone_shape.vertices) + + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + v1_offset, v2_offset, v3_offset = face + num_vertices_cone + v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone + + faces_list.extend([ + (v1, v2, v2_offset), + (v1, v1_offset, v3), + (v3_offset, v2, v3), + (v1, v2, v2_offset_2), + (v1, v1_offset_2, v3), + (v3_offset_2, v2, v3), + ]) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray: + """Computes faces for a camera mesh with multiple shells (for thicker frustums). + + Connects each consecutive pair of vertex shells to form the frustum edges. + """ + faces_list = [] + nv = len(cone_shape.vertices) + + for s in range(num_shells - 1): + off_a = s * nv + off_b = (s + 1) * nv + for face in cone_shape.faces: + if 0 in face: + continue + v1, v2, v3 = face + faces_list.extend([ + (v1 + off_a, v2 + off_a, v2 + off_b), + (v1 + off_a, v1 + off_b, v3 + off_a), + (v3 + off_b, v2 + off_a, v3 + off_a), + ]) + + faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list] + return np.array(faces_list) + + +def segment_sky( + image_path: str, + onnx_session, + mask_filename: str +) -> np.ndarray: + """ + Segments sky from an image using an ONNX model. + + Args: + image_path: Path to input image + onnx_session: ONNX runtime session with loaded model + mask_filename: Path to save the output mask + + Returns: + Continuous non-sky confidence map in [0, 1] + """ + image = cv2.imread(image_path) + result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image) + result_map_original = cv2.resize( + result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR + ) + output_mask = _result_map_to_non_sky_conf(result_map_original) + + os.makedirs(os.path.dirname(mask_filename), exist_ok=True) + cv2.imwrite(mask_filename, _mask_to_uint8(output_mask)) + return output_mask + + +def run_skyseg( + onnx_session, + input_size: Tuple[int, int], + image: np.ndarray +) -> np.ndarray: + """ + Runs sky segmentation inference using ONNX model. + + Args: + onnx_session: ONNX runtime session + input_size: Target size for model input (width, height) + image: Input image in BGR format + + Returns: + Segmentation mask + """ + temp_image = copy.deepcopy(image) + resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) + x = np.array(x, dtype=np.float32) + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + x = (x / 255 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32") + + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + onnx_result = (onnx_result - min_value) / (max_value - min_value) + onnx_result *= 255 + return onnx_result.astype("uint8") + + +def download_file_from_url(url: str, filename: str): + """Downloads a file from a URL, handling redirects.""" + import requests + + try: + response = requests.get(url, allow_redirects=False) + response.raise_for_status() + + if response.status_code == 302: + redirect_url = response.headers["Location"] + response = requests.get(redirect_url, stream=True) + response.raise_for_status() + else: + print(f"Unexpected status code: {response.status_code}") + return + + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Downloaded {filename} successfully.") + + except requests.exceptions.RequestException as e: + print(f"Error downloading file: {e}") diff --git a/lingbot_map/vis/point_cloud_viewer.py b/lingbot_map/vis/point_cloud_viewer.py new file mode 100644 index 0000000..b90f281 --- /dev/null +++ b/lingbot_map/vis/point_cloud_viewer.py @@ -0,0 +1,1780 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Interactive 3D Point Cloud Viewer using Viser. + +This module provides the PointCloudViewer class for visualizing 3D reconstruction results, +including point clouds, camera poses, and animated playback. +""" + +import os +import time +import threading +import subprocess +import tempfile +import shutil +from typing import List, Optional, Dict, Any, Tuple + +import numpy as np +import torch +import cv2 +import matplotlib.cm as cm +from tqdm.auto import tqdm + +import viser +import viser.transforms as tf + +from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map +from lingbot_map.vis.utils import CameraState +from lingbot_map.vis.sky_segmentation import apply_sky_segmentation + + +class PointCloudViewer: + """ + Interactive 3D point cloud viewer with camera visualization. + + Features: + - Point cloud visualization with confidence-based filtering + - Camera frustum visualization with gradient colors + - Frame-by-frame playback animation (3D/4D modes) + - Range-based and recent-N-frames visualization modes + - Navigation mode (camera following) + - Video export with FFmpeg + + Args: + model: Optional model for interactive inference + state_args: Optional state arguments + pc_list: List of point clouds per frame + color_list: List of colors per frame + conf_list: List of confidence scores per frame + cam_dict: Camera dictionary with focal, pp, R, t + image_mask: Optional image mask + edge_color_list: Optional edge colors + device: Device for computation + port: Viser server port + show_camera: Whether to show camera frustums + vis_threshold: Visibility threshold for filtering + size: Image size + downsample_factor: Point cloud downsample factor + point_size: Initial point size + pred_dict: Prediction dictionary (alternative to pc_list/color_list/conf_list) + init_conf_threshold: Initial confidence threshold percentage + use_point_map: Use point map instead of depth-based points + mask_sky: Apply sky segmentation + image_folder: Path to image folder (for sky segmentation) + """ + + def __init__( + self, + model=None, + state_args=None, + pc_list=None, + color_list=None, + conf_list=None, + cam_dict=None, + image_mask=None, + edge_color_list=None, + device: str = "cpu", + port: int = 8080, + show_camera: bool = True, + vis_threshold: float = 1.0, + size: int = 512, + downsample_factor: int = 10, + point_size: float = 0.00001, + pred_dict: Optional[Dict] = None, + init_conf_threshold: float = 50.0, + use_point_map: bool = False, + mask_sky: bool = False, + image_folder: Optional[str] = None, + depth_stride: int = 1, + ): + self.model = model + self.size = size + self.state_args = state_args + self.server = viser.ViserServer(host="0.0.0.0", port=port) + self.server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") + self.device = device + self.conf_list = conf_list + self.vis_threshold = vis_threshold + self.point_size = point_size + self.tt = lambda x: torch.from_numpy(x).float().to(device) + + # Process the prediction dictionary to create pc_list, color_list, conf_list + if pred_dict is not None: + pc_list, color_list, conf_list, cam_dict = self._process_pred_dict( + pred_dict, use_point_map, mask_sky, image_folder, + depth_stride=depth_stride, + ) + else: + self.original_images = [] + self.tsdf_depth_maps = None + self.tsdf_extrinsics = None + self.tsdf_intrinsics = None + self.tsdf_images = None + + self.pcs, self.all_steps = self.read_data( + pc_list, color_list, conf_list, edge_color_list + ) + self.cam_dict = cam_dict + self.num_frames = len(self.all_steps) + self.image_mask = image_mask + self.show_camera = show_camera + self.on_replay = False + self.vis_pts_list = [] + self.traj_list = [] + self.orig_img_list = [x[0] for x in color_list if len(x) > 0] if color_list else [] + self.via_points = [] + + self._setup_gui() + self.server.on_client_connect(self._connect_client) + + def _process_pred_dict( + self, + pred_dict: Dict, + use_point_map: bool, + mask_sky: bool, + image_folder: Optional[str], + depth_stride: int = 1, + ) -> Tuple[List, List, List, Dict]: + """Process prediction dictionary to extract visualization data. + + Args: + pred_dict: Model prediction dictionary. + use_point_map: Use point map instead of depth-based projection. + mask_sky: Apply sky segmentation to filter sky points. + image_folder: Path to images for sky segmentation. + depth_stride: Only project depth to point cloud every N frames. + Frames not projected will have empty point clouds but still + show camera frustums and images. 1 = every frame (default). + """ + images = pred_dict["images"] # (S, 3, H, W) + + depth_map = pred_dict.get("depth") # (S, H, W, 1) + depth_conf = pred_dict.get("depth_conf") # (S, H, W) + + extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) + intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) + + # Compute world points from depth if not using the precomputed point map + if not use_point_map: + world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) + conf = depth_conf + else: + world_points = pred_dict["world_points"] # (S, H, W, 3) + conf = pred_dict.get("world_points_conf", depth_conf) # (S, H, W) + + # Apply sky segmentation if enabled + if mask_sky: + conf = apply_sky_segmentation(conf, image_folder=image_folder, images=images) + + # Convert images from (S, 3, H, W) to (S, H, W, 3) + colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) + S = world_points.shape[0] + + # Store raw data for TSDF fusion + self.tsdf_depth_maps = depth_map # (S, H, W, 1) + self.tsdf_extrinsics = extrinsics_cam # (S, 3, 4) camera-from-world + self.tsdf_intrinsics = intrinsics_cam # (S, 3, 3) + self.tsdf_images = images # (S, 3, H, W) + + # Store original images for camera frustum display + self.original_images = [] + for i in range(S): + img = images[i] # shape (3, H, W) + img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) + self.original_images.append(img) + + # Create lists - apply depth_stride to skip frames for point projection + H, W = world_points.shape[1], world_points.shape[2] + pc_list = [] + color_list = [] + conf_list = [] + skipped = 0 + for i in range(S): + if depth_stride > 1 and i % depth_stride != 0: + # Empty point cloud for skipped frames + pc_list.append(np.zeros((0, 0, 3), dtype=np.float32)) + color_list.append(np.zeros((0, 0, 3), dtype=np.float32)) + conf_list.append(np.zeros((0, 0), dtype=np.float32)) + skipped += 1 + else: + pc_list.append(world_points[i]) + color_list.append(colors[i]) + if conf is not None: + conf_list.append(conf[i]) + else: + conf_list.append(np.ones(world_points[i].shape[:2], dtype=np.float32)) + + if depth_stride > 1: + print(f' depth_stride={depth_stride}: projecting {S - skipped}/{S} frames, skipping {skipped}') + + # Create camera dictionary (all frames keep cameras) + cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) + cam_dict = { + "focal": [intrinsics_cam[i, 0, 0] for i in range(S)], + "pp": [(intrinsics_cam[i, 0, 2], intrinsics_cam[i, 1, 2]) for i in range(S)], + "R": [cam_to_world_mat[i, :3, :3] for i in range(S)], + "t": [cam_to_world_mat[i, :3, 3] for i in range(S)], + } + + return pc_list, color_list, conf_list, cam_dict + + def _compute_scene_center_and_scale(self) -> Tuple[np.ndarray, float]: + """Compute scene center and scale from camera positions and point clouds. + + Returns: + Tuple of (center as 3D array, scale as float distance). + """ + # Use camera positions as primary reference (more reliable than noisy points) + if self.cam_dict is not None and "t" in self.cam_dict: + cam_positions = np.array([self.cam_dict["t"][s] for s in self.all_steps]) + center = np.mean(cam_positions, axis=0) + if len(cam_positions) > 1: + extent = np.ptp(cam_positions, axis=0) # range per axis + scale = np.linalg.norm(extent) + else: + scale = 1.0 + else: + # Fallback: use point cloud data + all_pts = [] + for step in self.all_steps: + pc = self.pcs[step]["pc"].reshape(-1, 3) + # subsample for speed + if len(pc) > 1000: + pc = pc[::len(pc) // 1000] + all_pts.append(pc) + all_pts = np.concatenate(all_pts, axis=0) + center = np.median(all_pts, axis=0) + extent = np.percentile(all_pts, 95, axis=0) - np.percentile(all_pts, 5, axis=0) + scale = np.linalg.norm(extent) + + return center, max(scale, 0.1) + + def _reset_view_to_direction( + self, + direction: np.ndarray, + up: np.ndarray = np.array([0.0, -1.0, 0.0]), + distance_scale: float = 1.5, + smooth: bool = True, + ): + """Reset the viewer camera to look at scene center from a given direction. + + Args: + direction: Unit vector pointing FROM the scene center TO the camera. + up: Up vector for the camera. + distance_scale: Multiplier on scene scale for camera distance. + smooth: Whether to smoothly transition. + """ + center, scale = self._compute_scene_center_and_scale() + distance = scale * distance_scale + position = center + direction * distance + + for client in self.server.get_clients().values(): + if smooth: + self._smooth_camera_transition( + client, + target_position=position, + target_look_at=center, + target_up=up, + duration=0.4, + ) + else: + client.camera.up_direction = tuple(up) + client.camera.position = tuple(position) + client.camera.look_at = tuple(center) + + def _setup_gui(self): + """Setup GUI controls.""" + gui_reset_up = self.server.gui.add_button( + "Reset up direction", + hint="Set the camera control 'up' direction to the current camera's 'up'.", + ) + + @gui_reset_up.on_click + def _(event: viser.GuiEvent) -> None: + client = event.client + assert client is not None + client.camera.up_direction = tf.SO3(client.camera.wxyz) @ np.array( + [0.0, -1.0, 0.0] + ) + + # Preset view direction buttons + with self.server.gui.add_folder("Reset View Direction"): + btn_look_at_center = self.server.gui.add_button( + "Look At Scene Center", + hint="Reset orbit center to the scene center (fixes orbit after dragging).", + ) + btn_overview = self.server.gui.add_button( + "Overview", + hint="Reset to a 3/4 overview of the scene.", + ) + btn_front = self.server.gui.add_button( + "Front (+Z)", + hint="View scene from the front.", + ) + btn_back = self.server.gui.add_button( + "Back (-Z)", + hint="View scene from the back.", + ) + btn_top = self.server.gui.add_button( + "Top (-Y)", + hint="View scene from above (bird's eye).", + ) + btn_left = self.server.gui.add_button( + "Left (-X)", + hint="View scene from the left.", + ) + btn_right = self.server.gui.add_button( + "Right (+X)", + hint="View scene from the right.", + ) + btn_first_cam = self.server.gui.add_button( + "First Camera", + hint="Reset to the first camera's viewpoint.", + ) + + @btn_look_at_center.on_click + def _(_) -> None: + center, _ = self._compute_scene_center_and_scale() + for client in self.server.get_clients().values(): + client.camera.look_at = tuple(center) + + @btn_overview.on_click + def _(_) -> None: + d = np.array([0.5, -0.6, 0.6]) + self._reset_view_to_direction(d / np.linalg.norm(d)) + + @btn_front.on_click + def _(_) -> None: + self._reset_view_to_direction(np.array([0.0, 0.0, 1.0])) + + @btn_back.on_click + def _(_) -> None: + self._reset_view_to_direction(np.array([0.0, 0.0, -1.0])) + + @btn_top.on_click + def _(_) -> None: + self._reset_view_to_direction( + np.array([0.0, -1.0, 0.0]), + up=np.array([0.0, 0.0, 1.0]), + ) + + @btn_left.on_click + def _(_) -> None: + self._reset_view_to_direction(np.array([-1.0, 0.0, 0.0])) + + @btn_right.on_click + def _(_) -> None: + self._reset_view_to_direction(np.array([1.0, 0.0, 0.0])) + + @btn_first_cam.on_click + def _(_) -> None: + self._move_to_camera(0, smooth=True) + + button3 = self.server.gui.add_button("4D (Only Show Current Frame)") + button4 = self.server.gui.add_button("3D (Show All Frames)") + self.is_render = False + self.fourd = False + + @button3.on_click + def _(event: viser.GuiEvent) -> None: + self.fourd = True + + @button4.on_click + def _(event: viser.GuiEvent) -> None: + self.fourd = False + + self.focal_slider = self.server.gui.add_slider( + "Focal Length", min=0.1, max=99999, step=1, initial_value=533 + ) + self.psize_slider = self.server.gui.add_slider( + "Point Size", min=0.00001, max=0.1, step=0.00001, initial_value=self.point_size + ) + self.camsize_slider = self.server.gui.add_slider( + "Camera Size", min=0.01, max=0.5, step=0.01, initial_value=0.1 + ) + self.downsample_slider = self.server.gui.add_slider( + "Downsample Factor", min=1, max=1000, step=1, initial_value=10 + ) + self.show_camera_checkbox = self.server.gui.add_checkbox( + "Show Camera", initial_value=self.show_camera + ) + self.vis_threshold_slider = self.server.gui.add_slider( + "Visibility Threshold", min=0.1, max=30.0, step=0.1, initial_value=self.vis_threshold + ) + self.camera_downsample_slider = self.server.gui.add_slider( + "Camera Downsample Factor", min=1, max=50, step=1, initial_value=1 + ) + + # Point cloud filtering controls + with self.server.gui.add_folder("Point Cloud Filtering"): + self.conf_percentile_slider = self.server.gui.add_slider( + "Confidence Percentile (%)", + min=0, max=95, step=1, initial_value=0, + hint="Remove the lowest N% of points by confidence. 0 = disabled.", + ) + self.bbox_clip_slider = self.server.gui.add_slider( + "Bounding Box Keep (%)", + min=50.0, max=100.0, step=0.5, initial_value=100.0, + hint="Keep the central N% of points per axis. 100 = no clipping.", + ) + self.sor_checkbox = self.server.gui.add_checkbox( + "Statistical Outlier Removal", + initial_value=False, + hint="Remove isolated floating points based on KNN distance.", + ) + self.sor_neighbors_slider = self.server.gui.add_slider( + "SOR Neighbors (K)", + min=5, max=50, step=1, initial_value=20, disabled=True, + hint="Number of nearest neighbors for outlier detection.", + ) + self.sor_std_slider = self.server.gui.add_slider( + "SOR Std Ratio", + min=0.5, max=5.0, step=0.1, initial_value=2.0, disabled=True, + hint="Lower = more aggressive filtering. Points beyond mean + ratio*std are removed.", + ) + self.filter_apply_button = self.server.gui.add_button( + "Apply Filters", + hint="Regenerate point clouds with current filter settings.", + ) + + @self.sor_checkbox.on_update + def _(_) -> None: + self.sor_neighbors_slider.disabled = not self.sor_checkbox.value + self.sor_std_slider.disabled = not self.sor_checkbox.value + + @self.filter_apply_button.on_click + def _(_) -> None: + self._regenerate_point_clouds() + + # TSDF Fusion controls + with self.server.gui.add_folder("TSDF Fusion"): + self.tsdf_voxel_size_slider = self.server.gui.add_slider( + "Voxel Size", min=0.001, max=0.1, step=0.001, initial_value=0.01, + hint="TSDF voxel size. Smaller = finer detail but slower.", + ) + self.tsdf_sdf_trunc_slider = self.server.gui.add_slider( + "SDF Truncation", min=0.01, max=0.5, step=0.01, initial_value=0.04, + hint="Truncation distance. Typically 3-5x voxel size.", + ) + self.tsdf_depth_scale_slider = self.server.gui.add_slider( + "Depth Scale", min=1.0, max=10000.0, step=1.0, initial_value=1.0, + hint="Depth scale factor. 1.0 if depth is in meters.", + ) + self.tsdf_depth_trunc_slider = self.server.gui.add_slider( + "Depth Truncation", min=0.5, max=50.0, step=0.5, initial_value=5.0, + hint="Max depth value to integrate (meters).", + ) + self.tsdf_run_button = self.server.gui.add_button( + "Run TSDF Fusion", + hint="Fuse all frames into a single point cloud via TSDF.", + ) + self.tsdf_clear_button = self.server.gui.add_button( + "Clear TSDF Result", + hint="Remove the TSDF fused point cloud from the scene.", + ) + self.tsdf_status = self.server.gui.add_text( + "Status", initial_value="Ready", + ) + + self._tsdf_handle = None + + @self.tsdf_run_button.on_click + def _(_) -> None: + self._run_tsdf_fusion() + + @self.tsdf_clear_button.on_click + def _(_) -> None: + if self._tsdf_handle is not None: + try: + self._tsdf_handle.remove() + except (KeyError, AttributeError): + pass + self._tsdf_handle = None + self.tsdf_status.value = "Cleared" + + # Range visualization controls + with self.server.gui.add_folder("Frame Range Control"): + self.range_mode_checkbox = self.server.gui.add_checkbox("Range Mode", initial_value=False) + self.range_start_slider = self.server.gui.add_slider( + "Start Frame", min=0, max=len(self.all_steps) - 1, step=1, initial_value=0, disabled=True + ) + self.range_end_slider = self.server.gui.add_slider( + "End Frame", min=0, max=len(self.all_steps) - 1, step=1, + initial_value=len(self.all_steps) - 1, disabled=True + ) + self.recent_n_mode_checkbox = self.server.gui.add_checkbox("Recent N Frames Mode", initial_value=False) + self.recent_n_slider = self.server.gui.add_slider( + "Recent N Frames", min=1, max=len(self.all_steps), step=1, + initial_value=min(10, len(self.all_steps)), disabled=True + ) + + # Navigation mode controls + with self.server.gui.add_folder("Navigation Mode"): + self.navigation_mode_checkbox = self.server.gui.add_checkbox("Follow Camera", initial_value=False) + self.smooth_navigation_checkbox = self.server.gui.add_checkbox("Smooth Transition", initial_value=True) + self.navigation_offset_slider = self.server.gui.add_slider( + "Camera Offset", min=0.0, max=2.0, step=0.05, initial_value=0.5 + ) + self.navigation_fov_checkbox = self.server.gui.add_checkbox("Match FOV", initial_value=True) + self.go_to_camera_button = self.server.gui.add_button("Go to Current Camera") + + @self.go_to_camera_button.on_click + def _(_) -> None: + if hasattr(self, 'gui_timestep'): + self._move_to_camera(self.gui_timestep.value, smooth=self.smooth_navigation_checkbox.value) + + # Video frame display controls + with self.server.gui.add_folder("Video Display"): + self.show_video_checkbox = self.server.gui.add_checkbox("Show Current Frame", initial_value=True) + if hasattr(self, 'original_images') and len(self.original_images) > 0: + self.current_frame_image = self.server.gui.add_image( + self.original_images[0], label="Current Frame" + ) + else: + self.current_frame_image = None + + # Screenshot controls + with self.server.gui.add_folder("Screenshot"): + self.screenshot_button = self.server.gui.add_button("Take Screenshot") + self.screenshot_resolution = self.server.gui.add_dropdown( + "Resolution", + options=["1920x1080", "2560x1440", "3840x2160", "Current"], + initial_value="1920x1080", + ) + self.screenshot_path = self.server.gui.add_text( + "Save Path", initial_value="screenshot.png" + ) + self.screenshot_status = self.server.gui.add_text( + "Status", initial_value="Ready" + ) + + @self.screenshot_button.on_click + def _(event: viser.GuiEvent) -> None: + self._take_screenshot(event.client) + + # GLB export controls + with self.server.gui.add_folder("Export GLB"): + self.glb_output_path = self.server.gui.add_text( + "Output Path", initial_value="export.glb" + ) + self.glb_show_cam_checkbox = self.server.gui.add_checkbox( + "Include Cameras", initial_value=True, + ) + self.glb_cam_scale_slider = self.server.gui.add_slider( + "Camera Scale", min=0.01, max=5.0, step=0.01, initial_value=1.0, + hint="Scale factor for camera size in GLB.", + ) + self.glb_frustum_thickness_slider = self.server.gui.add_slider( + "Frustum Thickness", min=1.0, max=10.0, step=0.5, initial_value=3.0, + hint="Thickness multiplier for camera frustum edges.", + ) + self.glb_trajectory_checkbox = self.server.gui.add_checkbox( + "Show Trajectory", initial_value=True, + hint="Connect cameras with a trajectory line.", + ) + self.glb_trajectory_radius_slider = self.server.gui.add_slider( + "Trajectory Radius", min=0.001, max=0.05, step=0.001, initial_value=0.005, + hint="Radius of the trajectory tube.", + ) + self.glb_mode_dropdown = self.server.gui.add_dropdown( + "Export Mode", + options=["Points", "Spheres"], + initial_value="Points", + hint="Points: raw (fast). Spheres: each point becomes a small sphere (prettier, slower).", + ) + self.glb_sphere_radius_slider = self.server.gui.add_slider( + "Sphere Radius", min=0.001, max=0.1, step=0.001, initial_value=0.005, + hint="Radius of each sphere in Spheres mode.", + disabled=True, + ) + self.glb_max_sphere_pts_slider = self.server.gui.add_slider( + "Max Sphere Points", min=10000, max=500000, step=10000, initial_value=100000, + hint="Cap point count for Spheres mode to keep file size manageable.", + disabled=True, + ) + self.glb_opacity_slider = self.server.gui.add_slider( + "Opacity", min=0.0, max=1.0, step=0.05, initial_value=1.0, + hint="Point/sphere opacity (alpha). <1.0 = semi-transparent.", + ) + self.glb_saturation_slider = self.server.gui.add_slider( + "Saturation Boost", min=0.0, max=2.0, step=0.1, initial_value=1.0, + hint="Color saturation multiplier. >1 = more vivid, <1 = washed out.", + ) + self.glb_brightness_slider = self.server.gui.add_slider( + "Brightness Boost", min=0.5, max=2.0, step=0.1, initial_value=1.0, + hint="Color brightness multiplier.", + ) + self.glb_export_button = self.server.gui.add_button( + "Export GLB", + hint="Export current filtered point clouds and cameras as GLB.", + ) + self.glb_status = self.server.gui.add_text("Status", initial_value="Ready") + + @self.glb_mode_dropdown.on_update + def _(_) -> None: + is_sphere = self.glb_mode_dropdown.value == "Spheres" + self.glb_sphere_radius_slider.disabled = not is_sphere + self.glb_max_sphere_pts_slider.disabled = not is_sphere + + @self.glb_export_button.on_click + def _(_) -> None: + self._export_glb() + + # Video saving controls + with self.server.gui.add_folder("Video Saving"): + self.save_video_button = self.server.gui.add_button("Save Video", disabled=False) + self.video_output_path = self.server.gui.add_text("Output Path", initial_value="output_pointcloud.mp4") + self.video_save_fps = self.server.gui.add_slider("Video FPS", min=10, max=60, step=1, initial_value=30) + self.video_resolution = self.server.gui.add_dropdown( + "Resolution", options=["1920x1080", "1280x720", "3840x2160"], initial_value="1920x1080" + ) + self.save_original_video_checkbox = self.server.gui.add_checkbox("Also Save Original Video", initial_value=True) + self.video_status = self.server.gui.add_text("Status", initial_value="Ready to save") + + @self.save_video_button.on_click + def _(_) -> None: + self.save_video( + output_path=self.video_output_path.value, + fps=self.video_save_fps.value, + resolution=self.video_resolution.value, + save_original_video=self.save_original_video_checkbox.value + ) + + @self.show_video_checkbox.on_update + def _(_) -> None: + if self.current_frame_image is not None: + self.current_frame_image.visible = self.show_video_checkbox.value + + self.pc_handles = [] + self.cam_handles = [] + + @self.psize_slider.on_update + def _(_) -> None: + for handle in self.pc_handles: + handle.point_size = self.psize_slider.value + + @self.camsize_slider.on_update + def _(_) -> None: + for handle in self.cam_handles: + handle.scale = self.camsize_slider.value + handle.line_thickness = 0.03 * handle.scale + + @self.downsample_slider.on_update + def _(_) -> None: + self._regenerate_point_clouds() + + @self.show_camera_checkbox.on_update + def _(_) -> None: + self.show_camera = self.show_camera_checkbox.value + if self.show_camera: + self._regenerate_cameras() + else: + for handle in self.cam_handles: + handle.visible = False + + @self.vis_threshold_slider.on_update + def _(_) -> None: + self.vis_threshold = self.vis_threshold_slider.value + self._regenerate_point_clouds() + + @self.camera_downsample_slider.on_update + def _(_) -> None: + self._regenerate_cameras() + + @self.range_mode_checkbox.on_update + def _(_) -> None: + self.range_start_slider.disabled = not self.range_mode_checkbox.value + self.range_end_slider.disabled = not self.range_mode_checkbox.value + if self.range_mode_checkbox.value: + self.recent_n_mode_checkbox.value = False + if hasattr(self, 'frame_nodes'): + self.update_frame_visibility() + + @self.recent_n_mode_checkbox.on_update + def _(_) -> None: + self.recent_n_slider.disabled = not self.recent_n_mode_checkbox.value + if self.recent_n_mode_checkbox.value: + self.range_mode_checkbox.value = False + if hasattr(self, 'frame_nodes'): + self.update_frame_visibility() + + @self.recent_n_slider.on_update + def _(_) -> None: + if hasattr(self, 'frame_nodes'): + self.update_frame_visibility() + + @self.range_start_slider.on_update + def _(_) -> None: + if self.range_start_slider.value > self.range_end_slider.value: + self.range_end_slider.value = self.range_start_slider.value + if hasattr(self, 'frame_nodes'): + self.update_frame_visibility() + + @self.range_end_slider.on_update + def _(_) -> None: + if self.range_end_slider.value < self.range_start_slider.value: + self.range_start_slider.value = self.range_end_slider.value + if hasattr(self, 'frame_nodes'): + self.update_frame_visibility() + + def _regenerate_point_clouds(self): + """Regenerate all point clouds with current settings.""" + if not hasattr(self, 'frame_nodes'): + return + + for handle in self.pc_handles: + try: + handle.remove() + except (KeyError, AttributeError): + pass + self.pc_handles.clear() + self.vis_pts_list.clear() + + for i, step in enumerate(self.all_steps): + pc = self.pcs[step]["pc"] + color = self.pcs[step]["color"] + conf = self.pcs[step]["conf"] + edge_color = self.pcs[step].get("edge_color", None) + + pred_pts, pc_color = self.parse_pc_data( + pc, color, conf, edge_color, set_border_color=True, + downsample_factor=self.downsample_slider.value + ) + + self.vis_pts_list.append(pred_pts) + handle = self.server.scene.add_point_cloud( + name=f"/frames/{step}/pred_pts", + points=pred_pts, + colors=pc_color, + point_size=self.psize_slider.value, + ) + self.pc_handles.append(handle) + + def _regenerate_cameras(self): + """Regenerate camera visualizations with current settings.""" + if not hasattr(self, 'frame_nodes'): + return + + for handle in self.cam_handles: + try: + handle.remove() + except (KeyError, AttributeError): + pass + self.cam_handles.clear() + + if self.show_camera: + downsample_factor = int(self.camera_downsample_slider.value) + for i, step in enumerate(self.all_steps): + if i % downsample_factor == 0: + self.add_camera(step) + + def _run_tsdf_fusion(self): + """Run TSDF fusion on all frames and display result as a point cloud.""" + if not hasattr(self, 'tsdf_depth_maps') or self.tsdf_depth_maps is None: + self.tsdf_status.value = "Error: no depth data (need pred_dict)" + return + + try: + import open3d as o3d + except ImportError: + self.tsdf_status.value = "Error: pip install open3d" + return + + self.tsdf_status.value = "Running TSDF fusion..." + print("Starting TSDF fusion...") + + voxel_size = self.tsdf_voxel_size_slider.value + sdf_trunc = self.tsdf_sdf_trunc_slider.value + depth_scale = self.tsdf_depth_scale_slider.value + depth_trunc = self.tsdf_depth_trunc_slider.value + + volume = o3d.pipelines.integration.ScalableTSDFVolume( + voxel_length=voxel_size, + sdf_trunc=sdf_trunc, + color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, + ) + + S = self.tsdf_depth_maps.shape[0] + H, W = self.tsdf_depth_maps.shape[1], self.tsdf_depth_maps.shape[2] + + for i in tqdm(range(S), desc="TSDF integrating"): + # Depth: (H, W, 1) -> (H, W) + depth = self.tsdf_depth_maps[i] + if depth.ndim == 3: + depth = depth[..., 0] + + # Color: (3, H, W) -> (H, W, 3), uint8 + color = self.tsdf_images[i].transpose(1, 2, 0) # (H, W, 3) + color = (np.clip(color, 0, 1) * 255).astype(np.uint8) + + # Camera extrinsic: (3, 4) -> (4, 4) camera-from-world + extr_34 = self.tsdf_extrinsics[i] + extr_44 = np.eye(4, dtype=np.float64) + extr_44[:3, :] = extr_34 + + intrinsic = o3d.camera.PinholeCameraIntrinsic( + width=W, height=H, + fx=float(self.tsdf_intrinsics[i, 0, 0]), + fy=float(self.tsdf_intrinsics[i, 1, 1]), + cx=float(self.tsdf_intrinsics[i, 0, 2]), + cy=float(self.tsdf_intrinsics[i, 1, 2]), + ) + + depth_o3d = o3d.geometry.Image( + (depth.astype(np.float32) * depth_scale).astype(np.float32) + ) + color_o3d = o3d.geometry.Image(np.ascontiguousarray(color)) + + rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( + color_o3d, depth_o3d, + depth_scale=depth_scale, + depth_trunc=depth_trunc, + convert_rgb_to_intensity=False, + ) + + volume.integrate(rgbd, intrinsic, extr_44) + + print("Extracting point cloud from TSDF volume...") + pcd = volume.extract_point_cloud() + + points = np.asarray(pcd.points, dtype=np.float32) + colors = np.asarray(pcd.colors, dtype=np.float32) # already 0-1 + + if len(points) == 0: + self.tsdf_status.value = "Error: empty result, try adjusting parameters" + print("TSDF fusion produced 0 points.") + return + + # Remove previous TSDF result + if self._tsdf_handle is not None: + try: + self._tsdf_handle.remove() + except (KeyError, AttributeError): + pass + + self._tsdf_handle = self.server.scene.add_point_cloud( + name="/tsdf_fusion", + points=points, + colors=colors, + point_size=self.psize_slider.value, + ) + + self.tsdf_status.value = f"Done: {len(points):,} points" + print(f"TSDF fusion complete: {len(points):,} points") + + def _export_glb(self): + """Export current filtered point clouds and cameras as a GLB file.""" + try: + import trimesh + except ImportError: + self.glb_status.value = "Error: pip install trimesh" + return + + self.glb_status.value = "Collecting points..." + print("Exporting GLB...") + + # Collect all currently visible, filtered points and colors + all_points = [] + all_colors = [] + for step in self.all_steps: + pc = self.pcs[step]["pc"] + color = self.pcs[step]["color"] + conf = self.pcs[step]["conf"] + edge_color = self.pcs[step].get("edge_color", None) + + pts, cols = self.parse_pc_data( + pc, color, conf, edge_color, set_border_color=False, + downsample_factor=self.downsample_slider.value, + ) + if len(pts) > 0: + all_points.append(pts) + if cols.dtype != np.uint8: + cols = (np.clip(cols, 0, 1) * 255).astype(np.uint8) + all_colors.append(cols) + + if not all_points: + self.glb_status.value = "Error: no points to export" + return + + vertices = np.concatenate(all_points, axis=0) + colors_rgb = np.concatenate(all_colors, axis=0) + + # --- Color enhancement --- + colors_float = colors_rgb.astype(np.float32) / 255.0 + + sat_boost = self.glb_saturation_slider.value + if sat_boost != 1.0: + gray = colors_float.mean(axis=1, keepdims=True) + colors_float = gray + sat_boost * (colors_float - gray) + + bri_boost = self.glb_brightness_slider.value + if bri_boost != 1.0: + colors_float = colors_float * bri_boost + + colors_float = np.clip(colors_float, 0.0, 1.0) + + # --- Opacity --- + # Simulate opacity by blending colors toward white (works in all viewers). + # For Spheres mode, also set true alpha for viewers that support it. + alpha = self.glb_opacity_slider.value + if alpha < 1.0: + bg = np.ones_like(colors_float) # white background + colors_float = colors_float * alpha + bg * (1.0 - alpha) + colors_float = np.clip(colors_float, 0.0, 1.0) + + colors_u8 = (colors_float * 255).astype(np.uint8) + colors_rgba = np.concatenate([ + colors_u8, + np.full((len(colors_u8), 1), int(alpha * 255), dtype=np.uint8), + ], axis=1) # (N, 4) + + # Compute scene scale for camera sizing + lo = np.percentile(vertices, 5, axis=0) + hi = np.percentile(vertices, 95, axis=0) + scene_scale = max(np.linalg.norm(hi - lo), 0.1) + + scene_3d = trimesh.Scene() + + # --- Export mode --- + export_mode = self.glb_mode_dropdown.value + if export_mode == "Spheres": + self.glb_status.value = "Building spheres..." + max_pts = int(self.glb_max_sphere_pts_slider.value) + radius = self.glb_sphere_radius_slider.value + + # Subsample if too many points + if len(vertices) > max_pts: + idx = np.random.choice(len(vertices), max_pts, replace=False) + idx.sort() + vertices = vertices[idx] + colors_rgba = colors_rgba[idx] + + sphere_template = trimesh.creation.icosphere(subdivisions=1, radius=radius) + n_verts_per = len(sphere_template.vertices) + n_faces_per = len(sphere_template.faces) + + all_verts = np.empty((len(vertices) * n_verts_per, 3), dtype=np.float32) + all_faces = np.empty((len(vertices) * n_faces_per, 3), dtype=np.int64) + all_face_colors = np.empty((len(vertices) * n_faces_per, 4), dtype=np.uint8) + + for i, (pt, rgba) in enumerate(zip(vertices, colors_rgba)): + v_off = i * n_verts_per + f_off = i * n_faces_per + all_verts[v_off:v_off + n_verts_per] = sphere_template.vertices + pt + all_faces[f_off:f_off + n_faces_per] = sphere_template.faces + v_off + all_face_colors[f_off:f_off + n_faces_per] = rgba + + mesh = trimesh.Trimesh(vertices=all_verts, faces=all_faces) + mesh.visual.face_colors = all_face_colors + # Enable alpha blending in glTF material for true transparency + if alpha < 1.0: + mesh.visual.material.alphaMode = 'BLEND' + scene_3d.add_geometry(mesh) + print(f"Spheres mode: {len(vertices):,} spheres, {len(all_faces):,} faces") + else: + # Points mode (GLB viewers ignore alpha on points, so use blended RGB) + scene_3d.add_geometry(trimesh.PointCloud(vertices=vertices, colors=colors_u8)) + + # Add cameras and trajectory + if self.glb_show_cam_checkbox.value and self.cam_dict is not None: + from lingbot_map.vis.glb_export import integrate_camera_into_scene + import matplotlib + colormap = matplotlib.colormaps.get_cmap("gist_rainbow") + num_cameras = len(self.all_steps) + cam_positions = [] + + frustum_thickness = self.glb_frustum_thickness_slider.value + effective_cam_scale = scene_scale * self.glb_cam_scale_slider.value + + for i, step in enumerate(self.all_steps): + R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3) + t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3) + + c2w = np.eye(4) + c2w[:3, :3] = R + c2w[:3, 3] = t + cam_positions.append(np.array(t, dtype=np.float64)) + + rgba_c = colormap(i / max(num_cameras - 1, 1)) + cam_color = tuple(int(255 * x) for x in rgba_c[:3]) + integrate_camera_into_scene( + scene_3d, c2w, cam_color, + effective_cam_scale, + frustum_thickness=frustum_thickness, + ) + + # Add trajectory line as a tube connecting camera positions + if self.glb_trajectory_checkbox.value and len(cam_positions) >= 2: + traj_pts = np.array(cam_positions) + traj_radius = self.glb_trajectory_radius_slider.value * self.glb_cam_scale_slider.value + traj_mesh = self._build_trajectory_tube( + traj_pts, traj_radius, colormap, num_cameras + ) + if traj_mesh is not None: + scene_3d.add_geometry(traj_mesh) + + # Align scene using first camera extrinsic + if self.cam_dict is not None and len(self.all_steps) > 0: + from lingbot_map.vis.glb_export import apply_scene_alignment + step0 = self.all_steps[0] + R0 = self.cam_dict["R"][step0] if "R" in self.cam_dict else np.eye(3) + t0 = self.cam_dict["t"][step0] if "t" in self.cam_dict else np.zeros(3) + c2w_0 = np.eye(4) + c2w_0[:3, :3] = R0 + c2w_0[:3, 3] = t0 + w2c_0 = np.linalg.inv(c2w_0) + extrinsics = np.expand_dims(w2c_0, 0) + scene_3d = apply_scene_alignment(scene_3d, extrinsics) + + output_path = self.glb_output_path.value + scene_3d.export(output_path) + + n_pts = len(vertices) + mode_str = f"spheres r={self.glb_sphere_radius_slider.value}" if export_mode == "Spheres" else "points" + self.glb_status.value = f"Saved: {output_path} ({n_pts:,} {mode_str})" + print(f"GLB exported to {output_path} ({n_pts:,} {mode_str})") + + @staticmethod + def _build_trajectory_tube(positions, radius, colormap, num_cameras): + """Build a tube mesh following camera trajectory with per-segment color. + + Args: + positions: (N, 3) camera positions. + radius: Tube radius. + colormap: Matplotlib colormap for gradient coloring. + num_cameras: Total number of cameras (for color normalization). + + Returns: + trimesh.Trimesh or None. + """ + import trimesh + + segments = [] + for i in range(len(positions) - 1): + p0, p1 = positions[i], positions[i + 1] + seg_len = np.linalg.norm(p1 - p0) + if seg_len < 1e-8: + continue + + # Create cylinder along Z, then transform + cyl = trimesh.creation.cylinder(radius=radius, height=seg_len, sections=8) + + # Direction vector + direction = (p1 - p0) / seg_len + mid = (p0 + p1) / 2.0 + + # Build rotation: default cylinder is along Z + z_axis = np.array([0.0, 0.0, 1.0]) + v = np.cross(z_axis, direction) + c = np.dot(z_axis, direction) + + if np.linalg.norm(v) < 1e-8: + rot = np.eye(3) if c > 0 else np.diag([1, -1, -1]) + else: + vx = np.array([[0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0]]) + rot = np.eye(3) + vx + vx @ vx / (1.0 + c) + + transform = np.eye(4) + transform[:3, :3] = rot + transform[:3, 3] = mid + cyl.apply_transform(transform) + + # Color: midpoint index + t_color = (i + 0.5) / max(num_cameras - 1, 1) + rgba = colormap(t_color) + color_rgb = tuple(int(255 * x) for x in rgba[:3]) + cyl.visual.face_colors[:, :3] = color_rgb + segments.append(cyl) + + if not segments: + return None + return trimesh.util.concatenate(segments) + + def update_frame_visibility(self): + """Update frame visibility based on range mode settings.""" + if not hasattr(self, 'frame_nodes') or not hasattr(self, 'gui_timestep'): + return + + current_timestep = self.gui_timestep.value + + if self.recent_n_mode_checkbox.value: + n = int(self.recent_n_slider.value) + start_idx = max(0, current_timestep - n + 1) + end_idx = current_timestep + for i, frame_node in enumerate(self.frame_nodes): + frame_node.visible = start_idx <= i <= end_idx + elif self.range_mode_checkbox.value: + start_idx = self.range_start_slider.value + end_idx = self.range_end_slider.value + for i, frame_node in enumerate(self.frame_nodes): + frame_node.visible = start_idx <= i <= end_idx + else: + for i, frame_node in enumerate(self.frame_nodes): + frame_node.visible = ( + i <= current_timestep if not self.fourd else i == current_timestep + ) + + def _move_to_camera(self, frame_idx: int, smooth: bool = True): + """Move viewer camera to match reconstructed camera at given frame.""" + if self.cam_dict is None: + return + + step = self.all_steps[frame_idx] if frame_idx < len(self.all_steps) else self.all_steps[-1] + + R = self.cam_dict["R"][step] if "R" in self.cam_dict else np.eye(3) + t = self.cam_dict["t"][step] if "t" in self.cam_dict else np.zeros(3) + focal = self.cam_dict["focal"][step] if "focal" in self.cam_dict else 1.0 + pp = self.cam_dict["pp"][step] if "pp" in self.cam_dict else (1.0, 1.0) + + offset = self.navigation_offset_slider.value if hasattr(self, 'navigation_offset_slider') else 0.5 + viewing_dir = R[:, 2] # camera Z axis in world frame + position = t - viewing_dir * offset + look_at = t + viewing_dir * 0.5 # look slightly ahead of camera + + fov = 2 * np.arctan(pp[0] / focal) if self.navigation_fov_checkbox.value else None + up = -R[:, 1] # camera -Y axis in world frame + + for client in self.server.get_clients().values(): + if smooth: + self._smooth_camera_transition( + client, + target_position=position, + target_look_at=look_at, + target_up=up, + target_fov=fov, + duration=0.3, + ) + else: + client.camera.up_direction = tuple(up) + client.camera.position = tuple(position) + client.camera.look_at = tuple(look_at) + if fov is not None: + client.camera.fov = fov + + def _smooth_camera_transition( + self, + client, + target_position, + target_look_at=None, + target_up=None, + target_fov=None, + duration=0.3, + ): + """Smoothly transition camera to target pose using look_at based control. + + Args: + client: Viser client handle. + target_position: Target camera position (3,). + target_look_at: Target look-at point (3,). If None, keeps current. + target_up: Target up direction (3,). If None, keeps current. + target_fov: Target FOV. If None, keeps current. + duration: Transition duration in seconds. + """ + def interpolate(): + num_steps = 15 + dt = duration / num_steps + + start_position = np.array(client.camera.position, dtype=np.float64) + start_look_at = np.array(client.camera.look_at, dtype=np.float64) + start_fov = client.camera.fov + + end_position = np.asarray(target_position, dtype=np.float64) + end_look_at = np.asarray(target_look_at, dtype=np.float64) if target_look_at is not None else start_look_at + + # Set up direction once at the start (not interpolated to avoid flicker) + if target_up is not None: + client.camera.up_direction = tuple(np.asarray(target_up, dtype=np.float64)) + + for i in range(num_steps + 1): + alpha = i / num_steps + # Smooth ease-in-out + alpha_smooth = alpha * alpha * (3 - 2 * alpha) + + interp_pos = start_position + (end_position - start_position) * alpha_smooth + interp_look = start_look_at + (end_look_at - start_look_at) * alpha_smooth + + # Set position first (this auto-moves look_at), then override look_at + client.camera.position = tuple(interp_pos) + client.camera.look_at = tuple(interp_look) + + if target_fov is not None: + interp_fov = start_fov + (target_fov - start_fov) * alpha_smooth + client.camera.fov = interp_fov + + time.sleep(dt) + + thread = threading.Thread(target=interpolate, daemon=True) + thread.start() + + def _slerp(self, q1, q2, t): + """Spherical linear interpolation between quaternions.""" + dot = np.dot(q1, q2) + + if abs(dot) > 0.9995: + result = q1 + t * (q2 - q1) + return result / np.linalg.norm(result) + + dot = np.clip(dot, -1.0, 1.0) + theta_0 = np.arccos(dot) + theta = theta_0 * t + + q2_orthogonal = q2 - q1 * dot + q2_orthogonal = q2_orthogonal / np.linalg.norm(q2_orthogonal) + + return q1 * np.cos(theta) + q2_orthogonal * np.sin(theta) + + def get_camera_state(self, client: viser.ClientHandle) -> CameraState: + """Get current camera state from client.""" + camera = client.camera + c2w = np.concatenate([ + np.concatenate([tf.SO3(camera.wxyz).as_matrix(), camera.position[:, None]], 1), + [[0, 0, 0, 1]], + ], 0) + return CameraState(fov=camera.fov, aspect=camera.aspect, c2w=c2w) + + @staticmethod + def generate_pseudo_intrinsics(h: int, w: int) -> np.ndarray: + """Generate pseudo intrinsics from image size.""" + focal = (h**2 + w**2) ** 0.5 + return np.array([[focal, 0, w // 2], [0, focal, h // 2], [0, 0, 1]]).astype(np.float32) + + def _connect_client(self, client: viser.ClientHandle): + """Setup client connection callbacks.""" + wxyz_panel = client.gui.add_text("wxyz:", f"{client.camera.wxyz}") + position_panel = client.gui.add_text("position:", f"{client.camera.position}") + fov_panel = client.gui.add_text( + "fov:", f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}" + ) + aspect_panel = client.gui.add_text("aspect:", "1.0") + + @client.camera.on_update + def _(_: viser.CameraHandle): + with self.server.atomic(): + wxyz_panel.value = f"{client.camera.wxyz}" + position_panel.value = f"{client.camera.position}" + fov_panel.value = f"{2 * np.arctan(self.size/self.focal_slider.value) * 180 / np.pi}" + aspect_panel.value = "1.0" + + @staticmethod + def set_color_border(image, border_width=5, color=[1, 0, 0]): + """Add colored border to image.""" + image[:border_width, :, 0] = color[0] + image[:border_width, :, 1] = color[1] + image[:border_width, :, 2] = color[2] + image[-border_width:, :, 0] = color[0] + image[-border_width:, :, 1] = color[1] + image[-border_width:, :, 2] = color[2] + image[:, :border_width, 0] = color[0] + image[:, :border_width, 1] = color[1] + image[:, :border_width, 2] = color[2] + image[:, -border_width:, 0] = color[0] + image[:, -border_width:, 1] = color[1] + image[:, -border_width:, 2] = color[2] + return image + + def read_data(self, pc_list, color_list, conf_list, edge_color_list=None): + """Read and organize point cloud data.""" + pcs = {} + step_list = [] + for i, pc in enumerate(pc_list): + step = i + pcs.update({ + step: { + "pc": pc, + "color": color_list[i], + "conf": conf_list[i], + "edge_color": ( + None if edge_color_list is None or edge_color_list[i] is None + else edge_color_list[i] + ), + } + }) + step_list.append(step) + + # Generate camera gradient colors + num_cameras = len(pc_list) + if num_cameras > 1: + normalized_indices = np.array(list(range(num_cameras))) / (num_cameras - 1) + else: + normalized_indices = np.array([0.0]) + cmap = cm.get_cmap('viridis') + self.camera_colors = cmap(normalized_indices) + return pcs, step_list + + def parse_pc_data( + self, + pc, + color, + conf=None, + edge_color=[0.251, 0.702, 0.902], + set_border_color=False, + downsample_factor=1, + ): + """Parse and filter point cloud data.""" + pred_pts = pc.reshape(-1, 3) + + if set_border_color and edge_color is not None: + color = self.set_color_border(color[0], color=edge_color) + if np.isnan(color).any(): + color = np.zeros((pred_pts.shape[0], 3)) + color[:, 2] = 1 + else: + color = color.reshape(-1, 3) + + # Remove NaN / Inf points + valid = np.isfinite(pred_pts).all(axis=1) + if not valid.all(): + pred_pts = pred_pts[valid] + color = color[valid] + if conf is not None: + conf = conf.reshape(-1)[valid] + + # Confidence threshold filter + if conf is not None: + conf_flat = conf.reshape(-1) if conf.ndim > 1 else conf + mask = conf_flat > self.vis_threshold + pred_pts = pred_pts[mask] + color = color[mask] + + if len(pred_pts) == 0: + return pred_pts, color + + # Confidence percentile filter + if conf is not None and hasattr(self, 'conf_percentile_slider'): + pct = self.conf_percentile_slider.value + if pct > 0: + conf_remaining = conf_flat[mask] if conf is not None else None + if conf_remaining is not None and len(conf_remaining) > 0: + threshold = np.percentile(conf_remaining, pct) + pct_mask = conf_remaining >= threshold + pred_pts = pred_pts[pct_mask] + color = color[pct_mask] + + if len(pred_pts) == 0: + return pred_pts, color + + # Bounding box clip: remove points far from the scene center + if hasattr(self, 'bbox_clip_slider'): + clip_pct = self.bbox_clip_slider.value + if clip_pct < 100.0: + lo = np.percentile(pred_pts, (100.0 - clip_pct) / 2, axis=0) + hi = np.percentile(pred_pts, 100.0 - (100.0 - clip_pct) / 2, axis=0) + bbox_mask = np.all((pred_pts >= lo) & (pred_pts <= hi), axis=1) + pred_pts = pred_pts[bbox_mask] + color = color[bbox_mask] + + if len(pred_pts) == 0: + return pred_pts, color + + # Statistical Outlier Removal (SOR) + if hasattr(self, 'sor_checkbox') and self.sor_checkbox.value and len(pred_pts) > 0: + pred_pts, color = self._statistical_outlier_removal( + pred_pts, color, + nb_neighbors=int(self.sor_neighbors_slider.value), + std_ratio=self.sor_std_slider.value, + ) + + # Downsample + if downsample_factor > 1 and len(pred_pts) > 0: + indices = np.arange(0, len(pred_pts), downsample_factor) + pred_pts = pred_pts[indices] + color = color[indices] + + return pred_pts, color + + @staticmethod + def _statistical_outlier_removal( + points: np.ndarray, + colors: np.ndarray, + nb_neighbors: int = 20, + std_ratio: float = 2.0, + ) -> Tuple[np.ndarray, np.ndarray]: + """Remove statistical outliers based on mean distance to k-nearest neighbors. + + Args: + points: (N, 3) point positions. + colors: (N, 3) point colors. + nb_neighbors: Number of nearest neighbors to consider. + std_ratio: Standard deviation multiplier for the distance threshold. + + Returns: + Filtered (points, colors) tuple. + """ + if len(points) <= nb_neighbors: + return points, colors + + try: + from scipy.spatial import cKDTree + except ImportError: + # Fallback: skip SOR if scipy not available + return points, colors + + # Subsample for KD-tree if too many points (speed) + max_pts_for_tree = 200_000 + if len(points) > max_pts_for_tree: + subsample_idx = np.random.choice(len(points), max_pts_for_tree, replace=False) + tree = cKDTree(points[subsample_idx]) + else: + tree = cKDTree(points) + + dists, _ = tree.query(points, k=nb_neighbors + 1) # +1 because first is self + mean_dists = dists[:, 1:].mean(axis=1) # exclude self + + threshold = mean_dists.mean() + std_ratio * mean_dists.std() + inlier_mask = mean_dists < threshold + + return points[inlier_mask], colors[inlier_mask] + + def add_pc(self, step): + """Add point cloud for a frame.""" + pc = self.pcs[step]["pc"] + color = self.pcs[step]["color"] + conf = self.pcs[step]["conf"] + edge_color = self.pcs[step].get("edge_color", None) + + pred_pts, color = self.parse_pc_data( + pc, color, conf, edge_color, set_border_color=True, + downsample_factor=self.downsample_slider.value + ) + + self.vis_pts_list.append(pred_pts) + self.pc_handles.append( + self.server.scene.add_point_cloud( + name=f"/frames/{step}/pred_pts", + points=pred_pts, + colors=color, + point_size=0.005, + ) + ) + + def add_camera(self, step): + """Add camera visualization for a frame.""" + cam = self.cam_dict + focal = cam["focal"][step] if cam and "focal" in cam else 1.0 + pp = cam["pp"][step] if cam and "pp" in cam else (1.0, 1.0) + R = cam["R"][step] if cam and "R" in cam else np.eye(3) + t = cam["t"][step] if cam and "t" in cam else np.zeros(3) + + q = tf.SO3.from_matrix(R).wxyz + fov = 2 * np.arctan(pp[0] / focal) + aspect = pp[0] / pp[1] + self.traj_list.append((q, t)) + + step_index = self.all_steps.index(step) if step in self.all_steps else 0 + camera_color = self.camera_colors[step_index] + camera_color_rgb = tuple((camera_color[:3] * 255).astype(int)) + + self.server.scene.add_frame( + f"/frames/{step}/camera_frame", + wxyz=q, + position=t, + axes_length=0.05, + axes_radius=0.002, + origin_radius=0.002, + ) + + frustum_handle = self.server.scene.add_camera_frustum( + name=f"/frames/{step}/camera", + fov=fov, + aspect=aspect, + wxyz=q, + position=t, + scale=0.1, + color=camera_color_rgb, + ) + + @frustum_handle.on_click + def _(event) -> None: + look_at_pt = t + R[:, 2] * 0.5 # look ahead along camera Z + up_dir = -R[:, 1] + for client in self.server.get_clients().values(): + client.camera.up_direction = tuple(up_dir) + client.camera.position = tuple(t) + client.camera.look_at = tuple(look_at_pt) + + self.cam_handles.append(frustum_handle) + + def animate(self): + """Setup and run animation controls.""" + with self.server.gui.add_folder("Playback"): + self.gui_timestep = self.server.gui.add_slider( + "Train Step", min=0, max=self.num_frames - 1, step=1, initial_value=0, disabled=False + ) + gui_next_frame = self.server.gui.add_button("Next Step", disabled=False) + gui_prev_frame = self.server.gui.add_button("Prev Step", disabled=False) + gui_playing = self.server.gui.add_checkbox("Playing", False) + gui_framerate = self.server.gui.add_slider("FPS", min=1, max=60, step=0.1, initial_value=1) + gui_framerate_options = self.server.gui.add_button_group("FPS options", ("10", "20", "30", "60")) + + @gui_next_frame.on_click + def _(_) -> None: + self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames + + @gui_prev_frame.on_click + def _(_) -> None: + self.gui_timestep.value = (self.gui_timestep.value - 1) % self.num_frames + + @gui_playing.on_update + def _(_) -> None: + self.gui_timestep.disabled = gui_playing.value + gui_next_frame.disabled = gui_playing.value + gui_prev_frame.disabled = gui_playing.value + + @gui_framerate_options.on_click + def _(_) -> None: + gui_framerate.value = int(gui_framerate_options.value) + + prev_timestep = self.gui_timestep.value + + @self.gui_timestep.on_update + def _(_) -> None: + nonlocal prev_timestep + current_timestep = self.gui_timestep.value + + if self.current_frame_image is not None and hasattr(self, 'original_images'): + if current_timestep < len(self.original_images): + self.current_frame_image.image = self.original_images[current_timestep] + + if hasattr(self, 'navigation_mode_checkbox') and self.navigation_mode_checkbox.value: + self._move_to_camera(current_timestep, smooth=self.smooth_navigation_checkbox.value) + + if self.recent_n_mode_checkbox.value: + self.update_frame_visibility() + elif not self.range_mode_checkbox.value: + with self.server.atomic(): + self.frame_nodes[current_timestep].visible = True + self.frame_nodes[prev_timestep].visible = False + self.server.flush() + + prev_timestep = current_timestep + + self.server.scene.add_frame("/frames", show_axes=False) + self.frame_nodes = [] + for i in range(self.num_frames): + step = self.all_steps[i] + self.frame_nodes.append( + self.server.scene.add_frame(f"/frames/{step}", show_axes=False) + ) + self.add_pc(step) + if self.show_camera: + downsample_factor = int(self.camera_downsample_slider.value) + if i % downsample_factor == 0: + self.add_camera(step) + + prev_timestep = self.gui_timestep.value + while True: + if self.on_replay: + pass + else: + if gui_playing.value: + self.gui_timestep.value = (self.gui_timestep.value + 1) % self.num_frames + self.update_frame_visibility() + + time.sleep(1.0 / gui_framerate.value) + + def _take_screenshot(self, client: Optional[Any] = None): + """Capture a screenshot from the current view and save to file. + + Args: + client: The viser client that triggered the action. If None, + uses the first connected client. + """ + output_path = self.screenshot_path.value + res_str = self.screenshot_resolution.value + + # Resolve client + if client is None: + clients = list(self.server.get_clients().values()) + if not clients: + self.screenshot_status.value = "Error: no client connected" + return + client = clients[0] + + try: + self.screenshot_status.value = "Capturing..." + + if res_str == "Current": + # Use default render size + width, height = 1920, 1080 + else: + width, height = map(int, res_str.split("x")) + + render = client.camera.get_render(height=height, width=width) + + if render is not None: + frame = np.array(render) + if frame.shape[2] == 4: + frame = frame[:, :, :3] + frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + cv2.imwrite(output_path, frame_bgr) + self.screenshot_status.value = f"Saved: {output_path}" + print(f"Screenshot saved to {output_path} ({width}x{height})") + else: + self.screenshot_status.value = "Error: render returned None" + print("Screenshot failed: render returned None") + + except Exception as e: + self.screenshot_status.value = f"Error: {e}" + print(f"Screenshot error: {e}") + + def save_video( + self, + output_path: str = "output_pointcloud.mp4", + fps: int = 30, + resolution: str = "1920x1080", + save_original_video: bool = True + ): + """Save point cloud animation as video.""" + try: + if hasattr(self, 'video_status'): + self.video_status.value = "Saving video..." + print(f"Saving video to {output_path}...") + + width, height = map(int, resolution.split('x')) + temp_dir = tempfile.mkdtemp(prefix="viser_video_") + print(f"Temporary directory: {temp_dir}") + + print("Waiting for client connection...") + timeout = 10 + start_time = time.time() + while len(self.server.get_clients()) == 0: + time.sleep(0.1) + if time.time() - start_time > timeout: + raise RuntimeError("No client connected. Please open the visualization in a browser first.") + + print("Client connected. Starting to render frames...") + clients = list(self.server.get_clients().values()) + client = clients[0] + + if not hasattr(self, 'gui_timestep'): + raise RuntimeError("Animation not initialized. Please ensure animate() is called before save_video().") + + for i in tqdm(range(self.num_frames), desc="Rendering frames"): + self.gui_timestep.value = i + time.sleep(0.1) + + try: + screenshot = client.camera.get_render(height=height, width=width) + if screenshot is not None: + frame = np.array(screenshot) + if frame.shape[2] == 4: + frame = frame[:, :, :3] + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") + cv2.imwrite(frame_path, frame) + else: + frame = self._render_frame_fallback(i, width, height) + frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") + cv2.imwrite(frame_path, frame) + except Exception as e: + print(f"Warning: Error capturing frame {i}: {e}, using fallback") + frame = self._render_frame_fallback(i, width, height) + frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") + cv2.imwrite(frame_path, frame) + + print("Encoding video with ffmpeg...") + ffmpeg_cmd = [ + 'ffmpeg', '-y', '-framerate', str(fps), + '-i', os.path.join(temp_dir, 'frame_%06d.png'), + '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', + output_path + ] + + result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) + + if result.returncode == 0: + print(f"Point cloud video saved successfully to {output_path}") + if hasattr(self, 'video_status'): + self.video_status.value = f"Saved to {output_path}" + else: + print(f"FFmpeg error: {result.stderr}") + if hasattr(self, 'video_status'): + self.video_status.value = "Error: FFmpeg failed" + + if save_original_video and hasattr(self, 'original_images') and len(self.original_images) > 0: + self._save_original_video(output_path, fps, width, height) + + shutil.rmtree(temp_dir) + print("Temporary files cleaned up") + + except Exception as e: + print(f"Error saving video: {e}") + import traceback + traceback.print_exc() + if hasattr(self, 'video_status'): + self.video_status.value = f"Error: {str(e)}" + + def _save_original_video(self, pointcloud_video_path: str, fps: int, width: int, height: int): + """Save original images as video.""" + base_path = os.path.splitext(pointcloud_video_path)[0] + original_video_path = f"{base_path}_original.mp4" + + print(f"Saving original images video to {original_video_path}...") + + try: + temp_dir = tempfile.mkdtemp(prefix="original_video_") + + for i, img in enumerate(tqdm(self.original_images, desc="Saving original frames")): + frame = cv2.resize(img, (width, height)) + if len(frame.shape) == 3 and frame.shape[2] == 3: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png") + cv2.imwrite(frame_path, frame) + + print("Encoding original video with ffmpeg...") + ffmpeg_cmd = [ + 'ffmpeg', '-y', '-framerate', str(fps), + '-i', os.path.join(temp_dir, 'frame_%06d.png'), + '-c:v', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', + original_video_path + ] + + result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) + + if result.returncode == 0: + print(f"Original video saved successfully to {original_video_path}") + else: + print(f"FFmpeg error for original video: {result.stderr}") + + shutil.rmtree(temp_dir) + + except Exception as e: + print(f"Error saving original video: {e}") + import traceback + traceback.print_exc() + + def _render_frame_fallback(self, frame_idx: int, width: int, height: int) -> np.ndarray: + """Fallback rendering when screenshot capture fails.""" + if hasattr(self, 'original_images') and frame_idx < len(self.original_images): + frame = self.original_images[frame_idx].copy() + frame = cv2.resize(frame, (width, height)) + cv2.putText(frame, f"Frame {frame_idx}", (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + return frame + else: + frame = np.zeros((height, width, 3), dtype=np.uint8) + cv2.putText(frame, f"Frame {frame_idx} - No render available", + (width//4, height//2), + cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) + return frame + + def run(self, background_mode: bool = False): + """Run the viewer.""" + self.animate() + if background_mode: + def server_loop(): + while True: + time.sleep(0.001) + + thread = threading.Thread(target=server_loop, daemon=True) + thread.start() + else: + while True: + time.sleep(10.0) diff --git a/lingbot_map/vis/sky_segmentation.py b/lingbot_map/vis/sky_segmentation.py new file mode 100644 index 0000000..4cdaa80 --- /dev/null +++ b/lingbot_map/vis/sky_segmentation.py @@ -0,0 +1,473 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Sky segmentation utilities for filtering sky points from point clouds. +""" + +import glob +import os +from typing import Optional, Tuple + +import numpy as np +import cv2 +from tqdm.auto import tqdm + +try: + import onnxruntime +except ImportError: + onnxruntime = None + print("onnxruntime not found. Sky segmentation may not work.") + + +_SKYSEG_INPUT_SIZE = (320, 320) +_SKYSEG_SOFT_THRESHOLD = 0.1 +_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3" + + +def _get_cache_version_path(sky_mask_dir: str) -> str: + return os.path.join(sky_mask_dir, ".skyseg_cache_version") + + +def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool: + if sky_mask_dir is None: + return False + + os.makedirs(sky_mask_dir, exist_ok=True) + version_path = _get_cache_version_path(sky_mask_dir) + refresh_cache = True + if os.path.exists(version_path): + with open(version_path, "r", encoding="utf-8") as f: + refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION + + if refresh_cache: + print( + f"Sky mask cache at {sky_mask_dir} uses an older format; " + "regenerating masks with ImageNet-normalized skyseg input" + ) + with open(version_path, "w", encoding="utf-8") as f: + f.write(_SKYSEG_CACHE_VERSION) + + return refresh_cache + + +def run_skyseg( + onnx_session, + input_size: Tuple[int, int], + image: np.ndarray, +) -> np.ndarray: + """ + Run ONNX sky segmentation on a BGR image and return an 8-bit score map. + """ + resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1])) + x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32) + mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) + std = np.array([0.229, 0.224, 0.225], dtype=np.float32) + x = (x / 255.0 - mean) / std + x = x.transpose(2, 0, 1) + x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32") + + input_name = onnx_session.get_inputs()[0].name + output_name = onnx_session.get_outputs()[0].name + onnx_result = onnx_session.run([output_name], {input_name: x}) + + onnx_result = np.array(onnx_result).squeeze() + min_value = np.min(onnx_result) + max_value = np.max(onnx_result) + denom = max(max_value - min_value, 1e-8) + onnx_result = (onnx_result - min_value) / denom + onnx_result *= 255.0 + return onnx_result.astype(np.uint8) + + +def _mask_to_float(mask: np.ndarray) -> np.ndarray: + mask = mask.astype(np.float32) + if mask.size == 0: + return mask + return np.clip(mask, 0.0, 1.0) + + +def _mask_to_uint8(mask: np.ndarray) -> np.ndarray: + mask = np.asarray(mask) + if mask.dtype == np.uint8: + return mask + mask = mask.astype(np.float32) + if mask.size > 0 and mask.max() <= 1.0: + mask = mask * 255.0 + return np.clip(mask, 0.0, 255.0).astype(np.uint8) + + +def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray: + # The raw skyseg map is higher on sky and lower on non-sky. + return 1.0 - _mask_to_float(result_map) + + +def segment_sky_from_array( + image: np.ndarray, + skyseg_session, + target_h: int, + target_w: int +) -> np.ndarray: + """ + Segment sky from an image array using ONNX model. + + Args: + image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255] + skyseg_session: ONNX runtime inference session + target_h: Target output height + target_w: Target output width + + Returns: + Continuous non-sky confidence map in [0, 1]. + """ + image_rgb = _image_to_rgb_uint8(image) + image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) + result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr) + result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + return _result_map_to_non_sky_conf(result_map) + + +def segment_sky( + image_path: str, + skyseg_session, + output_path: Optional[str] = None +) -> np.ndarray: + """ + Segment sky from an image using ONNX model. + + Args: + image_path: Path to the input image + skyseg_session: ONNX runtime inference session + output_path: Optional path to save the mask + + Returns: + Continuous non-sky confidence map in [0, 1]. + """ + image = cv2.imread(image_path) + if image is None: + raise ValueError(f"Failed to read image: {image_path}") + + result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image) + result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR) + mask = _result_map_to_non_sky_conf(result_map) + + if output_path is not None: + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + cv2.imwrite(output_path, _mask_to_uint8(mask)) + + return mask + + +def _list_image_files(image_folder: str) -> list[str]: + image_files = sorted(glob.glob(os.path.join(image_folder, "*"))) + image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"} + return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions] + + +def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray: + if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3: + image = image.transpose(1, 2, 0) + + if image.ndim != 3 or image.shape[2] != 3: + raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}") + + if image.dtype != np.uint8: + image = image.astype(np.float32) + if image.max() <= 1.0: + image = image * 255.0 + image = np.clip(image, 0.0, 255.0).astype(np.uint8) + + return image + + +def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str: + if image_paths is not None and index < len(image_paths): + return os.path.basename(image_paths[index]) + return f"frame_{index:06d}.png" + + +def _save_sky_mask_visualization( + image: np.ndarray, + sky_mask: np.ndarray, + output_path: str, +) -> None: + image_rgb = _image_to_rgb_uint8(image) + if sky_mask.shape[:2] != image_rgb.shape[:2]: + sky_mask = cv2.resize( + sky_mask, + (image_rgb.shape[1], image_rgb.shape[0]), + interpolation=cv2.INTER_NEAREST, + ) + + mask_uint8 = _mask_to_uint8(sky_mask) + mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2) + overlay = image_rgb.astype(np.float32).copy() + sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD + overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65 + overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8) + + panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1) + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR)) + + +def load_or_create_sky_masks( + image_folder: Optional[str] = None, + image_paths: Optional[list[str]] = None, + images: Optional[np.ndarray] = None, + skyseg_model_path: str = "skyseg.onnx", + sky_mask_dir: Optional[str] = None, + sky_mask_visualization_dir: Optional[str] = None, + target_shape: Optional[Tuple[int, int]] = None, + num_frames: Optional[int] = None, +) -> Optional[np.ndarray]: + """ + Load cached sky masks or generate them with the ONNX model. + + Args: + image_folder: Folder containing input images. + image_paths: Optional explicit image file list, in the exact order to process. + images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3). + skyseg_model_path: Path to the sky segmentation ONNX model. + sky_mask_dir: Optional directory for cached raw masks. + sky_mask_visualization_dir: Optional directory for side-by-side visualizations. + target_shape: Optional output mask shape (H, W) after resizing. + num_frames: Optional maximum number of frames to process. + + Returns: + Sky masks with shape (S, H, W), or None if sky segmentation could not run. + """ + if onnxruntime is None: + print("Warning: onnxruntime not available, skipping sky segmentation") + return None + + if image_folder is None and image_paths is None and images is None: + print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation") + return None + + if not os.path.exists(skyseg_model_path): + print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...") + try: + download_skyseg_model(skyseg_model_path) + except Exception as e: + print(f"Warning: Failed to download sky segmentation model: {e}") + return None + + skyseg_session = onnxruntime.InferenceSession(skyseg_model_path) + sky_masks = [] + + if sky_mask_visualization_dir is not None: + os.makedirs(sky_mask_visualization_dir, exist_ok=True) + print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}") + + if images is not None: + if image_paths is None and image_folder is not None: + image_paths = _list_image_files(image_folder) + + num_images = images.shape[0] + if num_frames is not None: + num_images = min(num_images, num_frames) + if image_paths is not None: + image_paths = image_paths[:num_images] + + if sky_mask_dir is None and image_folder is not None: + sky_mask_dir = image_folder.rstrip("/") + "_sky_masks" + refresh_cache = _prepare_sky_mask_cache(sky_mask_dir) + + print("Generating sky masks from image array...") + for i in tqdm(range(num_images)): + image_rgb = _image_to_rgb_uint8(images[i]) + image_h, image_w = image_rgb.shape[:2] + image_name = _get_mask_filename(image_paths, i) + mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None + + if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath): + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + if sky_mask is None: + print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it") + sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w) + cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask)) + elif sky_mask.shape[:2] != (image_h, image_w): + print( + f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image " + f"shape {(image_h, image_w)} for {image_name}; regenerating it" + ) + sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w) + cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask)) + else: + sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w) + if mask_filepath is not None: + cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask)) + + if sky_mask_visualization_dir is not None: + _save_sky_mask_visualization( + image_rgb, + sky_mask, + os.path.join(sky_mask_visualization_dir, image_name), + ) + + if target_shape is not None and sky_mask.shape[:2] != target_shape: + sky_mask = cv2.resize( + sky_mask, + (target_shape[1], target_shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + sky_masks.append(_mask_to_float(sky_mask)) + + else: + if image_paths is None and image_folder is not None: + image_paths = _list_image_files(image_folder) + + if images is None and image_paths is not None: + if len(image_paths) == 0: + print("Warning: No image files provided, skipping sky segmentation") + return None + + if num_frames is not None: + image_paths = image_paths[:num_frames] + + if sky_mask_dir is None: + if image_folder is None: + image_folder = os.path.dirname(image_paths[0]) + sky_mask_dir = image_folder.rstrip("/") + "_sky_masks" + refresh_cache = _prepare_sky_mask_cache(sky_mask_dir) + + print("Generating sky masks from image files...") + for image_path in tqdm(image_paths): + image_name = os.path.basename(image_path) + mask_filepath = os.path.join(sky_mask_dir, image_name) + + if not refresh_cache and os.path.exists(mask_filepath): + sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE) + if sky_mask is None: + print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it") + sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) + else: + sky_mask = segment_sky(image_path, skyseg_session, mask_filepath) + + if sky_mask is None: + print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame") + continue + + if sky_mask_visualization_dir is not None: + image_bgr = cv2.imread(image_path) + if image_bgr is not None: + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + _save_sky_mask_visualization( + image_rgb, + sky_mask, + os.path.join(sky_mask_visualization_dir, image_name), + ) + + if target_shape is not None and sky_mask.shape[:2] != target_shape: + sky_mask = cv2.resize( + sky_mask, + (target_shape[1], target_shape[0]), + interpolation=cv2.INTER_LINEAR, + ) + + sky_masks.append(_mask_to_float(sky_mask)) + + if len(sky_masks) == 0: + print("Warning: No sky masks generated, skipping sky segmentation") + return None + + try: + return np.stack(sky_masks, axis=0) + except ValueError: + return np.array(sky_masks, dtype=object) + + +def apply_sky_segmentation( + conf: np.ndarray, + image_folder: Optional[str] = None, + image_paths: Optional[list[str]] = None, + images: Optional[np.ndarray] = None, + skyseg_model_path: str = "skyseg.onnx", + sky_mask_dir: Optional[str] = None, + sky_mask_visualization_dir: Optional[str] = None, +) -> np.ndarray: + """ + Apply sky segmentation to confidence scores. + + Args: + conf: Confidence scores with shape (S, H, W) + image_folder: Path to the folder containing input images (optional if images provided) + image_paths: Optional explicit image file list in processing order + images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided) + skyseg_model_path: Path to the sky segmentation ONNX model + sky_mask_dir: Optional directory for cached raw masks + sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images + + Returns: + Updated confidence scores with sky regions masked out + """ + S, H, W = conf.shape + + sky_mask_array = load_or_create_sky_masks( + image_folder=image_folder, + image_paths=image_paths, + images=images, + skyseg_model_path=skyseg_model_path, + sky_mask_dir=sky_mask_dir, + sky_mask_visualization_dir=sky_mask_visualization_dir, + target_shape=(H, W), + num_frames=S, + ) + if sky_mask_array is None: + return conf + + if sky_mask_array.shape[0] < S: + print( + f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; " + "leaving the remaining frames unmasked" + ) + padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype) + padded[: sky_mask_array.shape[0]] = sky_mask_array + sky_mask_array = padded + elif sky_mask_array.shape[0] > S: + sky_mask_array = sky_mask_array[:S] + + sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32) + conf = conf * sky_mask_binary + + print("Sky segmentation applied successfully") + return conf + + +def download_skyseg_model(output_path: str = "skyseg.onnx") -> str: + """ + Download sky segmentation model from HuggingFace. + + Args: + output_path: Path to save the model + + Returns: + Path to the downloaded model + """ + import requests + + url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx" + + print(f"Downloading sky segmentation model from {url}...") + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get('content-length', 0)) + + with open(output_path, 'wb') as f: + with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + pbar.update(len(chunk)) + + print(f"Model saved to {output_path}") + return output_path diff --git a/lingbot_map/vis/utils.py b/lingbot_map/vis/utils.py new file mode 100644 index 0000000..affdeb5 --- /dev/null +++ b/lingbot_map/vis/utils.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Visualization utility functions for colorization and color bars. +""" + +import dataclasses +from typing import Optional, Tuple + +import numpy as np +import torch +import cv2 +import matplotlib.cm as cm + + +@dataclasses.dataclass +class CameraState: + """Camera state for rendering.""" + fov: float + aspect: float + c2w: np.ndarray + + def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray: + """Get camera intrinsic matrix from FOV and image size.""" + W, H = img_wh + focal_length = H / 2.0 / np.tan(self.fov / 2.0) + K = np.array([ + [focal_length, 0.0, W / 2.0], + [0.0, focal_length, H / 2.0], + [0.0, 0.0, 1.0], + ]) + return K + + +def get_vertical_colorbar( + h: int, + vmin: float, + vmax: float, + cmap_name: str = "jet", + label: Optional[str] = None, + cbar_precision: int = 2 +) -> np.ndarray: + """ + Create a vertical colorbar image. + + Args: + h: Height in pixels + vmin: Minimum value + vmax: Maximum value + cmap_name: Colormap name + label: Optional label for the colorbar + cbar_precision: Decimal precision for tick labels + + Returns: + Colorbar image as numpy array (H, W, 3) + """ + from matplotlib.figure import Figure + from matplotlib.backends.backend_agg import FigureCanvasAgg + import matplotlib as mpl + + fig = Figure(figsize=(2, 8), dpi=100) + fig.subplots_adjust(right=1.5) + canvas = FigureCanvasAgg(fig) + + ax = fig.add_subplot(111) + cmap = cm.get_cmap(cmap_name) + norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + tick_cnt = 6 + tick_loc = np.linspace(vmin, vmax, tick_cnt) + cb1 = mpl.colorbar.ColorbarBase( + ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical" + ) + + tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc] + if cbar_precision == 0: + tick_label = [x[:-2] for x in tick_label] + + cb1.set_ticklabels(tick_label) + cb1.ax.tick_params(labelsize=18, rotation=0) + if label is not None: + cb1.set_label(label) + + canvas.draw() + s, (width, height) = canvas.print_to_buffer() + + im = np.frombuffer(s, np.uint8).reshape((height, width, 4)) + im = im[:, :, :3].astype(np.float32) / 255.0 + + if h != im.shape[0]: + w = int(im.shape[1] / im.shape[0] * h) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA) + + return im + + +def colorize_np( + x: np.ndarray, + cmap_name: str = "jet", + mask: Optional[np.ndarray] = None, + range: Optional[Tuple[float, float]] = None, + append_cbar: bool = False, + cbar_in_image: bool = False, + cbar_precision: int = 2, +) -> np.ndarray: + """ + Turn a grayscale image into a color image. + + Args: + x: Input grayscale image [H, W] + cmap_name: Colormap name + mask: Optional mask image [H, W] + range: Value range for scaling [min, max], automatic if None + append_cbar: Whether to append colorbar + cbar_in_image: Put colorbar inside image + cbar_precision: Colorbar tick precision + + Returns: + Colorized image [H, W, 3] + """ + if range is not None: + vmin, vmax = range + elif mask is not None: + vmin = np.min(x[mask][np.nonzero(x[mask])]) + vmax = np.max(x[mask]) + x[np.logical_not(mask)] = vmin + else: + vmin, vmax = np.percentile(x, (1, 100)) + vmax += 1e-6 + + x = np.clip(x, vmin, vmax) + x = (x - vmin) / (vmax - vmin) + + cmap = cm.get_cmap(cmap_name) + x_new = cmap(x)[:, :, :3] + + if mask is not None: + mask = np.float32(mask[:, :, np.newaxis]) + x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask) + + cbar = get_vertical_colorbar( + h=x.shape[0], + vmin=vmin, + vmax=vmax, + cmap_name=cmap_name, + cbar_precision=cbar_precision, + ) + + if append_cbar: + if cbar_in_image: + x_new[:, -cbar.shape[1]:, :] = cbar + else: + x_new = np.concatenate( + (x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1 + ) + return x_new + else: + return x_new + + +def colorize( + x: torch.Tensor, + cmap_name: str = "jet", + mask: Optional[torch.Tensor] = None, + range: Optional[Tuple[float, float]] = None, + append_cbar: bool = False, + cbar_in_image: bool = False +) -> torch.Tensor: + """ + Turn a grayscale image into a color image (PyTorch tensor version). + + Args: + x: Grayscale image tensor [H, W] or [B, H, W] + cmap_name: Colormap name + mask: Optional mask tensor [H, W] or [B, H, W] + range: Value range for scaling + append_cbar: Whether to append colorbar + cbar_in_image: Put colorbar inside image + + Returns: + Colorized tensor + """ + device = x.device + x = x.cpu().numpy() + if mask is not None: + mask = mask.cpu().numpy() > 0.99 + kernel = np.ones((3, 3), np.uint8) + + if x.ndim == 2: + x = x[None] + if mask is not None: + mask = mask[None] + + out = [] + for x_ in x: + if mask is not None: + mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool) + + x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image) + out.append(torch.from_numpy(x_).to(device).float()) + out = torch.stack(out).squeeze(0) + return out diff --git a/lingbot_map/vis/viser_wrapper.py b/lingbot_map/vis/viser_wrapper.py new file mode 100644 index 0000000..32572ea --- /dev/null +++ b/lingbot_map/vis/viser_wrapper.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +""" +Quick visualization wrapper for GCT predictions using Viser. +""" + +import time +import threading +from typing import List, Optional + +import numpy as np +import viser +import viser.transforms as tf +from tqdm.auto import tqdm + +from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map +from lingbot_map.vis.sky_segmentation import apply_sky_segmentation + + +def viser_wrapper( + pred_dict: dict, + port: int = 8080, + init_conf_threshold: float = 50.0, + use_point_map: bool = False, + background_mode: bool = False, + mask_sky: bool = False, + image_folder: Optional[str] = None, +): + """ + Visualize predicted 3D points and camera poses with viser. + + This is a simplified wrapper for quick visualization without the full + PointCloudViewer controls. + + Args: + pred_dict: Dictionary containing predictions with keys: + - images: (S, 3, H, W) - Input images + - world_points: (S, H, W, 3) + - world_points_conf: (S, H, W) + - depth: (S, H, W, 1) + - depth_conf: (S, H, W) + - extrinsic: (S, 3, 4) + - intrinsic: (S, 3, 3) + port: Port number for the viser server + init_conf_threshold: Initial percentage of low-confidence points to filter out + use_point_map: Whether to visualize world_points or use depth-based points + background_mode: Whether to run the server in background thread + mask_sky: Whether to apply sky segmentation to filter out sky points + image_folder: Path to the folder containing input images (for sky segmentation) + + Returns: + viser.ViserServer: The viser server instance + """ + print(f"Starting viser server on port {port}") + + server = viser.ViserServer(host="0.0.0.0", port=port) + server.gui.configure_theme(titlebar_content=None, control_layout="collapsible") + + # Unpack prediction dict + images = pred_dict["images"] # (S, 3, H, W) + world_points_map = pred_dict["world_points"] # (S, H, W, 3) + conf_map = pred_dict["world_points_conf"] # (S, H, W) + + depth_map = pred_dict["depth"] # (S, H, W, 1) + depth_conf = pred_dict["depth_conf"] # (S, H, W) + + extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4) + intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3) + + # Compute world points from depth if not using the precomputed point map + if not use_point_map: + world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam) + conf = depth_conf + else: + world_points = world_points_map + conf = conf_map + + # Apply sky segmentation if enabled + if mask_sky and image_folder is not None: + conf = apply_sky_segmentation(conf, image_folder) + + # Convert images from (S, 3, H, W) to (S, H, W, 3) + colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3) + shape = world_points.shape + S: int = shape[0] + H: int = shape[1] + W: int = shape[2] + + # Flatten + points = world_points.reshape(-1, 3) + colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8) + conf_flat = conf.reshape(-1) + + # Random sample points if too many + indices = None + if points.shape[0] > 6000000: + print(f"Too many points ({points.shape[0]}), randomly sampling 6M points") + indices = np.random.choice(points.shape[0], size=6000000, replace=False) + points = points[indices] + colors_flat = colors_flat[indices] + conf_flat = conf_flat[indices] + + cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) + cam_to_world = cam_to_world_mat[:, :3, :] + + # Compute scene center and recenter + scene_center = np.mean(points, axis=0) + points_centered = points - scene_center + cam_to_world[..., -1] -= scene_center + + # Store frame indices for filtering + frame_indices = ( + np.repeat(np.arange(S), H * W)[indices] + if indices is not None + else np.repeat(np.arange(S), H * W) + ) + + # Build the viser GUI + gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True) + gui_points_conf = server.gui.add_slider( + "Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold + ) + gui_frame_selector = server.gui.add_dropdown( + "Show Points from Frames", + options=["All"] + [str(i) for i in range(S)], + initial_value="All" + ) + + # Create the main point cloud + init_threshold_val = np.percentile(conf_flat, init_conf_threshold) + init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1) + point_cloud = server.scene.add_point_cloud( + name="viser_pcd", + points=points_centered[init_conf_mask], + colors=colors_flat[init_conf_mask], + point_size=0.0005, + point_shape="circle", + ) + + frames: List[viser.FrameHandle] = [] + frustums: List[viser.CameraFrustumHandle] = [] + + def visualize_frames(extrinsics, images_: np.ndarray) -> None: + """Add camera frames and frustums to the scene.""" + for f in frames: + f.remove() + frames.clear() + for fr in frustums: + fr.remove() + frustums.clear() + + def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None: + @frustum.on_click + def _(_) -> None: + for client in server.get_clients().values(): + client.camera.wxyz = frame.wxyz + client.camera.position = frame.position + + for img_id in tqdm(range(S)): + cam2world_3x4 = extrinsics[img_id] + T_world_camera = tf.SE3.from_matrix(cam2world_3x4) + + frame_axis = server.scene.add_frame( + f"frame_{img_id}", + wxyz=T_world_camera.rotation().wxyz, + position=T_world_camera.translation(), + axes_length=0.05, + axes_radius=0.002, + origin_radius=0.002, + ) + frames.append(frame_axis) + + img = images_[img_id] + img = (img.transpose(1, 2, 0) * 255).astype(np.uint8) + h, w = img.shape[:2] + + fy = 1.1 * h + fov = 2 * np.arctan2(h / 2, fy) + + frustum_cam = server.scene.add_camera_frustum( + f"frame_{img_id}/frustum", + fov=fov, + aspect=w / h, + scale=0.05, + image=img, + line_width=1.0 + ) + frustums.append(frustum_cam) + attach_callback(frustum_cam, frame_axis) + + def update_point_cloud() -> None: + """Update point cloud based on current GUI selections.""" + current_percentage = gui_points_conf.value + threshold_val = np.percentile(conf_flat, current_percentage) + print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%") + + conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5) + + if gui_frame_selector.value == "All": + frame_mask = np.ones_like(conf_mask, dtype=bool) + else: + selected_idx = int(gui_frame_selector.value) + frame_mask = frame_indices == selected_idx + + combined_mask = conf_mask & frame_mask + point_cloud.points = points_centered[combined_mask] + point_cloud.colors = colors_flat[combined_mask] + + @gui_points_conf.on_update + def _(_) -> None: + update_point_cloud() + + @gui_frame_selector.on_update + def _(_) -> None: + update_point_cloud() + + @gui_show_frames.on_update + def _(_) -> None: + for f in frames: + f.visible = gui_show_frames.value + for fr in frustums: + fr.visible = gui_show_frames.value + + # Add camera frames + import torch + if torch.is_tensor(cam_to_world): + cam_to_world_np = cam_to_world.cpu().numpy() + else: + cam_to_world_np = cam_to_world + visualize_frames(cam_to_world_np, images) + + print("Starting viser server...") + if background_mode: + def server_loop(): + while True: + time.sleep(0.001) + + thread = threading.Thread(target=server_loop, daemon=True) + thread.start() + else: + while True: + time.sleep(0.01) + + return server diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..29c6fb8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[project] +name = "lingbot-map" +version = "0.1.0" +description = "LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction" +requires-python = ">= 3.10" +dependencies = [ + "Pillow", + "huggingface_hub", + "einops", + "safetensors", + "opencv-python", + "tqdm", + "scipy", + "torchvision", +] + +[project.optional-dependencies] +vis = ["viser>=0.2.23", "trimesh", "matplotlib", "onnxruntime", "requests"] +demo = ["lingbot-map[vis]"] + +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["."] +include = ["lingbot_map*"]