first commit

This commit is contained in:
LinZhuoChen
2026-04-16 09:51:30 +08:00
commit f9b3ae457a
44 changed files with 11994 additions and 0 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
__pycache__/
*.pyc
*.pyo
*.egg-info/
dist/
build/
*.so
.eggs/
demo_render/
CLAUDE.md
.claude/
.agents/

399
LICENSE.txt Normal file
View File

@@ -0,0 +1,399 @@
Attribution-NonCommercial 4.0 International
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More_considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial 4.0 International Public
License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial 4.0 International Public License ("Public
License"). To the extent this Public License may be interpreted as a
contract, You are granted the Licensed Rights in consideration of Your
acceptance of these terms and conditions, and the Licensor grants You
such rights in consideration of benefits the Licensor receives from
making the Licensed Material available under these terms and
conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
d. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
e. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
f. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
g. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
h. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
i. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
j. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
k. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
l. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
4. If You Share Adapted Material You produce, the Adapter's
License You apply must not prevent recipients of the Adapted
Material from complying with this Public License.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material; and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

141
README.md Normal file
View File

@@ -0,0 +1,141 @@
<h1 align="center">LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction</h1>
<p align="center">
<a href="lingbot-map_paper.pdf"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
<a href="https://technology.robbyant.com/lingbot-map"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
<a href="https://huggingface.co/robbyant/lingbot-map"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=orange"></a>
<a href="LICENSE.txt"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
</p>
<p align="center">
<img src="assets/teaser.png" width="100%">
</p>
<p align="center">
<video src="https://gw.alipayobjects.com/v/huamei_vaouhm/afts/video/q0sdTr9Mm6IAAAAAmyAAAAgADglFAQJr" width="100%" autoplay loop muted playsinline></video>
</p>
---
# Quick Start
## Installation
**1. Create conda environment**
```bash
conda create -n lingbot-map python=3.10 -y
conda activate lingbot-map
```
**2. Install PyTorch (CUDA 12.8)**
```bash
pip install torch==2.9.1 torchvision==0.24.1 --index-url https://download.pytorch.org/whl/cu128
```
> For other CUDA versions, see [PyTorch Get Started](https://pytorch.org/get-started/locally/).
**3. Install lingbot-map**
```bash
pip install -e .
```
**4. Install FlashInfer (recommended)**
FlashInfer provides paged KV cache attention for efficient streaming inference:
```bash
# CUDA 12.8 + PyTorch 2.9
pip install flashinfer-python -i https://flashinfer.ai/whl/cu128/torch2.9/
```
> For other CUDA/PyTorch combinations, see [FlashInfer installation](https://docs.flashinfer.ai/installation.html).
> If FlashInfer is not installed, the model falls back to SDPA (PyTorch native attention) via `--use_sdpa`.
**5. Visualization dependencies (optional)**
```bash
pip install -e ".[vis]"
```
# Demo
## Streaming Inference from Images
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/
```
## Streaming Inference from Video
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10
```
## Streaming with Keyframe Interval
Use `--keyframe_interval` to reduce KV cache memory by only keeping every N-th frame as a keyframe. Non-keyframe frames still produce predictions but are not stored in the cache. This is useful for long sequences
which excesses 320 frames.
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --keyframe_interval 6
```
## Windowed Inference (for long sequences, >3000 frames)
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10 \
--mode windowed --window_size 64
```
## With Sky Masking
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --mask_sky
```
## Without FlashInfer (SDPA fallback)
```bash
python demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --use_sdpa
```
# Model Download
| Model Name | Huggingface Repository | Description |
| :--- | :--- | :--- |
| lingbot-map | [robbyant/lingbot-map](https://huggingface.co/robbyant/lingbot-map) | Base model checkpoint (4.63 GB) |
# License
This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details.
# Citation
```bibtex
@article{lingbot-map2026,
title={},
author={},
journal={arXiv preprint arXiv:},
year={2026}
}
```
# Acknowledgments
This work builds upon several excellent open-source projects:
- [VGGT](https://github.com/facebookresearch/vggt)
- [DINOv2](https://github.com/facebookresearch/dinov2)
- [Flashinfer](https://github.com/flashinfer-ai/flashinfer)
---

BIN
assets/teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 MiB

346
demo.py Normal file
View File

@@ -0,0 +1,346 @@
"""LingBot-MAP demo: streaming 3D reconstruction from images or video.
Usage:
# Streaming inference (frame-by-frame with KV cache)
python examples/demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/
# Streaming inference with keyframe KV caching
python examples/demo.py --model_path /path/to/checkpoint.pt \
--image_folder /path/to/images/ --mode streaming --keyframe_interval 6
# Windowed inference (for very long sequences, >500 frames)
python examples/demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10 --mode windowed --window_size 64
# From video with custom FPS sampling
python examples/demo.py --model_path /path/to/checkpoint.pt \
--video_path video.mp4 --fps 10
"""
import argparse
import glob
import os
import time
import cv2
import numpy as np
import torch
from tqdm.auto import tqdm
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
from lingbot_map.utils.geometry import closed_form_inverse_se3_general
from lingbot_map.utils.load_fn import load_and_preprocess_images
# =============================================================================
# Image loading
# =============================================================================
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
"""Load images from folder or video and preprocess into a tensor."""
if video_path is not None:
video_name = os.path.splitext(os.path.basename(video_path))[0]
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
os.makedirs(out_dir, exist_ok=True)
cap = cv2.VideoCapture(video_path)
src_fps = cap.get(cv2.CAP_PROP_FPS) or 30
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
interval = max(1, round(src_fps / fps))
idx, saved = 0, []
pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame")
while True:
ret, frame = cap.read()
if not ret:
break
if idx % interval == 0:
path = os.path.join(out_dir, f"{len(saved):06d}.jpg")
cv2.imwrite(path, frame)
saved.append(path)
idx += 1
pbar.update(1)
pbar.close()
cap.release()
paths = saved
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
else:
exts = image_ext.split(",")
paths = []
for ext in exts:
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
paths = sorted(paths)
if stride > 1:
paths = paths[::stride]
if first_k is not None and first_k > 0:
paths = paths[:first_k]
print(f"Loading {len(paths)} images...")
images = load_and_preprocess_images(
paths,
mode="crop",
image_size=image_size,
patch_size=patch_size,
)
h, w = images.shape[-2:]
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
return images, paths
# =============================================================================
# Model loading
# =============================================================================
def load_model(args, device):
"""Load GCTStream model from checkpoint."""
if getattr(args, "mode", "streaming") == "windowed":
from lingbot_map.models.gct_stream_window import GCTStream
else:
from lingbot_map.models.gct_stream import GCTStream
print("Building model...")
model = GCTStream(
img_size=args.image_size,
patch_size=args.patch_size,
enable_3d_rope=args.enable_3d_rope,
max_frame_num=args.max_frame_num,
kv_cache_sliding_window=args.kv_cache_sliding_window,
kv_cache_scale_frames=args.kv_cache_scale_frames,
kv_cache_cross_frame_special=True,
kv_cache_include_scale_frames=True,
use_sdpa=args.use_sdpa,
)
if args.model_path:
print(f"Loading checkpoint: {args.model_path}")
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
state_dict = ckpt.get("model", ckpt)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f" Missing keys: {len(missing)}")
if unexpected:
print(f" Unexpected keys: {len(unexpected)}")
print(" Checkpoint loaded.")
return model.to(device).eval()
# =============================================================================
# Post-processing
# =============================================================================
_BATCHED_NDIMS = {
"pose_enc": 3,
"depth": 5,
"depth_conf": 4,
"world_points": 5,
"world_points_conf": 4,
"extrinsic": 4,
"intrinsic": 4,
"chunk_sim3_scales": 2,
"chunk_sim3_poses": 4,
"chunk_se3_poses": 4,
"images": 5,
}
def _squeeze_single_batch(key, value):
"""Drop the leading batch dimension for single-sequence demo outputs."""
batched_ndim = _BATCHED_NDIMS.get(key)
if batched_ndim is None or not hasattr(value, "ndim"):
return value
if value.ndim == batched_ndim and value.shape[0] == 1:
return value[0]
return value
def postprocess(predictions, images):
"""Convert pose encoding to extrinsics (c2w) and move to CPU."""
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
# Convert w2c to c2w
extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
extrinsic_4x4[..., :3, :4] = extrinsic
extrinsic_4x4[..., 3, 3] = 1.0
extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
extrinsic = extrinsic_4x4[..., :3, :4]
predictions["extrinsic"] = extrinsic
predictions["intrinsic"] = intrinsic
predictions.pop("pose_enc_list", None)
predictions.pop("images", None)
print("Moving results to CPU...")
for k in list(predictions.keys()):
if isinstance(predictions[k], torch.Tensor):
predictions[k] = _squeeze_single_batch(
k, predictions[k].to("cpu", non_blocking=True)
)
images_cpu = images.to("cpu", non_blocking=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
return predictions, images_cpu
def prepare_for_visualization(predictions, images=None):
"""Convert predictions to the unbatched NumPy format used by vis code."""
vis_predictions = {}
for k, v in predictions.items():
if isinstance(v, torch.Tensor):
v = _squeeze_single_batch(k, v.detach().cpu())
vis_predictions[k] = v.numpy()
elif isinstance(v, np.ndarray):
vis_predictions[k] = _squeeze_single_batch(k, v)
else:
vis_predictions[k] = v
if images is None:
images = predictions.get("images")
if isinstance(images, torch.Tensor):
images = images.detach().cpu()
if isinstance(images, np.ndarray):
images = _squeeze_single_batch("images", images)
elif isinstance(images, torch.Tensor):
images = _squeeze_single_batch("images", images).numpy()
if isinstance(images, torch.Tensor):
images = images.numpy()
if images is not None:
vis_predictions["images"] = images
return vis_predictions
# =============================================================================
# Main
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo")
# Input
parser.add_argument("--image_folder", type=str, default=None)
parser.add_argument("--video_path", type=str, default=None)
parser.add_argument("--fps", type=int, default=10)
parser.add_argument("--first_k", type=int, default=None)
parser.add_argument("--stride", type=int, default=1)
# Model
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--image_size", type=int, default=518)
parser.add_argument("--patch_size", type=int, default=14)
# Inference mode
parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"],
help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences")
# Streaming options
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
parser.add_argument("--max_frame_num", type=int, default=1024)
parser.add_argument("--num_scale_frames", type=int, default=8)
parser.add_argument(
"--keyframe_interval",
type=int,
default=1,
help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame.",
)
parser.add_argument("--kv_cache_sliding_window", type=int, default=64)
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
parser.add_argument("--use_sdpa", action="store_true", default=False,
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
# Windowed options
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows")
parser.add_argument("--sim3", action="store_true", default=True, help="Use Sim(3) alignment between windows")
parser.add_argument("--no_sim3", dest="sim3", action="store_false", help="Disable Sim(3), use SE(3) instead")
# Visualization
parser.add_argument("--port", type=int, default=8080)
parser.add_argument("--conf_threshold", type=float, default=1.0)
parser.add_argument("--downsample_factor", type=int, default=10)
parser.add_argument("--point_size", type=float, default=0.005)
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
args = parser.parse_args()
assert args.image_folder or args.video_path, \
"Provide --image_folder or --video_path"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── Load images & model ──────────────────────────────────────────────────
t0 = time.time()
images, paths = load_images(
image_folder=args.image_folder, video_path=args.video_path,
fps=args.fps, first_k=args.first_k, stride=args.stride,
image_size=args.image_size, patch_size=args.patch_size,
)
model = load_model(args, device)
print(f"Total load time: {time.time() - t0:.1f}s")
images = images.to(device)
num_frames = images.shape[0]
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
print(f"Mode: {args.mode}")
if args.mode != "streaming" and args.keyframe_interval != 1:
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
args.keyframe_interval = 1
elif args.mode == "streaming" and args.keyframe_interval > 1:
print(
f"Keyframe streaming enabled: interval={args.keyframe_interval} "
f"(after the first {args.num_scale_frames} scale frames)."
)
# ── Inference ────────────────────────────────────────────────────────────
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
print(f"Running {args.mode} inference (dtype={dtype})...")
t0 = time.time()
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
if args.mode == "streaming":
predictions = model.inference_streaming(
images,
num_scale_frames=args.num_scale_frames,
keyframe_interval=args.keyframe_interval,
)
else: # windowed
predictions = model.inference_windowed(
images,
window_size=args.window_size,
overlap_size=args.overlap_size,
num_scale_frames=args.num_scale_frames,
sim3=args.sim3,
se3=not args.sim3,
)
t_infer = time.time() - t0
print(f"Inference done: {t_infer:.1f}s ({num_frames / t_infer:.1f} FPS)")
# ── Post-process ─────────────────────────────────────────────────────────
predictions, images_cpu = postprocess(predictions, images)
# ── Visualize ────────────────────────────────────────────────────────────
try:
from lingbot_map.vis import PointCloudViewer
viewer = PointCloudViewer(
pred_dict=prepare_for_visualization(predictions, images_cpu),
port=args.port,
init_conf_threshold=args.conf_threshold,
downsample_factor=args.downsample_factor,
point_size=args.point_size,
mask_sky=args.mask_sky,
)
print(f"3D viewer at http://localhost:{args.port}")
viewer.run()
except ImportError:
print("viser not installed. Install with: pip install lingbot-map[vis]")
print(f"Predictions contain keys: {list(predictions.keys())}")
if __name__ == "__main__":
main()

BIN
docs/.DS_Store vendored Normal file

Binary file not shown.

BIN
lingbot-map_paper.pdf Normal file

Binary file not shown.

BIN
lingbot_map/.DS_Store vendored Normal file

Binary file not shown.

0
lingbot_map/__init__.py Normal file
View File

View File

@@ -0,0 +1,2 @@
from .stream import AggregatorStream
from .base import AggregatorBase

View File

@@ -0,0 +1,608 @@
"""
AggregatorBase - Base class for all Aggregator implementations.
Provides shared functionality:
- Patch embedding (DINOv2)
- Special tokens (camera, register, scale)
- Block building
- Common forward pass structure
Subclasses implement mode-specific attention logic.
"""
import logging
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
from lingbot_map.layers import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
logger = logging.getLogger(__name__)
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
def slice_expand_and_flatten(token, B, S, first_num_frame=1):
"""
Helper function to slice, expand and flatten tokens.
Args:
token: Token tensor [1, 2, N, C] where first index is for first frames
B: Batch size
S: Sequence length
first_num_frame: Number of frames to use first token for
Returns:
Flattened tokens [B*S, N, C]
"""
# token shape: [1, 2, N, C]
# Expand to [B, S, N, C]
if first_num_frame > 1:
# Use first token for first first_num_frame frames, second for rest
token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
else:
# Use first token for first frame, second for rest
token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
# Flatten to [B*S, N, C]
return token_expanded.reshape(B * S, -1, token.shape[-1])
class AggregatorBase(nn.Module, ABC):
"""
Base class for all Aggregator implementations.
Handles shared components:
- Patch embedding (DINOv2 or conv)
- Special tokens (camera, register, optionally scale)
- Block creation (frame + global)
- RoPE (2D rotary position embeddings)
- Common forward pass scaffolding
Subclasses must implement:
- _process_global_attention(): Mode-specific cross-frame attention logic
"""
def __init__(
self,
# Architecture parameters
img_size=518,
patch_size=14,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4.0,
num_register_tokens=4,
# Block configuration
block_fn=Block,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
qk_norm=True,
init_values=0.01,
# Patch embedding
patch_embed="dinov2_vitl14_reg",
pretrained_path=None,
# Attention pattern
aa_order=["frame", "global"],
aa_block_size=1,
# RoPE
rope_freq=100,
disable_global_rope=False,
# Gradient checkpointing
use_reentrant: bool = False,
use_gradient_checkpoint: bool = True,
):
super().__init__()
# Store configuration
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.num_register_tokens = num_register_tokens
self.aa_order = aa_order
self.aa_block_size = aa_block_size
self.disable_global_rope = disable_global_rope
self.use_reentrant = use_reentrant
self.use_gradient_checkpoint = use_gradient_checkpoint
self.pretrained_path = pretrained_path
self.enable_ulysses_cp = False # CP disabled
print("pretrained_path:", self.pretrained_path)
# Validate depth
if self.depth % self.aa_block_size != 0:
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
self.aa_block_num = self.depth // self.aa_block_size
# Build patch embedding
self._build_patch_embed(
patch_embed=patch_embed,
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
embed_dim=embed_dim,
pretrained_path=pretrained_path
)
# Initialize RoPE
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
self.position_getter = PositionGetter() if self.rope is not None else None
# Build blocks (frame + global)
self._build_blocks(
block_fn=block_fn,
depth=depth,
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
)
# Setup special tokens (camera, register, optionally scale)
self._setup_special_tokens()
# Register normalization constants
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
# Initialize from DINO checkpoint if available
if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
self._init_blocks_from_dino(self._dino_checkpoint)
del self._dino_checkpoint # Free memory
def _build_patch_embed(
self,
patch_embed: str,
img_size: int,
patch_size: int,
num_register_tokens: int,
embed_dim: int,
pretrained_path: str,
interpolate_antialias=True,
interpolate_offset=0.0,
block_chunks=0,
init_values=1.0,
):
"""
Build patch embedding layer.
Supports:
- "conv": Simple convolutional patch embedding
- "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
"""
if "conv" in patch_embed:
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim
)
self._dino_checkpoint = None
else:
vit_models = {
"dinov2_vitl14_reg": vit_large,
"dinov2_vitb14_reg": vit_base,
"dinov2_vits14_reg": vit_small,
"dinov2_vitg2_reg": vit_giant2,
}
if patch_embed not in vit_models:
raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
self.patch_embed = vit_models[patch_embed](
img_size=img_size,
patch_size=patch_size,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
block_chunks=block_chunks,
init_values=init_values,
)
# Load pretrained weights
try:
ckpt = torch.load(pretrained_path)
del ckpt['pos_embed']
logger.info("Loading pretrained weights for DINOv2")
missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Store checkpoint for block initialization
self._dino_checkpoint = ckpt
except Exception as e:
logger.warning(f"Failed to load pretrained weights: {e}")
self._dino_checkpoint = None
# Disable gradients for mask token
if hasattr(self.patch_embed, "mask_token"):
self.patch_embed.mask_token.requires_grad_(False)
@abstractmethod
def _build_blocks(
self,
block_fn,
depth: int,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
proj_bias: bool,
ffn_bias: bool,
init_values: float,
qk_norm: bool,
):
"""
Build frame_blocks and global_blocks.
Subclasses implement mode-specific block creation.
Must create:
- self.frame_blocks: nn.ModuleList of frame attention blocks
- self.global_blocks: nn.ModuleList of global attention blocks
"""
pass
@abstractmethod
def _setup_special_tokens(self):
"""
Setup camera token, register tokens, and optionally scale token.
Subclasses implement mode-specific token initialization.
Must create:
- self.camera_token
- self.register_token
- self.scale_token (optional, for causal mode)
- self.patch_start_idx
- self.num_special_tokens
"""
pass
def _init_blocks_from_dino(self, dino_ckpt: dict):
"""
Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
Args:
dino_ckpt: Checkpoint dictionary from DINOv2 model
"""
logger.info("Initializing blocks from DINOv2 pretrained weights")
# Extract block keys
dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
if not dino_block_keys:
logger.warning("No 'blocks' found in DINO checkpoint")
return
# Get block indices
block_indices = set()
for key in dino_block_keys:
parts = key.split('.')
if len(parts) > 1 and parts[1].isdigit():
block_indices.add(int(parts[1]))
num_dino_blocks = len(block_indices)
print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
# Initialize frame_blocks
for i, block in enumerate(self.frame_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
# Initialize global_blocks
for i, block in enumerate(self.global_blocks):
dino_block_idx = i % num_dino_blocks
block_state_dict = {}
prefix = f'blocks.{dino_block_idx}.'
for key, value in dino_ckpt.items():
if key.startswith(prefix):
new_key = key[len(prefix):]
block_state_dict[new_key] = value
if block_state_dict:
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
if i == 0: # Only log for first block to avoid spam
print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
logger.info("Successfully initialized blocks from DINOv2 weights")
def _embed_images(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
) -> Tuple[torch.Tensor, int, int, int, int, int]:
"""
Embed images and prepare for attention processing.
Handles:
- Image normalization
- Patch embedding
- Special token concatenation
- Position embedding
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
Returns:
(tokens, B, S, S, P, C):
tokens: Embedded tokens [B*S, P, C]
B: Batch size
S: Sequence length
S: Same as above (no CP slicing)
P: Number of tokens per frame (patches + special tokens)
C: Embedding dimension
"""
B, S, C_in, H, W = images.shape
if C_in != 3:
raise ValueError(f"Expected 3 input channels, got {C_in}")
# Normalize images
images = (images - self._resnet_mean) / self._resnet_std
# No CP slicing: S_local == S_global
S_local = S
S_global = S
# Reshape for patch embedding [B*S, C, H, W]
images = images.view(B * S, C_in, H, W)
# Patch embedding
patch_tokens = self.patch_embed(images)
if isinstance(patch_tokens, dict):
patch_tokens = patch_tokens["x_norm_patchtokens"]
_, P_patch, C = patch_tokens.shape
# Prepare special tokens
special_tokens = self._prepare_special_tokens(
B, S_local, S_global, C,
num_frame_for_scale=num_frame_for_scale
)
# Concatenate special tokens + patch tokens
tokens = torch.cat([special_tokens, patch_tokens], dim=1)
_, P, C = tokens.shape
return tokens, B, S_local, S_global, P, C
@abstractmethod
def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
"""
Prepare special tokens (camera, register, optionally scale).
Subclasses implement mode-specific token preparation.
Args:
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
C: Embedding dimension
**kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
Returns:
Special tokens [B*S, N_special, C]
"""
pass
def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
"""
Get 2D position embeddings for RoPE.
Args:
B: Batch size
S: Sequence length
H: Image height
W: Image width
device: Device to create positions on
Returns:
Position tensor [B*S, P, 2] or None if rope is disabled
"""
if self.rope is None:
return None
# Get patch positions
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
# Add offset for patch tokens (skip special tokens at pos=0)
if self.patch_start_idx > 0:
pos = pos + 1
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
pos = torch.cat([pos_special, pos], dim=1)
return pos
def _process_frame_attention(
self,
tokens: torch.Tensor,
B: int,
S: int,
P: int,
C: int,
frame_idx: int,
pos: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process frame attention blocks.
Frame attention operates independently per frame (no cross-frame communication).
Tokens stay in shape [B*S, P, C].
Args:
tokens: Input tokens [B*S, P, C]
B: Batch size
S: Sequence length
P: Tokens per frame
C: Embedding dimension
frame_idx: Current frame block index
pos: Position embeddings [B*S, P, 2]
Returns:
(tokens, frame_idx, intermediates):
tokens: Output tokens [B*S, P, C]
frame_idx: Updated frame block index
intermediates: List of intermediate outputs [B, S, P, C]
"""
# Ensure correct shape
if tokens.shape != (B * S, P, C):
tokens = tokens.view(B * S, P, C)
if pos is not None and pos.shape != (B * S, P, 2):
pos = pos.view(B * S, P, 2)
intermediates = []
# Process blocks
for i in range(self.aa_block_size):
if self.training and self.use_gradient_checkpoint:
from torch.utils.checkpoint import checkpoint
tokens = checkpoint(
self.frame_blocks[frame_idx],
tokens,
pos,
False, # enable_ulysses_cp (always False)
use_reentrant=self.use_reentrant
)
else:
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
frame_idx += 1
intermediates.append(tokens.view(B, S, P, C))
return tokens, frame_idx, intermediates
@abstractmethod
def _process_global_attention(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
**kwargs
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process global (cross-frame) attention blocks.
Subclasses implement mode-specific attention logic.
Args:
tokens: Input tokens
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Tokens per frame
C: Embedding dimension
global_idx: Current global block index
pos: Position embeddings
**kwargs: Mode-specific parameters
Returns:
(tokens, global_idx, intermediates):
tokens: Output tokens
global_idx: Updated global block index
intermediates: List of intermediate outputs
"""
pass
def forward(
self,
images: torch.Tensor,
selected_idx: Optional[List[int]] = None,
# Mode-specific parameters
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
) -> Tuple[List[torch.Tensor], int]:
"""
Forward pass.
Args:
images: Input images [B, S, 3, H, W] in range [0, 1]
selected_idx: Which block indices to output (None = all)
num_frame_for_scale: Number of frames for scale estimation (causal mode)
sliding_window_size: Sliding window size in blocks (causal mode)
num_frame_per_block: Number of frames per processing block (causal mode)
Returns:
(output_list, patch_start_idx):
output_list: List of block outputs [B, S, P, 2C]
patch_start_idx: Index where patch tokens start
"""
B, S_input, _, H, W = images.shape
# Embed images
tokens, B, S_local, S_global, P, C = self._embed_images(
images,
num_frame_for_scale=num_frame_for_scale,
)
# Get position embeddings
pos_local = self._get_positions(B, S_local, H, W, device=images.device)
pos_global = self._get_positions(B, S_global, H, W, device=images.device)
# Alternating attention
frame_idx = 0
global_idx = 0
output_list = []
for block_group_idx in range(self.aa_block_num):
for attn_type in self.aa_order:
if attn_type == "frame":
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
tokens, B, S_local, P, C, frame_idx, pos=pos_local
)
elif attn_type == "global":
tokens, global_idx, global_intermediates = self._process_global_attention(
tokens, B, S_local, S_global, P, C, global_idx,
pos=pos_global,
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
image_height=H,
image_width=W,
)
else:
raise ValueError(f"Unknown attention type: {attn_type}")
# Collect outputs
if selected_idx is None or block_group_idx in selected_idx:
for i in range(len(frame_intermediates)):
# Concatenate frame and global intermediates [B, S, P, 2C]
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
output_list.append(concat_inter)
return output_list, self.patch_start_idx

View File

@@ -0,0 +1,531 @@
"""
AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
Provides:
- Temporal causal attention
- Sliding window support
- Scale token for scale estimation frames
- Streaming inference with FlashInfer paged KV cache
"""
import logging
import torch
import torch.nn as nn
from typing import Optional, Tuple, List
from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
from lingbot_map.layers.rope import WanRotaryPosEmbed
from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
logger = logging.getLogger(__name__)
class AggregatorStream(AggregatorBase):
"""
Streaming causal aggregator with FlashInfer paged KV cache.
Features:
- Temporal causal attention (each frame only attends to past frames)
- Sliding window support to limit attention scope
- Scale token for scale estimation frames
- Streaming inference with FlashInfer KV cache
"""
def __init__(
self,
# Causal-specific parameters
sliding_window_size: int = -1,
num_frame_for_scale: int = 1,
num_random_frames: int = 0,
attend_to_special_tokens: bool = False,
attend_to_scale_frames: bool = False,
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# Base class parameters via **kwargs
**kwargs
):
"""
Initialize AggregatorStream.
Args:
sliding_window_size: Sliding window size in blocks (-1 for full causal)
num_frame_for_scale: Number of scale estimation frames
num_random_frames: Number of random frames for long-range dependencies
attend_to_special_tokens: Enable cross-frame special token attention
attend_to_scale_frames: Include scale frames in attention
enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
max_frame_num: Maximum number of frames for 3D RoPE
kv_cache_sliding_window: Sliding window size for KV cache eviction
kv_cache_scale_frames: Number of scale frames to keep in KV cache
kv_cache_cross_frame_special: Keep special tokens from evicted frames
kv_cache_include_scale_frames: Include scale frames in KV cache
kv_cache_camera_only: Only keep camera tokens from evicted frames
**kwargs: Base class parameters
"""
self.sliding_window_size = sliding_window_size
self.num_frame_for_scale = num_frame_for_scale
self.num_random_frames = num_random_frames
self.attend_to_special_tokens = attend_to_special_tokens
self.attend_to_scale_frames = attend_to_scale_frames
self.enable_3d_rope = enable_3d_rope
self.max_frame_num = max_frame_num
# KV cache parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
# Pop kwargs that are passed but not needed by base class
kwargs.pop('enable_stream_inference', None)
use_flashinfer = kwargs.pop('use_flashinfer', True)
kwargs.pop('use_flexflash', None)
use_sdpa = kwargs.pop('use_sdpa', False)
# Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
self.use_sdpa = use_sdpa
self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
# Call parent __init__
super().__init__(**kwargs)
# Initialize KV cache
self._init_kv_cache()
# Initialize 3D RoPE if enabled
if self.enable_3d_rope:
self._init_3d_rope()
def _build_blocks(
self,
block_fn,
depth: int,
embed_dim: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool,
proj_bias: bool,
ffn_bias: bool,
init_values: float,
qk_norm: bool,
):
"""Build frame and global blocks for streaming causal mode."""
block_params = dict(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
init_values=init_values,
qk_norm=qk_norm,
)
# Frame blocks: Standard Block + RoPE
self.frame_blocks = nn.ModuleList([
block_fn(**block_params, rope=self.rope)
for _ in range(depth)
])
# Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
self.global_blocks = nn.ModuleList([
GlobalBlockCls(
**block_params,
rope=self.rope if not self.disable_global_rope else None,
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
)
for _ in range(depth)
])
def _setup_special_tokens(self):
"""Setup camera, register, and scale tokens for causal mode."""
# Camera token
self.camera_token = nn.Parameter(
torch.randn(1, 2, 1, self.embed_dim)
)
# Register tokens
if self.num_register_tokens > 0:
self.register_token = nn.Parameter(
torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
)
# Scale token (causal mode specific)
self.scale_token = nn.Parameter(
torch.ones(1, 2, 1, self.embed_dim)
)
# Initialize
nn.init.normal_(self.camera_token, std=1e-6)
if self.num_register_tokens > 0:
nn.init.normal_(self.register_token, std=1e-6)
nn.init.normal_(self.scale_token, std=1e-6)
# Token indexing (includes scale token)
self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
self.num_special_tokens = 1 + self.num_register_tokens + 1
def _init_kv_cache(self):
"""Initialize KV cache for streaming inference."""
self.kv_cache_manager = None # FlashInfer (lazy-initialized)
self.kv_cache = {} # Dict-based cache for SDPA
self.total_frames_processed = 0
self._cached_pos3d = None
if self.use_sdpa:
# Dict-based KV cache for SDPA
if hasattr(self, 'depth'):
for i in range(self.depth):
self.kv_cache[f"k_{i}"] = None
self.kv_cache[f"v_{i}"] = None
self.kv_cache[f"k_{i}_special"] = None
self.kv_cache[f"v_{i}_special"] = None
logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
else:
logger.info("FlashInfer KV cache will be lazily initialized on first forward")
def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
"""Lazily initialize FlashInferKVCacheManager on first use.
Args:
device: Device for cache tensors.
dtype: Data type for cache tensors.
tokens_per_frame: Actual number of tokens per frame (patches + specials).
If None, falls back to assuming square images of self.img_size.
"""
if self.kv_cache_manager is None:
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
head_dim = 64
if tokens_per_frame is None:
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
# max_num_frames: scale + window + headroom
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
self.kv_cache_manager = FlashInferKVCacheManager(
num_blocks=self.depth,
max_num_frames=max_num_frames,
tokens_per_frame=tokens_per_frame,
num_heads=num_heads,
head_dim=head_dim,
dtype=dtype,
device=device,
num_special_tokens=self.num_special_tokens,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
max_total_frames=self.max_frame_num + 100,
force_fp32=getattr(self, 'kv_cache_force_fp32', False),
fa3=getattr(self, 'kv_cache_fa3', False),
)
logger.info(
f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
)
return self.kv_cache_manager
def clean_kv_cache(self):
"""Clean KV cache (call this when starting a new sequence)."""
if self.kv_cache_manager is not None:
self.kv_cache_manager.reset()
if self.kv_cache:
for key in list(self.kv_cache.keys()):
if key == "_skip_append":
self.kv_cache[key] = False
else:
self.kv_cache[key] = None
self.total_frames_processed = 0
self._cached_pos3d = None
logger.info("KV cache cleaned")
def _init_3d_rope(self):
"""Initialize 3D RoPE for streaming inference."""
if not self.enable_3d_rope:
self.rope3d = None
return
num_heads = 16
head_dim = self.embed_dim // num_heads
self.rope3d = WanRotaryPosEmbed(
attention_head_dim=head_dim,
patch_size=(1, self.patch_size, self.patch_size),
max_seq_len=self.max_frame_num,
)
logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
"""
Generate 3D RoPE positions for streaming mode with correct global frame indices.
Args:
num_frames: Number of frames in current batch
H, W: Image height and width
device: Device to create positions on
f_start: Global start frame index
f_end: Global end frame index
Returns:
pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
"""
if self.rope3d is None:
return None
pph = H // self.patch_size
ppw = W // self.patch_size
pos3d = self.rope3d(
ppf=num_frames,
pph=pph,
ppw=ppw,
patch_start_idx=self.num_special_tokens,
device=device,
f_start=f_start,
f_end=f_end
)
return pos3d
def _prepare_special_tokens(
self,
B: int,
S_local: int,
S_global: int,
C: int,
num_frame_for_scale: Optional[int] = None,
) -> torch.Tensor:
"""
Prepare camera, register, and scale tokens.
Args:
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
C: Embedding dimension
num_frame_for_scale: Number of frames for scale estimation
Returns:
Special tokens [B*S_global, N_special, C]
"""
# Get effective num_frame_for_scale
scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
# Check cache state for both backends
has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
# Determine if we're in causal inference mode based on KV cache state
causal_inference = True
if causal_inference and has_flashinfer_cache:
S_cached = self.kv_cache_manager.num_frames
S_true = S_cached + S_global
elif causal_inference and has_sdpa_cache:
_, _, S_cached, _, _ = self.kv_cache["k_0"].shape
S_true = S_cached + S_global
else:
S_true = S_global
# Expand tokens based on mode
if causal_inference and S_true > S_global:
# Streaming mode: expand with S_true, then slice to get current frames
effective_scale_frames = min(scale_frames, S_true)
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
camera_token = camera_token_full[-S_global:, :, :]
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
register_token = register_token_full[-S_global:, :, :]
scale_token_full = slice_expand_and_flatten(
self.scale_token, B, S_true, first_num_frame=effective_scale_frames
)
scale_token = scale_token_full[-S_global:, :, :]
else:
# Batch mode or first inference: expand directly
effective_scale_frames = min(scale_frames, S_global)
camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
register_token = slice_expand_and_flatten(self.register_token, B, S_global)
scale_token = slice_expand_and_flatten(
self.scale_token, B, S_global, first_num_frame=effective_scale_frames
)
special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
# Verify shape
expected_shape = (B * S_global, self.num_special_tokens, C)
assert special_tokens.shape == expected_shape, \
f"Expected {expected_shape}, got {special_tokens.shape}"
return special_tokens
def _process_global_attention(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
# Mode-specific parameters
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
**kwargs,
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
"""
Process causal global attention via FlashInfer streaming path.
Args:
tokens: Input tokens
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Tokens per frame
C: Embedding dimension
global_idx: Current global block index
pos: Position embeddings
num_frame_for_scale: Number of frames for scale estimation
sliding_window_size: Sliding window size in blocks
num_frame_per_block: Number of frames per processing block
Returns:
(tokens, global_idx, intermediates)
"""
# Extract image dimensions from kwargs for 3D RoPE
image_height = kwargs.get('image_height', self.img_size)
image_width = kwargs.get('image_width', self.img_size)
return self._process_causal_stream(
tokens, B, S_local, S_global, P, C, global_idx, pos,
num_frame_per_block, sliding_window_size, num_frame_for_scale,
image_height=image_height, image_width=image_width
)
def _process_causal_stream(
self,
tokens: torch.Tensor,
B: int,
S_local: int,
S_global: int,
P: int,
C: int,
global_idx: int,
pos: Optional[torch.Tensor] = None,
num_frame_per_block: int = 1,
sliding_window_size: Optional[int] = None,
num_frame_for_scale: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
):
"""
Causal attention for streaming inference using FlashInfer KV cache.
Args:
tokens: Input tokens [B*S_local, P, C]
B: Batch size
S_local: Local sequence length
S_global: Global sequence length
P: Number of patches per frame (includes special tokens)
C: Channel dimension
global_idx: Starting block index
pos: Position embeddings [B*S_global, P, 2]
num_frame_per_block: Number of frames per block
sliding_window_size: Sliding window size in blocks
num_frame_for_scale: Number of scale frames
image_height: Image height for 3D RoPE calculation
image_width: Image width for 3D RoPE calculation
Returns:
(tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
"""
# Get effective parameters
scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
# Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
if tokens.shape != (B, S_local * P, C):
tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
# Calculate number of frames for block mask
num_frames = S_global
num_patches = P - self.num_special_tokens
# Check if this is the first block group
is_first_block_group = (global_idx < self.aa_block_size)
if self.enable_3d_rope and self.rope3d is not None:
if is_first_block_group:
f_start = self.total_frames_processed
f_end = self.total_frames_processed + S_global
H = image_height if image_height is not None else self.img_size
W = image_width if image_width is not None else self.img_size
pos3d = self._get_3d_positions_streaming(
S_global, H, W, tokens.device, f_start, f_end
)
self._cached_pos3d = pos3d
else:
pos3d = self._cached_pos3d
pos = pos3d
else:
# Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
if pos is not None and pos.shape != (B, S_global * P, 2):
pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
intermediates = []
# Process blocks with KV cache
for _ in range(self.aa_block_size):
num_patches = P - self.num_special_tokens
if self.use_sdpa:
# SDPA: dict-based KV cache
tokens = self.global_blocks[global_idx](
tokens,
pos=pos,
enable_ulysses_cp=False,
num_patches=num_patches,
num_special=self.num_special_tokens,
num_frames=num_frames,
enable_3d_rope=self.enable_3d_rope,
kv_cache=self.kv_cache,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=scale_frames,
num_register_tokens=self.num_register_tokens,
)
else:
# FlashInfer: paged KV cache manager
manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
tokens = self.global_blocks[global_idx](
tokens,
pos=pos,
enable_ulysses_cp=False,
num_patches=num_patches,
num_special=self.num_special_tokens,
num_frames=num_frames,
enable_3d_rope=self.enable_3d_rope,
kv_cache=manager,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=scale_frames,
num_register_tokens=self.num_register_tokens,
)
global_idx += 1
intermediates.append(tokens.view(B, S_local, P, C))
# Update total frames processed counter only on the first block group
if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
self.total_frames_processed += S_global
return tokens, global_idx, intermediates

View File

View File

@@ -0,0 +1,454 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from lingbot_map.layers import Mlp
from lingbot_map.layers.block import Block
from lingbot_map.layers.block import CameraBlock
from lingbot_map.heads.head_act import activate_pose
from lingbot_map.layers.rope import WanRotaryPosEmbed
from functools import partial
from torch.utils.checkpoint import checkpoint
class CameraHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
enable_ulysses_cp=False,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.enable_ulysses_cp = enable_ulysses_cp
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
num_iterations (int): Number of refinement iterations.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape # S is expected to be 1.
pred_pose_enc = None
pred_pose_enc_list = []
for _ in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
# Apply trunk blocks with enable_ulysses_cp
for block in self.trunk:
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraCausalHead(nn.Module):
"""
CameraHead predicts camera parameters from token representations using iterative refinement.
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
"""
def __init__(
self,
dim_in: int = 2048,
trunk_depth: int = 4,
pose_encoding_type: str = "absT_quaR_FoV",
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
trans_act: str = "linear",
quat_act: str = "linear",
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
num_iterations = 4,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
enable_ulysses_cp: bool = False,
attn_class: str = "flexflashattn_varlen",
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# 3D RoPE parameters
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
rope_theta: float = 10000.0,
):
super().__init__()
if pose_encoding_type == "absT_quaR_FoV":
self.target_dim = 9
else:
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
self.trans_act = trans_act
self.quat_act = quat_act
self.fl_act = fl_act
self.trunk_depth = trunk_depth
self.sliding_window_size = sliding_window_size
self.enable_ulysses_cp = enable_ulysses_cp
self.num_heads = num_heads
# 3D RoPE for temporal position encoding
self.enable_3d_rope = enable_3d_rope
if enable_3d_rope:
head_dim = dim_in // num_heads
# For camera head: each frame has 1 token (frame_seqlen=1)
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
self.rope3d = WanRotaryPosEmbed(
attention_head_dim=head_dim,
patch_size=(max_frame_num, 1, 1),
theta=rope_theta,
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
)
else:
self.rope3d = None
# Build the trunk using a sequence of transformer blocks.
self.trunk = nn.Sequential(
*[
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
for _ in range(trunk_depth)
]
)
# Normalizations for camera token and trunk output.
self.token_norm = nn.LayerNorm(dim_in)
self.trunk_norm = nn.LayerNorm(dim_in)
# Learnable empty camera pose token.
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
self.embed_pose = nn.Linear(self.target_dim, dim_in)
# Module for producing modulation parameters: shift, scale, and a gate.
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
# Adaptive layer normalization without affine parameters.
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
self.num_iterations = num_iterations
self.kv_cache = None
self.pos_cache = None
self.frame_idx = 0
self.cp_size = 1
## Get cp size if enable ulysses cp
if self.enable_ulysses_cp:
from torchtitan.distributed.sequence_parallel import (
init_sequence_parallel,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
)
self.cp_size = get_ulysses_sequence_parallel_world_size()
def clean_kv_cache(self):
del self.kv_cache
self.kv_cache = None
self.frame_idx = 0
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
"""
Forward pass to predict camera parameters.
Args:
aggregated_tokens_list (list): List of token tensors from the network;
the last tensor is used for prediction.
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
If None, use the default self.sliding_window_size.
Returns:
list: A list of predicted camera encodings (post-activation) from each iteration.
"""
# Use passed sliding_window_size if provided, otherwise use default
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Use tokens from the last block for camera prediction.
tokens = aggregated_tokens_list[-1]
# Extract the camera tokens
pose_tokens = tokens[:, :, 0]
pose_tokens = self.token_norm(pose_tokens)
if causal_inference:
if self.kv_cache is None:
self.kv_cache = []
for i in range(self.num_iterations):
self.kv_cache.append({"_skip_append": False})
for j in range(self.trunk_depth):
self.kv_cache[i][f"k_{j}"] = None
self.kv_cache[i][f"v_{j}"] = None
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
return pred_pose_enc_list
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
"""
Iteratively refine camera pose predictions.
Args:
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
num_iterations (int): Number of refinement iterations.
sliding_window_size (int, optional): Sliding window size to use.
Returns:
list: List of activated camera encodings from each iteration.
"""
B, S, C = pose_tokens.shape
pred_pose_enc = None
pred_pose_enc_list = []
# Check if this is the first call (processing scale frames)
# Scale frames should use batch mode attention for numerical consistency
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
# Generate 3D RoPE positions if enabled
pos3d = None
if self.rope3d is not None:
# For camera tokens: shape [B, S, C] where each frame has 1 token
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
# In streaming mode with KV cache, use frame_idx to track global position
# Otherwise, generate positions from 0
if self.kv_cache is not None:
f_start = self.frame_idx
f_end = self.frame_idx + S
else:
f_start = 0
f_end = None # Will use ppf as frame count
pos3d = self.rope3d(
ppf=S * self.cp_size, # Total frames (with CP)
pph=1, # height = 1 (camera token)
ppw=1, # width = 1 (camera token)
patch_start_idx=0, # No special tokens before
device=pose_tokens.device,
f_start=f_start,
f_end=f_end,
) # Returns [1, 1, S*cp_size, head_dim//2] complex
for i in range(num_iterations):
# Use a learned empty pose for the first iteration.
if pred_pose_enc is None:
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
else:
# Detach the previous prediction to avoid backprop through time.
pred_pose_enc = pred_pose_enc.detach()
module_input = self.embed_pose(pred_pose_enc)
# Generate modulation parameters and split them into shift, scale, and gate components.
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
# Adaptive layer normalization and modulation.
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
for idx in range(self.trunk_depth):
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
# Compute the delta update for the pose encoding.
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
if pred_pose_enc is None:
pred_pose_enc = pred_pose_enc_delta
else:
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
# Apply final activation functions for translation, quaternion, and field-of-view.
activated_pose = activate_pose(
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
)
pred_pose_enc_list.append(activated_pose)
# Update frame_idx for streaming mode (KV cache)
if self.kv_cache is not None:
self.frame_idx += S
return pred_pose_enc_list
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Modulate the input tensor using scaling and shifting parameters.
"""
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
return x * (1 + scale) + shift
class CameraDecoder(nn.Module):
def __init__(
self,
in_dim,
out_dim,
dec_embed_dim=512,
depth=5,
dec_num_heads=8,
mlp_ratio=4,
rope=None,
need_project=True,
use_checkpoint=False,
):
super().__init__()
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
self.use_checkpoint = use_checkpoint
self.blocks = nn.ModuleList([
Block(
dim=dec_embed_dim,
num_heads=dec_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_bias=True,
ffn_bias=True,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
ffn_layer=Mlp,
init_values=None,
qk_norm=False,
# attn_class=MemEffAttentionRope,
rope=rope
) for _ in range(depth)])
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
def forward(self, hidden, xpos=None):
hidden = self.projects(hidden)
B, V, P, C = hidden.shape
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
for i, blk in enumerate(self.blocks):
if self.use_checkpoint and self.training:
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
else:
hidden = blk(hidden, pos=xpos)
out = self.linear_out(hidden).view(B, V, P, -1)
return out

View File

@@ -0,0 +1,679 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
import os
from typing import List, Dict, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from .head_act import activate_head
from .utils import create_uv_grid, position_grid_to_embed
class DPTHead(nn.Module):
"""
DPT Head for dense prediction tasks.
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
backbone and produces dense predictions by fusing multi-scale features.
Args:
dim_in (int): Input dimension (channels).
patch_size (int, optional): Patch size. Default is 14.
output_dim (int, optional): Number of output channels. Default is 4.
activation (str, optional): Activation type. Default is "inv_log".
conf_activation (str, optional): Confidence activation type. Default is "expp1".
features (int, optional): Feature channels for intermediate representations. Default is 256.
out_channels (List[int], optional): Output channels for each intermediate layer.
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
"""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 4,
activation: str = "inv_log",
conf_activation: str = "expp1",
features: int = 256,
out_channels: List[int] = [256, 512, 1024, 1024],
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
pos_embed: bool = True,
feature_only: bool = False,
down_ratio: int = 1,
) -> None:
super(DPTHead, self).__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.feature_only = feature_only
self.down_ratio = down_ratio
self.intermediate_layer_idx = intermediate_layer_idx
self.norm = nn.LayerNorm(dim_in)
# Projection layers for each output channel from tokens.
self.projects = nn.ModuleList(
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
)
# Resize layers for upsampling feature maps.
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
self.scratch = _make_scratch(out_channels, features, expand=False)
# Attach additional modules to scratch.
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features)
self.scratch.refinenet2 = _make_fusion_block(features)
self.scratch.refinenet3 = _make_fusion_block(features)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
head_features_1 = features
head_features_2 = 32
if feature_only:
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
else:
self.scratch.output_conv1 = nn.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
)
conv2_in_channels = head_features_1 // 2
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
)
def forward(
self,
aggregated_tokens_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frames_chunk_size: int = 8,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass through the DPT head, supports processing by chunking frames.
Args:
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
patch_start_idx (int): Starting index for patch tokens in the token sequence.
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
frames_chunk_size (int, optional): Number of frames to process in each chunk.
If None or larger than S, all frames are processed at once. Default: 8.
Returns:
Tensor or Tuple[Tensor, Tensor]:
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
"""
B, _, _, H, W = images.shape
S = aggregated_tokens_list[0].shape[1]
# If frames_chunk_size is not specified or greater than S, process all frames at once
if frames_chunk_size is None or frames_chunk_size >= S:
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
# Otherwise, process frames in chunks to manage memory usage
assert frames_chunk_size > 0
# Process frames in batches
all_preds = []
all_conf = []
for frames_start_idx in range(0, S, frames_chunk_size):
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
# Process batch of frames
if self.feature_only:
chunk_output = self._forward_impl(
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
)
all_preds.append(chunk_output)
else:
chunk_preds, chunk_conf = self._forward_impl(
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
)
all_preds.append(chunk_preds)
all_conf.append(chunk_conf)
# Concatenate results along the sequence dimension
if self.feature_only:
return torch.cat(all_preds, dim=1)
else:
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
def _forward_impl(
self,
aggregated_tokens_list: List[torch.Tensor],
images: torch.Tensor,
patch_start_idx: int,
frames_start_idx: int = None,
frames_end_idx: int = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Implementation of the forward pass through the DPT head.
This method processes a specific chunk of frames from the sequence.
Args:
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
images (Tensor): Input images with shape [B, S, 3, H, W].
patch_start_idx (int): Starting index for patch tokens.
frames_start_idx (int, optional): Starting index for frames to process.
frames_end_idx (int, optional): Ending index for frames to process.
Returns:
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
"""
B, _, _, H, W = images.shape
patch_h, patch_w = H // self.patch_size, W // self.patch_size
out = []
dpt_idx = 0
for layer_idx in self.intermediate_layer_idx:
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
if frames_start_idx is not None and frames_end_idx is not None:
x = x[:, frames_start_idx:frames_end_idx]
B, S = x.shape[0], x.shape[1]
x = x.reshape(B * S, -1, x.shape[-1])
x = self.norm(x)
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[dpt_idx](x)
if self.pos_embed:
x = self._apply_pos_embed(x, W, H)
x = self.resize_layers[dpt_idx](x)
out.append(x)
dpt_idx += 1
# Fuse features from multiple layers.
out = self.scratch_forward(out)
# Interpolate fused output to match target image resolution.
out = custom_interpolate(
out,
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
mode="bilinear",
align_corners=True,
)
if self.pos_embed:
out = self._apply_pos_embed(out, W, H)
if self.feature_only:
return out.view(B, S, *out.shape[1:])
out = self.scratch.output_conv2(out)
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
preds = preds.view(B, S, *preds.shape[1:])
conf = conf.view(B, S, *conf.shape[1:])
return preds, conf
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""
Apply positional embedding to tensor x.
"""
patch_w = x.shape[-1]
patch_h = x.shape[-2]
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
pos_embed = pos_embed * ratio
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
return x + pos_embed
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
"""
Forward pass through the fusion blocks.
Args:
features (List[Tensor]): List of feature maps from different layers.
Returns:
Tensor: Fused feature map.
"""
layer_1, layer_2, layer_3, layer_4 = features
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
del layer_4_rn, layer_4
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
del layer_3_rn, layer_3
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
del layer_2_rn, layer_2
out = self.scratch.refinenet1(out, layer_1_rn)
del layer_1_rn, layer_1
out = self.scratch.output_conv1(out)
return out
################################################################################
# Modules
################################################################################
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
return FeatureFusionBlock(
features,
nn.ReLU(inplace=True),
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=size,
has_residual=has_residual,
groups=groups,
)
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module."""
def __init__(self, features, activation, bn, groups=1):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups = groups
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.norm1 = None
self.norm2 = None
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.norm1 is not None:
out = self.norm1(out)
out = self.activation(out)
out = self.conv2(out)
if self.norm2 is not None:
out = self.norm2(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block."""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None,
has_residual=True,
groups=1,
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups = groups
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
)
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.has_residual = has_residual
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
self.skip_add = nn.quantized.FloatFunctional()
self.size = size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if self.has_residual:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output
def custom_interpolate(
x: torch.Tensor,
size: Tuple[int, int] = None,
scale_factor: float = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
"""
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
"""
if size is None:
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
if input_elements > INT_MAX:
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
interpolated_chunks = [
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
]
x = torch.cat(interpolated_chunks, dim=0)
return x.contiguous()
else:
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
class DPTHead_Update(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead_Update, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
if return_intermediate:
return out, path_1, path_2, path_3, path_4
else:
out = self.scratch.output_conv2(out)
return out
def _make_fusion_block_slam(features, use_bn, size=None):
return FeatureFusionBlock_slam(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class FeatureFusionBlock_slam(nn.Module):
"""Feature fusion block.
"""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_slam, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size=size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
"""
Activate pose parameters with specified activation functions.
Args:
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
trans_act: Activation type for translation component
quat_act: Activation type for quaternion component
fl_act: Activation type for focal length component
Returns:
Activated pose parameters tensor
"""
T = pred_pose_enc[..., :3]
quat = pred_pose_enc[..., 3:7]
fl = pred_pose_enc[..., 7:] # or fov
T = base_pose_act(T, trans_act)
quat = base_pose_act(quat, quat_act)
fl = base_pose_act(fl, fl_act) # or fov
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
return pred_pose_enc
def base_pose_act(pose_enc, act_type="linear"):
"""
Apply basic activation function to pose parameters.
Args:
pose_enc: Tensor containing encoded pose parameters
act_type: Activation type ("linear", "inv_log", "exp", "relu")
Returns:
Activated pose parameters
"""
if act_type == "linear":
return pose_enc
elif act_type == "inv_log":
return inverse_log_transform(pose_enc)
elif act_type == "exp":
return torch.exp(pose_enc)
elif act_type == "relu":
return F.relu(pose_enc)
else:
raise ValueError(f"Unknown act_type: {act_type}")
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
"""
Process network output to extract 3D points and confidence values.
Args:
out: Network output tensor (B, C, H, W)
activation: Activation type for 3D points
conf_activation: Activation type for confidence values
Returns:
Tuple of (3D points tensor, confidence tensor)
"""
# Move channels from last dim to the 4th dimension => (B, H, W, C)
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
# Split into xyz (first C-1 channels) and confidence (last channel)
xyz = fmap[:, :, :, :-1]
conf = fmap[:, :, :, -1]
if activation == "norm_exp":
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
xyz_normed = xyz / d
pts3d = xyz_normed * torch.expm1(d)
elif activation == "norm":
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
elif activation == "exp":
pts3d = torch.exp(xyz)
elif activation == "relu":
pts3d = F.relu(xyz)
elif activation == "inv_log":
pts3d = inverse_log_transform(xyz)
elif activation == "xy_inv_log":
xy, z = xyz.split([2, 1], dim=-1)
z = inverse_log_transform(z)
pts3d = torch.cat([xy * z, z], dim=-1)
elif activation == "sigmoid":
pts3d = torch.sigmoid(xyz)
elif activation == "linear":
pts3d = xyz
else:
raise ValueError(f"Unknown activation: {activation}")
if conf_activation == "expp1":
conf_out = 1 + conf.exp()
elif conf_activation == "expp0":
conf_out = conf.exp()
elif conf_activation == "sigmoid":
conf_out = torch.sigmoid(conf)
else:
raise ValueError(f"Unknown conf_activation: {conf_activation}")
return pts3d, conf_out
def inverse_log_transform(y):
"""
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
Args:
y: Input tensor
Returns:
Transformed tensor
"""
return torch.sign(y) * (torch.expm1(torch.abs(y)))

109
lingbot_map/heads/utils.py Normal file
View File

@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
"""
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
Args:
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
embed_dim: Output channel dimension for embeddings
Returns:
Tensor of shape (H, W, embed_dim) with positional embeddings
"""
H, W, grid_dim = pos_grid.shape
assert grid_dim == 2
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
# Process x and y coordinates separately
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
# Combine and reshape
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
return emb.view(H, W, embed_dim) # [H, W, D]
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
"""
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args:
- embed_dim: The embedding dimension.
- pos: The position to generate the embedding from.
Returns:
- emb: The generated 1D positional embedding.
"""
assert embed_dim % 2 == 0
device = pos.device
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
omega /= embed_dim / 2.0
omega = 1.0 / omega_0**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb.float()
# Inspired by https://github.com/microsoft/moge
def create_uv_grid(
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
) -> torch.Tensor:
"""
Create a normalized UV grid of shape (width, height, 2).
The grid spans horizontally and vertically according to an aspect ratio,
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
corner is at (x_span, y_span), normalized by the diagonal of the plane.
Args:
width (int): Number of points horizontally.
height (int): Number of points vertically.
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
dtype (torch.dtype, optional): Data type of the resulting tensor.
device (torch.device, optional): Device on which the tensor is created.
Returns:
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
"""
# Derive aspect ratio if not explicitly provided
if aspect_ratio is None:
aspect_ratio = float(width) / float(height)
# Compute normalized spans for X and Y
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
# Establish the linspace boundaries
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
# Generate 1D coordinates
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
# Create 2D meshgrid (width x height) and stack into UV
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
uv_grid = torch.stack((uu, vv), dim=-1)
return uv_grid

View File

@@ -0,0 +1,5 @@
from lingbot_map.layers.mlp import Mlp
from lingbot_map.layers.patch_embed import PatchEmbed
from lingbot_map.layers.block import Block
from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
from lingbot_map.layers.attention import Attention as MemEffAttention

View File

@@ -0,0 +1,766 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import os
import math
import warnings
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from lingbot_map.layers.rope import apply_rotary_emb
from einops import rearrange
# FlashInfer imports (optional - for paged attention)
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
print("flashinfer not available")
try:
from torchtitan.distributed.sequence_parallel import (
gather_seq_scatter_heads,
gather_heads_scatter_seq,
pad_tensor,
slice_input_tensor_scale_grad,
gather_outputs,
)
except ImportError:
print("torchtitan not available for ulysses cp")
def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
"""Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
q = gather_seq_scatter_heads(q, seq_dim, head_dim)
k = gather_seq_scatter_heads(k, seq_dim, head_dim)
v = gather_seq_scatter_heads(v, seq_dim, head_dim)
return q, k, v
from typing_extensions import List
from typing import Optional, Tuple
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
if self.rope is not None:
q = self.rope(q, pos)
k = self.rope(k, pos)
if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
if enable_ulysses_cp:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class CausalAttention(nn.Module):
"""
Causal self-attention module with KV cache support for streaming inference.
Used by CasualBlockCamera in camera_head.py.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
elementwise_attn_output_gate=False,
# KV cache eviction parameters (matching build_attn_mask)
kv_cache_sliding_window: int =64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
) -> None:
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.fused_attn = fused_attn
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
# Store KV cache eviction parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
B, N, C = x.shape
# Calculate special token indices
camera_token_idx = 0
scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
# [3, B, num_heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.gate_proj is not None:
gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
if kv_cache is None:
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
N = q.shape[2] # Update N after gather
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif enable_3d_rope and pos is not None:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
with torch.no_grad():
block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
if block_mask.dim() == 2:
block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
mask = block_mask | ~video_mask
# Apply sliding window mask if sliding_window_size > 0
# sliding_window_size is in units of num_frame_per_block
if sliding_window_size > 0 and frame_seqlen is not None:
# Create sliding window mask: each frame can only attend to frames within the window
num_frames = N // frame_seqlen
sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
for i in range(num_frames):
q_start = i * frame_seqlen
q_end = (i + 1) * frame_seqlen
# Calculate the window start: sliding_window_size is in units of num_frame_per_block
# So the actual window size in frames is sliding_window_size * num_frame_per_block
window_size_in_frames = sliding_window_size * num_frame_per_block
window_start_frame = max(0, i - window_size_in_frames + 1)
k_start = window_start_frame * frame_seqlen
k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
# Combine with existing mask: both masks need to allow attention
mask = mask & sliding_mask
# If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
if num_frame_for_scale > 0:
for i in range(num_frames):
q_start = i * frame_seqlen
q_end = (i + 1) * frame_seqlen
# Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
## global attention for the first num_frame_for_scale frames
if num_frame_for_scale > 0:
mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
if self.fused_attn:
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=mask
)
else:
# Apply RoPE to current k before caching
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif enable_3d_rope and pos is not None:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Check if we should skip appending to cache (non-keyframe in keyframe mode)
skip_append = kv_cache.get("_skip_append", False)
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
if not skip_append:
# KEYFRAME: store in cache (original behavior)
if kv_cache[f"k_{global_idx}"] is None:
kv_cache[f"k_{global_idx}"] = k_reshaped
kv_cache[f"v_{global_idx}"] = v_reshaped
else:
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
# Apply sliding window eviction BEFORE attention to match causal_3drope behavior
# This ensures current frame only attends to frames within the sliding window
self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
# Retrieve full k, v from cache (already RoPE-applied, already evicted)
k = kv_cache[f"k_{global_idx}"].clone()
v = kv_cache[f"v_{global_idx}"].clone()
else:
# NON-KEYFRAME: attend to [cached + current] without storing in cache
if kv_cache[f"k_{global_idx}"] is not None:
k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
else:
k = k_reshaped
v = v_reshaped
a, b, c, d, e = k.shape
k = k.reshape(a, b, c*d, e)
v = v.reshape(a, b, c*d, e)
# Prepend special tokens (camera + scale) from evicted frames if they exist
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
special_v = kv_cache[f"v_{global_idx}_special"]
sa, sb, sc, sd, se = special_k.shape
special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
special_v = special_v.reshape(sa, sb, sc * sd, se)
# Prepend special tokens (older frames first)
k = torch.cat([special_k, k], dim=2)
v = torch.cat([special_v, v], dim=2)
# Note: k from cache is already RoPE-applied, no need to apply again
if self.fused_attn:
# Use mask-based SDPA to ensure same kernel as batch mode
# The causal constraint is enforced by KV cache contents, not by mask
mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
x = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
attn_mask=mask,
)
if self.gate_proj is not None:
x = x * torch.sigmoid(gate_score)
if enable_ulysses_cp:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
# Use actual dimensions from attention output, not original input C
# x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
"""
Apply sliding window eviction to KV cache BEFORE attention.
This ensures current frame only attends to frames within the sliding window,
matching the behavior of causal_3drope's attention mask.
"""
sliding_window_frames = self.kv_cache_sliding_window
scale_frames = self.kv_cache_scale_frames
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
if num_cached_frames > sliding_window_frames + scale_frames:
evict_start = scale_frames
evict_end = num_cached_frames - sliding_window_frames
if evict_end > evict_start:
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
if self.kv_cache_cross_frame_special:
if self.kv_cache_camera_only:
# Only keep camera token
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
else:
# Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
# Special tokens are in range [camera_token_idx, scale_token_idx+1)
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
kv_cache[f"k_{global_idx}_special"] = new_special_k
kv_cache[f"v_{global_idx}_special"] = new_special_v
else:
kv_cache[f"k_{global_idx}_special"] = torch.cat(
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
kv_cache[f"v_{global_idx}_special"] = torch.cat(
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
if self.kv_cache_include_scale_frames:
kv_cache[f"k_{global_idx}"] = torch.cat([
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat([
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
else:
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
class FlashInferAttention(Attention):
"""
FlashInfer variant of the GCT attention layer.
Uses FlashInferKVCacheManager for paged KV cache storage and
FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
Supports the same optimized token layout and KV cache streaming inference.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
# KV cache eviction parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
if not FLASHINFER_AVAILABLE:
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
super().__init__(
dim=dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
# Store KV cache eviction parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
"""Fused pre-attention ops for single-frame streaming (Phase 2).
Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
[tpf, H, D] format ready for append_frame + compute_attention.
Extracted as a method so torch.compile can capture all pre-attn ops as one
CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
squeeze/permute/contiguous x3).
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None: # enable_3d_rope=True
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
return q_nhd, k_nhd, v_nhd
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
# KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
"""
Forward pass with FlashInfer paged KV cache and attention.
Args:
x: Input tensor [B, N, C]
kv_cache: FlashInferKVCacheManager instance or None (batch mode)
global_idx: Block index for per-block cache access
"""
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
B, N, C = x.shape
# Detect if using optimized layout
using_optimized_layout = (num_patches is not None and num_special is not None
and num_frames is not None)
# ========== Batch Mode (no KV cache manager) ==========
if not isinstance(kv_cache, FlashInferKVCacheManager):
# [3, B, num_heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
q, k = self.q_norm(q), self.k_norm(k)
if enable_ulysses_cp:
if using_optimized_layout:
boundary = num_frames * num_patches
q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
)
q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
q_special, k_special, v_special, seq_dim=2, head_dim=1
)
q = torch.cat([q_patch, q_special], dim=2)
k = torch.cat([k_patch, k_special], dim=2)
v = torch.cat([v_patch, v_special], dim=2)
else:
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
# Batch mode: use SDPA for numerical consistency with SDPA variant
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
if enable_ulysses_cp:
if using_optimized_layout:
seq_global = x.shape[2]
seq_local = num_frames * (num_patches + num_special)
cp_size = seq_global // seq_local
boundary_global = num_frames * cp_size * num_patches
x_patch = x[:, :, :boundary_global, :]
x_special = x[:, :, boundary_global:, :]
x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
x = torch.cat([x_patch, x_special], dim=2)
else:
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# ========== Streaming Mode (with FlashInferKVCacheManager) ==========
else:
manager = kv_cache # FlashInferKVCacheManager
# Phase 1 (scale frames): num_frames > 1 — multi-frame batch
# Phase 2 (streaming): num_frames == 1 — single frame
is_multi_frame = (num_frames is not None and num_frames > 1)
if is_multi_frame:
# Phase 1: compute full self-attention via SDPA (all frames attend to each other),
# then append each frame's K/V to the paged cache one at a time.
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# Apply RoPE before caching (RoPE baked into K before append)
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# Append each frame's K/V to the paged cache individually.
tpf = manager.tokens_per_frame
k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
v_all = v.squeeze(0).permute(1, 0, 2)
for f_idx in range(num_frames):
s = f_idx * tpf
manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
cross_frame_special=self.kv_cache_cross_frame_special,
include_scale_frames=self.kv_cache_include_scale_frames,
camera_only=self.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
else:
# Phase 2: single-frame streaming via FlashInfer paged attention.
q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
# 1. Append to paged cache
manager.append_frame(global_idx, k_nhd, v_nhd)
# 2. Apply sliding window eviction
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.kv_cache_scale_frames,
sliding_window=self.kv_cache_sliding_window,
cross_frame_special=self.kv_cache_cross_frame_special,
include_scale_frames=self.kv_cache_include_scale_frames,
camera_only=self.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
# 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
x = manager.compute_attention(global_idx, q_nhd)
# Convert back: [tpf, H, D] -> [B, tpf, C].
x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SDPAAttention(Attention):
"""
SDPA variant for streaming inference.
Uses F.scaled_dot_product_attention with dict-based KV cache.
No FlashInfer dependency required — works on any CUDA GPU.
"""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: nn.Module = nn.LayerNorm,
qk_norm: bool = False,
fused_attn: bool = True,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__(
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
)
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
B, N, C = x.shape
using_optimized_layout = (num_patches is not None and num_special is not None
and num_frames is not None)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# ========== Batch Mode (no KV cache) ==========
if kv_cache is None:
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
# ========== Streaming Mode (with KV cache dict) ==========
else:
if self.rope is not None and not enable_3d_rope:
q = self.rope(q, pos)
k = self.rope(k, pos)
elif self.rope is not None and enable_3d_rope:
q = apply_rotary_emb(q, pos)
k = apply_rotary_emb(k, pos)
camera_token_idx = 0
scale_token_idx = camera_token_idx + num_register_tokens + 1
if kv_cache[f"k_{global_idx}"] is None:
kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
N // num_frame_per_block, self.head_dim)
kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
N // num_frame_per_block, self.head_dim)
else:
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
kv_cache[f"k_{global_idx}"] = torch.cat((
kv_cache[f"k_{global_idx}"],
k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
), dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat((
kv_cache[f"v_{global_idx}"],
v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
), dim=2)
self._apply_kv_cache_eviction(
kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
)
k_cached = kv_cache[f"k_{global_idx}"].clone()
v_cached = kv_cache[f"v_{global_idx}"].clone()
a, b, c, d, e = k_cached.shape
k_full = k_cached.reshape(a, b, c * d, e)
v_full = v_cached.reshape(a, b, c * d, e)
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
special_k = kv_cache[f"k_{global_idx}_special"]
special_v = kv_cache[f"v_{global_idx}_special"]
sa, sb, sc, sd, se = special_k.shape
k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
q_seq_len = q.shape[2]
x = F.scaled_dot_product_attention(
q, k_full, v_full,
dropout_p=self.attn_drop.p if self.training else 0.0,
)
x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
"""Apply sliding window eviction to KV cache."""
sliding_window_frames = self.kv_cache_sliding_window
scale_frames = self.kv_cache_scale_frames
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
if num_cached_frames > sliding_window_frames + scale_frames:
evict_start = scale_frames
evict_end = num_cached_frames - sliding_window_frames
if evict_end > evict_start:
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
if self.kv_cache_cross_frame_special:
if self.kv_cache_camera_only:
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
else:
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
kv_cache[f"k_{global_idx}_special"] = new_special_k
kv_cache[f"v_{global_idx}_special"] = new_special_v
else:
kv_cache[f"k_{global_idx}_special"] = torch.cat(
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
kv_cache[f"v_{global_idx}_special"] = torch.cat(
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
if self.kv_cache_include_scale_frames:
kv_cache[f"k_{global_idx}"] = torch.cat([
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
kv_cache[f"v_{global_idx}"] = torch.cat([
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
], dim=2)
else:
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]

514
lingbot_map/layers/block.py Normal file
View File

@@ -0,0 +1,514 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings
import math
import torch
from torch import nn, Tensor
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
from functools import lru_cache, partial
from torch.nn.attention.flex_attention import BlockMask, create_mask
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
qk_norm=qk_norm,
fused_attn=fused_attn,
rope=rope,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
x = drop_add_residual_stochastic_depth(
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
if pos is not None:
# if necessary, apply rope to the subset
pos = pos[brange]
residual = residual_func(x_subset, pos=pos)
else:
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
class FlashInferBlock(nn.Module):
"""
FlashInfer variant of causal block for GCT.
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
Supports optimized token layout and KV cache streaming inference.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = FlashInferAttention(
dim=dim,
num_heads=num_heads,
qk_norm=qk_norm,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
Returns:
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
ready for manager.append_frame + manager.compute_attention.
"""
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
def forward(
self,
x: Tensor,
pos=None,
enable_ulysses_cp=False,
num_patches=None,
num_special=None,
num_frames=None,
enable_3d_rope=False,
kv_cache=None,
global_idx=0,
num_frame_per_block=1,
num_frame_for_scale=-1,
num_register_tokens=4,
) -> Tensor:
# Phase 2 (streaming): single-frame FlashInfer paged attention.
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
if is_streaming:
manager = kv_cache
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
# Eager: write frame K/V to paged cache
manager.append_frame(global_idx, k_nhd, v_nhd)
# CPU-only: update eviction state (deque ops, no GPU kernel)
manager.evict_frames(
block_idx=global_idx,
scale_frames=self.attn.kv_cache_scale_frames,
sliding_window=self.attn.kv_cache_sliding_window,
cross_frame_special=self.attn.kv_cache_cross_frame_special,
include_scale_frames=self.attn.kv_cache_include_scale_frames,
camera_only=self.attn.kv_cache_camera_only,
num_register_tokens=num_register_tokens,
)
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
attn_x = manager.compute_attention(global_idx, q_nhd)
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
self.attn.num_heads * self.attn.head_dim)
# Compiled: output projection
attn_x = self.attn.proj(attn_x)
x = x + self.ls1(attn_x)
else:
# Phase 1 (multi-frame scale pass) or non-streaming training path
x = x + self.ls1(self.attn(
self.norm1(x),
pos=pos,
enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches,
num_special=num_special,
num_frames=num_frames,
enable_3d_rope=enable_3d_rope,
kv_cache=kv_cache,
global_idx=global_idx,
num_frame_per_block=num_frame_per_block,
num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
x = self.ffn_residual(x)
return x
def ffn_residual(self, x: Tensor) -> Tensor:
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
Includes the residual add (x + ...) so torch.compile captures the entire
ffn branch as one CUDA graph.
"""
return x + self.ls2(self.mlp(self.norm2(x)))
class CameraBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
rope=None,
elementwise_attn_output_gate: bool = False,
sliding_window_size: int = -1,
attend_to_scale_frames: bool = False,
num_random_frames: int = 0,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
qk_norm=qk_norm, qkv_bias=qkv_bias,
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only)
self.sliding_window_size = sliding_window_size
self.attend_to_scale_frames = attend_to_scale_frames
self.num_random_frames = num_random_frames
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
self.masks = {}
@torch.no_grad()
def _prepare_blockwise_causal_attn_mask(self,
device: torch.device | str, num_frames: int = 21,
frame_seqlen: int = 1560, num_frame_per_block=1
) -> BlockMask:
"""
we will divide the token sequence into the following format
[1 latent frame] [1 latent frame] ... [1 latent frame]
We use flexattention to construct the attention mask
"""
total_length = num_frames * frame_seqlen
# we do right padding to get to a multiple of 128
padded_length = math.ceil(total_length / 128) * 128 - total_length
ends = torch.zeros(total_length + padded_length,
device=device, dtype=torch.long)
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
frame_indices = torch.arange(
start=0,
end=total_length,
step=frame_seqlen * num_frame_per_block,
device=device
)
for tmp in frame_indices:
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
frame_seqlen * num_frame_per_block
def attention_mask(b, h, q_idx, kv_idx):
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
KV_LEN=total_length + padded_length, device=device)
return block_mask
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
# Fast path for full attention (camera head) - skip mask computation
if full_attention:
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
mask_block = self._prepare_blockwise_causal_attn_mask(
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x
class SDPABlock(nn.Module):
"""
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
with dict-based KV cache. No FlashInfer dependency required.
"""
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
ffn_layer: Callable[..., nn.Module] = Mlp,
qk_norm: bool = False,
rope=None,
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SDPAAttention(
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
kv_cache_sliding_window=kv_cache_sliding_window,
kv_cache_scale_frames=kv_cache_scale_frames,
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
kv_cache_camera_only=kv_cache_camera_only,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
act_layer=act_layer, drop=drop, bias=ffn_bias)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
kv_cache=None, global_idx=0, num_frame_per_block=1,
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
def attn_residual_func(x, pos=None):
return self.ls1(self.attn(
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
num_register_tokens=num_register_tokens,
))
def ffn_residual_func(x):
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
x = x + self.drop_path1(ffn_residual_func(x))
else:
x = x + attn_residual_func(x, pos=pos)
x = x + ffn_residual_func(x)
return x

View File

@@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,582 @@
"""
FlashInfer KV Cache Manager — Two-Stream Paged Design.
Two logical streams sharing one physical page pool per layer:
Patch stream (recyclable):
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
- Exactly 1 patch page per frame
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
Special stream (append-only, never recycled):
- num_special_tokens (6) special tokens per frame
- Packed continuously: one special page holds floor(page_size/6) frames
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
- Specials written for EVERY frame (including scale + window), not just evicted ones.
Physical layout per block:
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
dim 1: 0=K 1=V
Attention computation:
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
Special pages placed LAST → paged_kv_last_page_len naturally describes
the partial special-tail without a custom mask.
plan() is called ONCE per frame step (when block_idx == 0).
run() is called per layer, reusing the same plan. All layers at the
same frame step have identical page structures (same page IDs in same
positions), so reusing the plan across layers is correct.
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
append_frame(block_idx, k, v)
evict_frames(block_idx, scale_frames, sliding_window, ...)
compute_attention(block_idx, q) -> out
reset()
"""
import collections
import math
from typing import List
import torch
from torch import Tensor
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
class FlashInferKVCacheManager:
"""
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
Args:
num_blocks: Number of Transformer blocks (one cache per block).
max_num_frames: Maximum frames held in the KV window at once
(scale_frames + sliding_window + headroom).
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
num_heads: Number of KV heads (= QO heads; MHA assumed).
head_dim: Head dimension (64 for ViT-L).
dtype: Storage dtype (bfloat16 / float16).
device: CUDA device.
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
scale_frames: Number of always-resident scale frames (8).
sliding_window: Sliding window size (64).
max_total_frames: Upper bound on total frames ever processed; used to
pre-allocate the special page pool (default 2048).
"""
def __init__(
self,
num_blocks: int,
max_num_frames: int,
tokens_per_frame: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype,
device: torch.device,
num_special_tokens: int = 6,
scale_frames: int = 8,
sliding_window: int = 64,
max_total_frames: int = 2048,
force_fp32: bool = False,
fa3: bool = False,
):
if not FLASHINFER_AVAILABLE:
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
self.num_blocks = num_blocks
self.num_special_tokens = num_special_tokens # 6
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
p = self.patches_per_frame
if fa3:
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
self.page_size = 1 << (p - 1).bit_length()
else:
self.page_size = p # exact: no zero padding in patch pages
self.scale_frames = scale_frames # 8
self.sliding_window = sliding_window # 64
self.num_heads = num_heads
self.head_dim = head_dim
self.tokens_per_frame = tokens_per_frame
assert self.patches_per_frame > 0, (
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
)
assert self.page_size > 0
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
self.force_fp32 = force_fp32
if force_fp32:
self.dtype = torch.float32
else:
if dtype == torch.float32:
dtype = torch.bfloat16
self.dtype = dtype
self.device = device
# ── Page pool sizing ─────────────────────────────────────────────────
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
max_special_pages = (
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
)
self.max_patch_pages = max_patch_pages
self.max_num_pages = max_patch_pages + max_special_pages
# ── Physical paged KV caches ─────────────────────────────────────────
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
self.kv_caches: List[Tensor] = [
torch.zeros(
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
dtype=dtype, device=device,
)
for _ in range(num_blocks)
]
# ── Per-block state ──────────────────────────────────────────────────
# Patch pages (IDs 0 .. max_patch_pages-1)
self.scale_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.live_window_patch_pages: List[collections.deque] = [
collections.deque() for _ in range(num_blocks)
]
self.free_patch_pages: List[List[int]] = [
list(range(max_patch_pages)) for _ in range(num_blocks)
]
# Special pages (IDs max_patch_pages .. max_num_pages-1)
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
self.free_special_pages: List[List[int]] = [
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
]
self.special_token_count: List[int] = [0] * num_blocks
# Frame counter per block (determines scale vs window routing)
self.frame_count: List[int] = [0] * num_blocks
# ── FlashInfer wrapper ───────────────────────────────────────────────
# plan() is called once per frame step (block_idx == 0).
# run() is called per layer, reusing the same aux structures.
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
# FlashInfer 0.2.5 at 518×378 resolution.
_fi_backend = "fa3" if fa3 else "fa2"
self.workspace_buffer = torch.zeros(
128 * 1024 * 1024, dtype=torch.uint8, device=device
)
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
kv_layout="NHD",
backend=_fi_backend,
)
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
self._qo_indptr = torch.tensor(
[0, tokens_per_frame], dtype=torch.int32, device=device
)
# =========================================================================
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
# =========================================================================
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
"""
Append one frame's K/V tensors to the two-stream cache.
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
i.e. specials come first (matching stream.py's patch_start_idx convention).
Args:
block_idx: Block/layer index (0 … num_blocks-1).
k: [tokens_per_frame, H, D] NHD layout.
v: [tokens_per_frame, H, D] NHD layout.
"""
n = self.num_special_tokens # 6
sp_k = k[:n].to(self.dtype) # [6, H, D]
patch_k = k[n:].to(self.dtype) # [256, H, D]
sp_v = v[:n].to(self.dtype)
patch_v = v[n:].to(self.dtype)
assert patch_k.shape[0] == self.patches_per_frame, (
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
)
self._write_patch_page(block_idx, patch_k, patch_v)
self._write_special_tokens(block_idx, sp_k, sp_v)
self.frame_count[block_idx] += 1
def evict_frames(
self,
block_idx: int,
scale_frames: int,
sliding_window: int,
cross_frame_special: bool = True,
include_scale_frames: bool = True,
camera_only: bool = False,
num_register_tokens: int = 4,
) -> None:
"""
Evict old window patch pages (recycle to free list).
Special pages are NEVER evicted.
Scale pages are NEVER evicted.
Only live_window_patch_pages beyond `sliding_window` are recycled.
"""
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
old_page = self.live_window_patch_pages[block_idx].popleft()
self.free_patch_pages[block_idx].append(old_page)
def _gather_kv(self, block_idx: int):
"""
Gather all visible K and V tokens from the paged cache into dense tensors.
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
Returns:
k_flat: [kv_len, H, D] — all visible K tokens concatenated
v_flat: [kv_len, H, D] — all visible V tokens concatenated
"""
visible = self.build_visible_page_table(block_idx)
last_len = self.compute_last_page_len(block_idx)
P = self.page_size
parts_k, parts_v = [], []
for i, pid in enumerate(visible):
n = last_len if (i == len(visible) - 1) else P
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
v_flat = torch.cat(parts_v, dim=0)
return k_flat, v_flat
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
"""
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
When self.force_fp32 is True, gathers all visible K/V into dense tensors
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
plan() is called once per frame step (when block_idx == 0).
All layers at the same step share the same visible page structure,
so the plan is reused by calling run() with each layer's kv_cache.
Args:
block_idx: Block/layer index.
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
Returns:
out: [q_len, H, D]
"""
if self.frame_count[block_idx] == 0:
# No KV present yet (should not occur in normal usage after append_frame)
return torch.zeros_like(q)
if self.force_fp32:
# ── fp32 gather+SDPA path ─────────────────────────────────────────
# Gather visible K/V from paged cache and run SDPA in fp32.
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
import torch.nn.functional as F_nn
k_flat, v_flat = self._gather_kv(block_idx)
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
if block_idx == 0:
# ── Plan once per frame step ──────────────────────────────────────
# Build visible page table from block 0's state.
# All blocks have identical page structures, so this plan is valid
# for all subsequent run() calls (block_idx = 1, 2, ...).
visible = self.build_visible_page_table(0)
last_len = self.compute_last_page_len(0)
assert visible, "visible page table is empty after append_frame"
assert 1 <= last_len <= self.page_size, (
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
)
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
self.prefill_wrapper.plan(
self._qo_indptr,
paged_kv_indptr,
paged_kv_indices,
paged_kv_last_page_len,
num_qo_heads = self.num_heads,
num_kv_heads = self.num_heads,
head_dim_qk = self.head_dim,
page_size = self.page_size,
causal = False, # custom page ordering; no causal mask
pos_encoding_mode = "NONE", # RoPE applied externally before append
q_data_type = self.dtype,
)
# ── Run attention for this layer ──────────────────────────────────────
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
return self.prefill_wrapper.run(
q = q.to(self.dtype).contiguous(),
paged_kv_cache = self.kv_caches[block_idx],
) # → [q_len, H, D]
def reset(self) -> None:
"""Reset all per-block state for a new sequence."""
for i in range(self.num_blocks):
self.scale_patch_pages[i].clear()
self.live_window_patch_pages[i].clear()
self.all_special_pages[i].clear()
self.free_patch_pages[i] = list(range(self.max_patch_pages))
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
self.special_token_count[i] = 0
self.frame_count[i] = 0
# =========================================================================
# Helper methods
# =========================================================================
def build_visible_page_table(self, block_idx: int) -> List[int]:
"""
Return page IDs in strict order: scale → window → special.
Placing special pages last means only the final page may be partially
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
without a custom attention mask.
"""
return (
list(self.scale_patch_pages[block_idx]) +
list(self.live_window_patch_pages[block_idx]) +
list(self.all_special_pages[block_idx])
)
def compute_last_page_len(self, block_idx: int) -> int:
"""
Valid token count in the last page of the visible sequence.
- No special pages → last page is a patch page.
Returns patches_per_frame (real tokens written),
which may be < page_size when page_size was rounded
up to a power of 2.
- Special tail partial → special_token_count % page_size.
- Special tail exactly full → page_size.
"""
if not self.all_special_pages[block_idx]:
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
# valid count so it doesn't read beyond the real tokens.
return self.patches_per_frame
tail = self.special_token_count[block_idx] % self.page_size
return self.page_size if tail == 0 else tail
# ── Internal write helpers ────────────────────────────────────────────────
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
"""
Allocate one free patch page and write patches_per_frame patch tokens.
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
patch_k/v fill exactly one full page (patches_per_frame == page_size).
Routes to scale_patch_pages if still filling scale quota,
otherwise to live_window_patch_pages.
Returns:
page_id: Physical page index used.
"""
assert self.free_patch_pages[block_idx], (
f"block {block_idx}: patch page pool exhausted — "
f"scale={len(self.scale_patch_pages[block_idx])}, "
f"window={len(self.live_window_patch_pages[block_idx])}, "
f"free={len(self.free_patch_pages[block_idx])}"
)
page_id = self.free_patch_pages[block_idx].pop()
# Direct slice write: positions 0..patches_per_frame-1.
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
# this is equivalent to a full-page write. When page_size > patches_per_frame
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
P = self.patches_per_frame
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
self.scale_patch_pages[block_idx].append(page_id)
else:
self.live_window_patch_pages[block_idx].append(page_id)
return page_id
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
"""
Append num_special_tokens (6) special tokens to the special stream.
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
overhead of flashinfer.page.append_paged_kv_cache.
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
two slice writes (rare — page_size=256 >> 6).
"""
remaining = self.num_special_tokens # 6
written = 0
while remaining > 0:
tail_offset = self.special_token_count[block_idx] % self.page_size
if tail_offset == 0:
# Current tail page is full (or no page exists) — allocate a new one
assert self.free_special_pages[block_idx], (
f"block {block_idx}: special page pool exhausted at "
f"special_token_count={self.special_token_count[block_idx]}. "
f"Increase max_total_frames."
)
new_page = self.free_special_pages[block_idx].pop()
self.all_special_pages[block_idx].append(new_page)
tail_page = self.all_special_pages[block_idx][-1]
space = self.page_size - tail_offset # free slots in tail page
write_n = min(remaining, space)
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
end = tail_offset + write_n
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
self.special_token_count[block_idx] += write_n
written += write_n
remaining -= write_n
# ── Legacy property (used by stream.py) ──────────────────────────────────
@property
def num_frames(self) -> int:
"""Number of frames appended to block 0 (representative)."""
return self.frame_count[0] if self.frame_count else 0
# =============================================================================
# Sanity check
# =============================================================================
def _sanity_check():
"""
Minimal smoke test.
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
print("[sanity_check] CUDA not available — skipping.")
return
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
num_special = 6
patches_per_frame = tokens_per_frame - num_special # 256
page_size = patches_per_frame # 256
mgr = FlashInferKVCacheManager(
num_blocks = 2,
max_num_frames = 88,
tokens_per_frame = tokens_per_frame,
num_heads = 16,
head_dim = 64,
dtype = torch.bfloat16,
device = device,
num_special_tokens = num_special,
scale_frames = 8,
sliding_window = 64,
max_total_frames = 200,
)
def make_kv():
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
return k, v
def make_q():
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
for block in range(2):
for t in range(100):
k, v = make_kv()
mgr.append_frame(block, k, v)
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
# ── Page count checks ───────────────────────────────────────────────
n_scale = len(mgr.scale_patch_pages[block])
n_window = len(mgr.live_window_patch_pages[block])
n_spec = len(mgr.all_special_pages[block])
sp_count = mgr.special_token_count[block]
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
expected_spec_pages = math.ceil(100 * num_special / page_size)
assert n_spec == expected_spec_pages, (
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
)
assert sp_count == 100 * num_special, (
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
)
# ── last_page_len ────────────────────────────────────────────────────
last_len = mgr.compute_last_page_len(block)
tail = sp_count % page_size
expected_len = page_size if tail == 0 else tail
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
# ── visible page table order ─────────────────────────────────────────
visible = mgr.build_visible_page_table(block)
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
for pid in visible[:n_scale + n_window]:
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
for pid in visible[n_scale + n_window:]:
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
# ── forward pass: plan() once for block 0, run() for both blocks ─────
if block == 1:
# Simulate the actual calling pattern: plan on block 0, run on both
q0 = make_q()
out0 = mgr.compute_attention(0, q0) # triggers plan()
q1 = make_q()
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
assert out0.shape == (tokens_per_frame, 16, 64)
assert out1.shape == (tokens_per_frame, 16, 64)
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
f"special_pages={n_spec}, special_tokens={sp_count}, "
f"last_page_len={last_len}")
mgr.reset()
assert mgr.frame_count[0] == 0
print("\n[sanity_check] All assertions passed.")
if __name__ == "__main__":
_sanity_check()

View File

@@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

40
lingbot_map/layers/mlp.py Normal file
View File

@@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

474
lingbot_map/layers/rope.py Normal file
View File

@@ -0,0 +1,474 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# Implementation of 2D Rotary Position Embeddings (RoPE).
# This module provides a clean implementation of 2D Rotary Position Embeddings,
# which extends the original RoPE concept to handle 2D spatial positions.
# Inspired by:
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
# https://github.com/naver-ai/rope-vit
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple
from typing import List, Optional, Tuple, Union
class PositionGetter:
"""Generates and caches 2D spatial positions for patches in a grid.
This class efficiently manages the generation of spatial coordinates for patches
in a 2D grid, caching results to avoid redundant computations.
Attributes:
position_cache: Dictionary storing precomputed position tensors for different
grid dimensions.
"""
def __init__(self):
"""Initializes the position generator with an empty cache."""
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
"""Generates spatial positions for a batch of patches.
Args:
batch_size: Number of samples in the batch.
height: Height of the grid in patches.
width: Width of the grid in patches.
device: Target device for the position tensor.
Returns:
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
for each position in the grid, repeated for each batch item.
"""
if (height, width) not in self.position_cache:
y_coords = torch.arange(height, device=device)
x_coords = torch.arange(width, device=device)
positions = torch.cartesian_prod(y_coords, x_coords)
self.position_cache[height, width] = positions
cached_positions = self.position_cache[height, width]
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(nn.Module):
"""2D Rotary Position Embedding implementation.
This module applies rotary position embeddings to input tokens based on their
2D spatial positions. It handles the position-dependent rotation of features
separately for vertical and horizontal dimensions.
Args:
frequency: Base frequency for the position embeddings. Default: 100.0
scaling_factor: Scaling factor for frequency computation. Default: 1.0
Attributes:
base_frequency: Base frequency for computing position embeddings.
scaling_factor: Factor to scale the computed frequencies.
frequency_cache: Cache for storing precomputed frequency components.
"""
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
"""Initializes the 2D RoPE module."""
super().__init__()
self.base_frequency = frequency
self.scaling_factor = scaling_factor
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
def _compute_frequency_components(
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes frequency components for rotary embeddings.
Args:
dim: Feature dimension (must be even).
seq_len: Maximum sequence length.
device: Target device for computations.
dtype: Data type for the computed tensors.
Returns:
Tuple of (cosine, sine) tensors for frequency components.
"""
cache_key = (dim, seq_len, device, dtype)
if cache_key not in self.frequency_cache:
# Compute frequency bands
exponents = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency**exponents)
# Generate position-dependent frequencies
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
angles = torch.einsum("i,j->ij", positions, inv_freq)
# Compute and cache frequency components
angles = angles.to(dtype)
angles = torch.cat((angles, angles), dim=-1)
cos_components = angles.cos().to(dtype)
sin_components = angles.sin().to(dtype)
self.frequency_cache[cache_key] = (cos_components, sin_components)
return self.frequency_cache[cache_key]
@staticmethod
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
"""Performs feature rotation by splitting and recombining feature dimensions.
Args:
x: Input tensor to rotate.
Returns:
Rotated feature tensor.
"""
feature_dim = x.shape[-1]
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d_rope(
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
) -> torch.Tensor:
"""Applies 1D rotary position embeddings along one dimension.
Args:
tokens: Input token features.
positions: Position indices.
cos_comp: Cosine components for rotation.
sin_comp: Sine components for rotation.
Returns:
Tokens with applied rotary position embeddings.
"""
# Embed positions with frequency components
cos = F.embedding(positions, cos_comp)[:, None, :, :]
sin = F.embedding(positions, sin_comp)[:, None, :, :]
# Apply rotation
return (tokens * cos) + (self._rotate_features(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Applies 2D rotary position embeddings to input tokens.
Args:
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
The feature dimension (dim) must be divisible by 4.
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
the y and x coordinates for each token.
Returns:
Tensor of same shape as input with applied 2D rotary position embeddings.
Raises:
AssertionError: If input dimensions are invalid or positions are malformed.
"""
# Validate inputs
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
# Compute feature dimension for each spatial direction
feature_dim = tokens.size(-1) // 2
# Get frequency components
max_position = int(positions.max()) + 1
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
# Split features for vertical and horizontal processing
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
# Apply RoPE separately for each dimension
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
# Combine processed features
return torch.cat((vertical_features, horizontal_features), dim=-1)
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
计算1D旋转位置编码RoPE的频率张量。
RoPE的核心思想使用旋转矩阵来编码位置信息使得相对位置关系保持不变。
公式对于位置m和维度i频率为 θ_i = θ^(-2i/d)其中θ是基础频率默认10000
Args:
dim: 特征维度,必须是偶数(因为要成对处理)
pos: 位置索引可以是整数自动生成0到pos-1的序列或位置数组 [S]
theta: 基础频率控制位置编码的周期性默认10000
use_real: 是否返回实数形式cos和sin分开还是复数形式
linear_factor: 线性缩放因子,用于上下文扩展
ntk_factor: NTK-Aware缩放因子用于处理更长的序列
repeat_interleave_real: 当use_real=True时是否交错重复用于某些模型架构
freqs_dtype: 频率张量的数据类型
Returns:
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
实数形式:两个 [S, D] 的张量cos和sin
"""
# 确保维度是偶数RoPE需要成对处理维度
assert dim % 2 == 0
# 将位置转换为torch张量
if isinstance(pos, int):
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # [S]
# 应用NTK缩放Neural Tangent Kernel用于处理训练时未见过的长序列
theta = theta * ntk_factor
# 步骤1计算频率 θ_i = 1 / (θ^(2i/d))
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
# 公式freq_i = 1 / (theta^(2i/d) * linear_factor)
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2],每个频率对应一个维度对
# 步骤2计算位置-频率矩阵
# 使用外积pos[m] * freqs[i] = m * θ_i
# 结果每个位置m和每个频率i的组合
freqs = torch.outer(pos, freqs) # [S, D/2]
# 步骤3根据返回格式转换
if use_real and repeat_interleave_real:
# 方式1交错重复用于flux, hunyuan-dit, cogvideox等模型
# 将每个频率的cos和sin交错排列[cos_0, cos_0, cos_1, cos_1, ...]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# 方式2拼接重复用于stable audio, allegro等模型
# 将所有cos拼接然后是所有sin[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# 方式3复数形式用于lumina等模型
# 使用欧拉公式e^(iθ) = cos(θ) + i*sin(θ)
# torch.polar(r, θ) 返回 r * e^(iθ)这里r=1所以就是 e^(i*freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
return freqs_cis
class WanRotaryPosEmbed(nn.Module):
"""
3D旋转位置编码3D RoPE模块
核心思想将RoPE扩展到3D空间时间、高度、宽度为视频或3D数据提供位置编码。
每个维度t, h, w独立使用RoPE然后拼接起来。
公式:
对于3D位置 (f, h, w)(帧、高度、宽度):
- 帧维度使用 dim_f 个特征维度
- 高度维度使用 dim_h 个特征维度
- 宽度维度使用 dim_w 个特征维度
其中 dim_f + dim_h + dim_w = attention_head_dim
"""
def __init__(
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int = 1024,
theta: float = 10000.0,
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
):
super().__init__()
self.attention_head_dim = attention_head_dim # 注意力头的总维度
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
# 步骤1分配维度给三个空间维度
if fhw_dim is not None:
# 如果指定了维度分配,使用指定的
assert attention_head_dim == sum(
fhw_dim
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
t_dim, h_dim, w_dim = fhw_dim
else:
# 否则自动分配h和w各占1/3t占剩余
# 例如如果attention_head_dim=64则 h_dim=w_dim=21t_dim=22
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
# 保存维度分配以便在forward中使用
self.fhw_dim = (t_dim, h_dim, w_dim)
# 步骤2为每个维度预计算频率
# 分别计算时间、高度、宽度三个维度的RoPE频率
freqs = []
for dim in [t_dim, h_dim, w_dim]:
# 每个维度独立调用1D RoPE
# 返回复数形式的频率: [max_seq_len, dim//2]
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
freqs.append(freq)
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
self.freqs = torch.cat(freqs, dim=1)
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
"""
前向传播为3D输入视频帧+patch生成旋转位置编码
参数:
- ppf (int): 帧数patches per frame当f_end为None时使用
- pph (int): 每帧的patch高度数量
- ppw (int): 每帧的patch宽度数量
- patch_start_idx (int): 每帧的特殊token数量在patches之前
- device: 计算设备CPU/GPU
- f_start (int): 起始帧索引用于causal模式默认为0
- f_end (Optional[int]): 结束帧索引用于causal模式如果为None则使用ppf作为帧数
返回:
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
Token排列顺序
[frame0_special_token_0, ..., frame0_special_token_N,
frame0_patch_0, ..., frame0_patch_M,
frame1_special_token_0, ..., frame1_special_token_N,
frame1_patch_0, ..., frame1_patch_M,
...]
模式:
- 非causal模式f_end=None使用ppf作为帧数从位置0开始
- Causal模式f_end不为None使用[f_start, f_end)范围的帧ppf会被重新计算
"""
# 步骤1将预计算的频率移到目标设备并分割成三个维度
self.freqs = self.freqs.to(device)
# 获取实际的维度分配
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
t_dim, h_dim, w_dim = self.fhw_dim
else:
# 自动分配的情况
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
t_dim = self.attention_head_dim - h_dim - w_dim
# 使用正确的split sizes每个维度的一半
freqs = self.freqs.split_with_sizes(
[
t_dim // 2, # 时间维度
h_dim // 2, # 高度维度
w_dim // 2, # 宽度维度
],
dim=1,
)
# 处理causal模式如果指定了f_end重新计算ppf和帧范围
if f_end is not None:
ppf = f_end - f_start
frame_slice = slice(f_start, f_end)
else:
# 非causal模式使用从0开始的ppf个帧
frame_slice = slice(0, ppf)
# 步骤2处理特殊token如果存在
## For other tokens
if patch_start_idx > 0:
# 2.1 为特殊token生成位置编码
# 特殊token位于对角线位置 (f, i, i)每个特殊token有唯一位置
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
# Shape: (ppf, patch_start_idx, dim)
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
# 2.2 为图像patch生成位置编码
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w)h,w 整体偏移 patch_start_idx
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
# Shape: (ppf, pph, ppw, dim)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
# 步骤3按照正确的顺序组合特殊token和patches
# 每帧内部顺序:[特殊tokens, patches]
# Concatenate special tokens and patches for each frame along the second dimension
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
# 步骤4展平为最终形状并添加batch和head维度
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
return freqs
# 如果没有特殊tokenpatch_start_idx == 0只处理图像patches
# 所有patches位于 (f, 0:pph, 0:ppw)
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
return freqs
def apply_rotary_emb(x, freqs):
"""
应用旋转位置编码到输入特征
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
数学原理:
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
这等价于旋转矩阵:
[cos(θ) -sin(θ)] [x1]
[sin(θ) cos(θ)] [x2]
参数:
- x: 输入特征 [batch, heads, seq_len, head_dim]
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
返回:
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
实现步骤:
1. 将x的每两个连续特征看作一个复数 (real, imag)
2. 与预计算的复数频率 e^(iθ) 相乘
3. 转回实数表示
"""
# 步骤1reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
# 步骤2转换为复数表示 [b, h, seq, 32]
# 每个元素是 real + imag*i
x_complex = torch.view_as_complex(x_reshaped)
# 步骤3复数乘法实现旋转
# x_complex * freqs 相当于将每对特征旋转θ角度
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
x_rotated = x_complex * freqs
# 步骤4转回实数表示 [b, h, seq, 32, 2]
x_real = torch.view_as_real(x_rotated)
# 步骤5展平最后两维 [b, h, seq, 64]
x_out = x_real.flatten(3)
# 步骤6转回原始数据类型
return x_out.to(x.dtype)

View File

@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import os
from typing import Callable, Optional
import warnings
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
# try:
# if XFORMERS_ENABLED:
# from xformers.ops import SwiGLU
# XFORMERS_AVAILABLE = True
# warnings.warn("xFormers is available (SwiGLU)")
# else:
# warnings.warn("xFormers is disabled (SwiGLU)")
# raise ImportError
# except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
# warnings.warn("xFormers is not available (SwiGLU)")
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)

View File

@@ -0,0 +1,411 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.nn.init import trunc_normal_
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
# TODO: Check this
# We replace NestedTensorBlock with Block
from .block import Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
drop_cls_token=False,
qk_norm=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if not drop_cls_token else 0
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.use_reentrant = False # hardcoded to False
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.drop_cls_token = drop_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
qk_norm=qk_norm,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
if not self.drop_cls_token:
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
else:
patch_pos_embed = pos_embed
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
if not self.drop_cls_token:
x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
else:
x = patch_pos_embed
return x.to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
if self.training:
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
else:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
if self.training:
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
else:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=True, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model

View File

View File

@@ -0,0 +1,359 @@
"""
GCTBase - Base class for GCT model implementations.
Provides shared functionality:
- Prediction heads (camera, depth, point)
- Forward pass structure
- Model hub mixin (PyTorchModelHubMixin)
"""
import logging
import numpy as np
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Union
from huggingface_hub import PyTorchModelHubMixin
from lingbot_map.heads.dpt_head import DPTHead
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
from lingbot_map.utils.geometry import closed_form_inverse_se3
logger = logging.getLogger(__name__)
class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
"""
Base class for GCT model implementations.
Handles shared components:
- Prediction heads (camera, depth, point)
- Forward pass structure
- Input normalization
Subclasses must implement:
- _build_aggregator(): Create mode-specific aggregator
- _build_camera_head(): Create mode-specific camera head
"""
def __init__(
self,
# Architecture parameters
img_size: int = 518,
patch_size: int = 14,
embed_dim: int = 1024,
patch_embed: str = 'dinov2_vitl14_reg',
disable_global_rope: bool = False,
# Head configuration
enable_camera: bool = True,
enable_point: bool = True,
enable_local_point: bool = False,
enable_depth: bool = True,
enable_track: bool = False,
# Camera head sliding window
enable_camera_sliding_window: bool = False,
# 3D RoPE
enable_3d_rope: bool = False,
# Context Parallelism (kept for checkpoint compatibility but not used)
enable_ulysses_cp: bool = False,
# Normalization
enable_normalize: bool = False,
# Prediction normalization
pred_normalization: bool = False,
pred_normalization_detach_scale: bool = False,
# Gradient checkpointing
use_gradient_checkpoint: bool = True,
):
super().__init__()
# Store configuration
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.patch_embed = patch_embed
self.disable_global_rope = disable_global_rope
self.enable_ulysses_cp = False # CP disabled in standalone package
self.enable_normalize = enable_normalize
self.pred_normalization = pred_normalization
self.pred_normalization_detach_scale = pred_normalization_detach_scale
self.use_gradient_checkpoint = use_gradient_checkpoint
# Head flags
self.enable_camera = enable_camera
self.enable_point = enable_point
self.enable_local_point = enable_local_point
self.enable_depth = enable_depth
self.enable_track = enable_track
self.enable_camera_sliding_window = enable_camera_sliding_window
self.enable_3d_rope = enable_3d_rope
# Build aggregator (subclass-specific)
self.aggregator = self._build_aggregator()
# Build prediction heads (subclass-specific)
self.camera_head = self._build_camera_head() if enable_camera else None
self.point_head = self._build_point_head() if enable_point else None
self.local_point_head = self._build_local_point_head() if enable_local_point else None
self.depth_head = self._build_depth_head() if enable_depth else None
@abstractmethod
def _build_aggregator(self) -> nn.Module:
pass
@abstractmethod
def _build_camera_head(self) -> nn.Module:
pass
def _build_depth_head(self) -> nn.Module:
return DPTHead(
dim_in=2 * self.embed_dim,
patch_size=self.patch_size,
output_dim=2,
activation="exp",
conf_activation="expp1"
)
def _build_point_head(self) -> nn.Module:
return DPTHead(
dim_in=2 * self.embed_dim,
patch_size=self.patch_size,
output_dim=4,
activation="inv_log",
conf_activation="expp1"
)
def _build_local_point_head(self) -> nn.Module:
return DPTHead(
dim_in=2 * self.embed_dim,
patch_size=self.patch_size,
output_dim=4,
activation="inv_log",
conf_activation="expp1"
)
def _normalize_input(self, images: torch.Tensor, query_points=None):
if len(images.shape) == 4:
images = images.unsqueeze(0)
if query_points is not None and len(query_points.shape) == 2:
query_points = query_points.unsqueeze(0)
return images, query_points
@abstractmethod
def _aggregate_features(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
view_graphs: Optional[torch.Tensor] = None,
causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
ordered_video: Optional[torch.Tensor] = None,
is_cp_sliced: bool = False,
) -> tuple:
pass
def _predict_camera(
self,
aggregated_tokens_list: list,
mask: Optional[torch.Tensor] = None,
causal_inference: bool = False,
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
gather_outputs: bool = True,
) -> Dict[str, torch.Tensor]:
if self.camera_head is None:
return {}
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
with torch.amp.autocast('cuda', enabled=False):
pose_enc_list = self.camera_head(
aggregated_tokens_list_fp32,
mask=mask,
causal_inference=causal_inference,
num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
sliding_window_size=camera_sliding_window,
num_frame_per_block=num_frame_per_block,
)
return {
"pose_enc": pose_enc_list[-1],
"pose_enc_list": pose_enc_list,
}
def _predict_depth(
self,
aggregated_tokens_list: list,
images: torch.Tensor,
patch_start_idx: int,
gather_outputs: bool = True,
) -> Dict[str, torch.Tensor]:
if self.depth_head is None:
return {}
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
images_fp32 = images.float()
with torch.amp.autocast('cuda', enabled=False):
depth, depth_conf = self.depth_head(
aggregated_tokens_list_fp32,
images=images_fp32,
patch_start_idx=patch_start_idx
)
return {"depth": depth, "depth_conf": depth_conf}
def _predict_points(
self,
aggregated_tokens_list: list,
images: torch.Tensor,
patch_start_idx: int,
gather_outputs: bool = True,
) -> Dict[str, torch.Tensor]:
if self.point_head is None:
return {}
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
images_fp32 = images.float()
with torch.amp.autocast('cuda', enabled=False):
pts3d, pts3d_conf = self.point_head(
aggregated_tokens_list_fp32,
images=images_fp32,
patch_start_idx=patch_start_idx
)
return {"world_points": pts3d, "world_points_conf": pts3d_conf}
def _predict_local_points(
self,
aggregated_tokens_list: list,
images: torch.Tensor,
patch_start_idx: int,
gather_outputs: bool = True,
) -> Dict[str, torch.Tensor]:
if self.local_point_head is None:
return {}
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
images_fp32 = images.float()
with torch.amp.autocast('cuda', enabled=False):
pts3d, pts3d_conf = self.local_point_head(
aggregated_tokens_list_fp32,
images=images_fp32,
patch_start_idx=patch_start_idx
)
return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
def _unproject_depth_to_world(
self,
depth: torch.Tensor,
pose_enc: torch.Tensor,
) -> torch.Tensor:
B, S, H, W, _ = depth.shape
device = depth.device
dtype = depth.dtype
image_size_hw = (H, W)
extrinsics, intrinsics = pose_encoding_to_extri_intri(
pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
)
extrinsics_flat = extrinsics.view(B * S, 3, 4)
extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
extrinsics_4x4[:, :3, :] = extrinsics_flat
extrinsics_4x4[:, 3, 3] = 1.0
c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
y_grid, x_grid = torch.meshgrid(
torch.arange(H, device=device, dtype=dtype),
torch.arange(W, device=device, dtype=dtype),
indexing='ij'
)
pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
intrinsics_inv = torch.inverse(intrinsics)
camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
camera_points = camera_coords * depth
ones = torch.ones_like(camera_points[..., :1])
camera_points_h = torch.cat([camera_points, ones], dim=-1)
world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
return world_points_h[..., :3]
def forward(
self,
images: torch.Tensor,
query_points: Optional[torch.Tensor] = None,
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
mask: Optional[torch.Tensor] = None,
causal_inference: bool = False,
ordered_video: Optional[torch.Tensor] = None,
gather_outputs: bool = True,
point_masks: Optional[torch.Tensor] = None,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Forward pass of the GCT model.
Args:
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
query_points: Optional query points [N, 2] or [B, N, 2]
Returns:
Dictionary containing predictions:
- pose_enc: Camera pose encoding [B, S, 9]
- depth: Depth maps [B, S, H, W, 1]
- depth_conf: Depth confidence [B, S, H, W]
- world_points: 3D world coordinates [B, S, H, W, 3]
- world_points_conf: Point confidence [B, S, H, W]
"""
images, query_points = self._normalize_input(images, query_points)
aggregated_tokens_list, patch_start_idx = self._aggregate_features(
images,
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
)
predictions = {}
predictions.update(self._predict_camera(
aggregated_tokens_list,
mask=ordered_video,
causal_inference=causal_inference,
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
gather_outputs=gather_outputs,
))
predictions.update(self._predict_depth(
aggregated_tokens_list, images, patch_start_idx,
gather_outputs=gather_outputs,
))
predictions.update(self._predict_points(
aggregated_tokens_list, images, patch_start_idx,
gather_outputs=gather_outputs,
))
predictions.update(self._predict_local_points(
aggregated_tokens_list, images, patch_start_idx,
gather_outputs=gather_outputs,
))
if not self.training:
predictions["images"] = images
return predictions

View File

@@ -0,0 +1,444 @@
"""
GCTStream - Streaming GCT with KV cache for online inference.
Provides streaming inference functionality:
- Temporal causal attention with KV cache
- Sliding window support
- Efficient frame-by-frame processing
- 3D RoPE support for temporal consistency
"""
import logging
import torch
import torch.nn as nn
from typing import Optional, Dict, Any, List
from tqdm.auto import tqdm
from lingbot_map.heads.camera_head import CameraCausalHead
from lingbot_map.models.gct_base import GCTBase
from lingbot_map.aggregator.stream import AggregatorStream
logger = logging.getLogger(__name__)
class GCTStream(GCTBase):
"""
Streaming GCT model with KV cache for efficient online inference.
Features:
- AggregatorStream with KV cache support (FlashInfer backend)
- CameraCausalHead for pose refinement
- Sliding window attention for memory efficiency
- Frame-by-frame streaming inference
"""
def __init__(
self,
# Architecture parameters
img_size: int = 518,
patch_size: int = 14,
embed_dim: int = 1024,
patch_embed: str = 'dinov2_vitl14_reg',
pretrained_path: str = '',
disable_global_rope: bool = False,
# Head configuration
enable_camera: bool = True,
enable_point: bool = True,
enable_local_point: bool = False,
enable_depth: bool = True,
enable_track: bool = False,
# Normalization
enable_normalize: bool = False,
# Prediction normalization
pred_normalization: bool = False,
# Stream-specific parameters
sliding_window_size: int = -1,
num_frame_for_scale: int = 1,
num_random_frames: int = 0,
attend_to_special_tokens: bool = False,
attend_to_scale_frames: bool = False,
enable_stream_inference: bool = True, # Default to True for streaming
enable_3d_rope: bool = False,
max_frame_num: int = 1024,
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
enable_camera_3d_rope: bool = False,
camera_rope_theta: float = 10000.0,
# Scale token configuration (kept for checkpoint compat, ignored)
use_scale_token: bool = True,
# KV cache parameters
kv_cache_sliding_window: int = 64,
kv_cache_scale_frames: int = 8,
kv_cache_cross_frame_special: bool = True,
kv_cache_include_scale_frames: bool = True,
kv_cache_camera_only: bool = False,
# Backend selection
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
# Gradient checkpointing
use_gradient_checkpoint: bool = True,
):
"""
Initialize GCTStream.
Args:
img_size: Input image size
patch_size: Patch size for embedding
embed_dim: Embedding dimension
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
pretrained_path: Path to pretrained DINOv2 weights
disable_global_rope: Disable RoPE in global attention
enable_camera/point/depth/track: Enable prediction heads
enable_normalize: Enable normalization
sliding_window_size: Sliding window size in blocks (-1 for full causal)
num_frame_for_scale: Number of scale estimation frames
num_random_frames: Number of random frames for long-range dependencies
attend_to_special_tokens: Enable cross-frame special token attention
attend_to_scale_frames: Whether to attend to scale frames
enable_stream_inference: Enable streaming inference with KV cache
enable_3d_rope: Enable 3D RoPE for temporal consistency
max_frame_num: Maximum number of frames for 3D RoPE
use_scale_token: Kept for checkpoint compatibility, ignored
kv_cache_sliding_window: Sliding window size for KV cache eviction
kv_cache_scale_frames: Number of scale frames to keep in KV cache
kv_cache_cross_frame_special: Keep special tokens from evicted frames
kv_cache_include_scale_frames: Include scale frames in KV cache
kv_cache_camera_only: Only keep camera tokens from evicted frames
"""
# Store stream-specific parameters before calling super().__init__()
self.pretrained_path = pretrained_path
self.sliding_window_size = sliding_window_size
self.num_frame_for_scale = num_frame_for_scale
self.num_random_frames = num_random_frames
self.attend_to_special_tokens = attend_to_special_tokens
self.attend_to_scale_frames = attend_to_scale_frames
self.enable_stream_inference = enable_stream_inference
self.enable_3d_rope = enable_3d_rope
self.max_frame_num = max_frame_num
# Camera head 3D RoPE settings
self.enable_camera_3d_rope = enable_camera_3d_rope
self.camera_rope_theta = camera_rope_theta
# KV cache parameters
self.kv_cache_sliding_window = kv_cache_sliding_window
self.kv_cache_scale_frames = kv_cache_scale_frames
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
self.kv_cache_camera_only = kv_cache_camera_only
self.use_sdpa = use_sdpa
# Call base class __init__ (will call _build_aggregator)
super().__init__(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
patch_embed=patch_embed,
disable_global_rope=disable_global_rope,
enable_camera=enable_camera,
enable_point=enable_point,
enable_local_point=enable_local_point,
enable_depth=enable_depth,
enable_track=enable_track,
enable_normalize=enable_normalize,
pred_normalization=pred_normalization,
enable_3d_rope=enable_3d_rope,
use_gradient_checkpoint=use_gradient_checkpoint,
)
def _build_aggregator(self) -> nn.Module:
"""
Build streaming aggregator with KV cache support (FlashInfer backend).
Returns:
AggregatorStream module
"""
return AggregatorStream(
img_size=self.img_size,
patch_size=self.patch_size,
embed_dim=self.embed_dim,
patch_embed=self.patch_embed,
pretrained_path=self.pretrained_path,
disable_global_rope=self.disable_global_rope,
sliding_window_size=self.sliding_window_size,
num_frame_for_scale=self.num_frame_for_scale,
num_random_frames=self.num_random_frames,
attend_to_special_tokens=self.attend_to_special_tokens,
attend_to_scale_frames=self.attend_to_scale_frames,
enable_stream_inference=self.enable_stream_inference,
enable_3d_rope=self.enable_3d_rope,
max_frame_num=self.max_frame_num,
# Backend: FlashInfer (default) or SDPA (fallback)
use_flashinfer=not self.use_sdpa,
use_sdpa=self.use_sdpa,
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
use_gradient_checkpoint=self.use_gradient_checkpoint,
)
def _build_camera_head(self) -> nn.Module:
"""
Build causal camera head for streaming inference.
Returns:
CameraCausalHead module or None
"""
return CameraCausalHead(
dim_in=2 * self.embed_dim,
sliding_window_size=self.sliding_window_size,
attend_to_scale_frames=self.attend_to_scale_frames,
# KV cache parameters
kv_cache_sliding_window=self.kv_cache_sliding_window,
kv_cache_scale_frames=self.kv_cache_scale_frames,
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
kv_cache_camera_only=self.kv_cache_camera_only,
# Camera head 3D RoPE parameters
enable_3d_rope=self.enable_camera_3d_rope,
max_frame_num=self.max_frame_num,
rope_theta=self.camera_rope_theta,
)
def _aggregate_features(
self,
images: torch.Tensor,
num_frame_for_scale: Optional[int] = None,
sliding_window_size: Optional[int] = None,
num_frame_per_block: int = 1,
**kwargs,
) -> tuple:
"""
Run aggregator to get multi-scale features.
Args:
images: Input images [B, S, 3, H, W]
num_frame_for_scale: Number of frames for scale estimation
sliding_window_size: Override sliding window size
num_frame_per_block: Number of frames per block
Returns:
(aggregated_tokens_list, patch_start_idx)
"""
aggregated_tokens_list, patch_start_idx = self.aggregator(
images,
selected_idx=[4, 11, 17, 23],
num_frame_for_scale=num_frame_for_scale,
sliding_window_size=sliding_window_size,
num_frame_per_block=num_frame_per_block,
)
return aggregated_tokens_list, patch_start_idx
def clean_kv_cache(self):
"""
Clean KV cache in aggregator.
Call this method when starting a new video sequence to clear
cached key-value pairs from previous sequences.
"""
if hasattr(self.aggregator, 'clean_kv_cache'):
self.aggregator.clean_kv_cache()
else:
logger.warning("Aggregator does not support KV cache cleaning")
if hasattr(self.camera_head, 'kv_cache'):
self.camera_head.clean_kv_cache()
else:
logger.warning("Camera head does not support KV cache cleaning")
def _set_skip_append(self, skip: bool):
"""Set _skip_append flag on all KV caches (aggregator + camera head).
When skip=True, attention layers will attend to [cached_kv + current_kv]
but will NOT store the current frame's KV in cache. This is used for
non-keyframe processing in keyframe-based streaming inference.
Args:
skip: If True, subsequent forward passes will not append KV to cache.
"""
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
self.aggregator.kv_cache["_skip_append"] = skip
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
for cache_dict in self.camera_head.kv_cache:
cache_dict["_skip_append"] = skip
def get_kv_cache_info(self) -> Dict[str, Any]:
"""
Get information about current KV cache state.
Returns:
Dictionary with cache statistics:
- num_cached_blocks: Number of blocks with cached KV
- cache_memory_mb: Approximate memory usage in MB
"""
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
kv_cache = self.aggregator.kv_cache
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
# Estimate memory usage
total_elements = 0
for _, v in kv_cache.items():
if v is not None and torch.is_tensor(v):
total_elements += v.numel()
# Assume bfloat16 (2 bytes per element)
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
return {
"num_cached_blocks": num_cached,
"cache_memory_mb": round(cache_memory_mb, 2)
}
@torch.no_grad()
def inference_streaming(
self,
images: torch.Tensor,
num_scale_frames: Optional[int] = None,
keyframe_interval: int = 1,
output_device: Optional[torch.device] = None,
) -> Dict[str, torch.Tensor]:
"""
Streaming inference: process scale frames first, then frame-by-frame.
This method enables efficient online inference by:
1. Processing initial scale frames together (bidirectional attention via scale token)
2. Processing remaining frames one-by-one with KV cache (causal streaming)
Keyframe mode (keyframe_interval > 1):
- Every keyframe_interval-th frame (after scale frames) is a keyframe
- Keyframes: KV is stored in cache (normal behavior)
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
- All frames produce full predictions regardless of keyframe status
- Reduces KV cache memory growth by ~1/keyframe_interval
Args:
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
num_scale_frames: Number of initial frames for scale estimation.
If None, uses self.num_frame_for_scale.
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
whose KV persists in cache. 1 = every frame is a
keyframe (default, same as original behavior).
output_device: Device to store output predictions on. If None, keeps on
the same device as the model. Set to torch.device('cpu')
to offload predictions per-frame and avoid GPU OOM on
long sequences.
Returns:
Dictionary containing predictions for all frames:
- pose_enc: [B, S, 9]
- depth: [B, S, H, W, 1]
- depth_conf: [B, S, H, W]
- world_points: [B, S, H, W, 3]
- world_points_conf: [B, S, H, W]
"""
# Normalize input shape
if len(images.shape) == 4:
images = images.unsqueeze(0)
B, S, C, H, W = images.shape
# Determine number of scale frames
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
scale_frames = min(scale_frames, S) # Cap to available frames
# Helper to move tensor to output device
def _to_out(t: torch.Tensor) -> torch.Tensor:
if output_device is not None:
return t.to(output_device)
return t
# Clean KV caches before starting new sequence
self.clean_kv_cache()
# Phase 1: Process scale frames together
# These frames get bidirectional attention among themselves via scale token
logger.info(f'Processing {scale_frames} scale frames...')
scale_images = images[:, :scale_frames]
scale_output = self.forward(
scale_images,
num_frame_for_scale=scale_frames,
num_frame_per_block=scale_frames, # Process all scale frames as one block
causal_inference=True,
)
# Initialize output lists with scale frame predictions (offload if needed)
all_pose_enc = [_to_out(scale_output["pose_enc"])]
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
del scale_output
# Phase 2: Process remaining frames one-by-one
pbar = tqdm(
range(scale_frames, S),
desc='Streaming inference',
initial=scale_frames,
total=S,
)
for i in pbar:
frame_image = images[:, i:i+1]
# Determine if this frame is a keyframe
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
if not is_keyframe:
self._set_skip_append(True)
frame_output = self.forward(
frame_image,
num_frame_for_scale=scale_frames, # Keep same for scale token logic
num_frame_per_block=1, # Single frame per block
causal_inference=True,
)
if not is_keyframe:
self._set_skip_append(False)
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
if "depth" in frame_output:
all_depth.append(_to_out(frame_output["depth"]))
if "depth_conf" in frame_output:
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
if "world_points" in frame_output:
all_world_points.append(_to_out(frame_output["world_points"]))
if "world_points_conf" in frame_output:
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
del frame_output
# Free GPU memory before concatenation
if output_device is not None:
# Move images to output device, then free GPU copy
images_out = _to_out(images)
del images
# Clean KV cache (no longer needed after inference)
self.clean_kv_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
images_out = images
# Concatenate all predictions along sequence dimension
predictions = {
"pose_enc": torch.cat(all_pose_enc, dim=1),
}
del all_pose_enc
if all_depth:
predictions["depth"] = torch.cat(all_depth, dim=1)
del all_depth
if all_depth_conf:
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
del all_depth_conf
if all_world_points:
predictions["world_points"] = torch.cat(all_world_points, dim=1)
del all_world_points
if all_world_points_conf:
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
del all_world_points_conf
# Store images for visualization
predictions["images"] = images_out
# Apply prediction normalization if enabled
if self.pred_normalization:
predictions = self._normalize_predictions(predictions)
return predictions

View File

View File

@@ -0,0 +1,774 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import numpy as np
from scipy.spatial.transform import Rotation as R
from scipy.spatial.transform import Rotation
try:
from lietorch import SE3, Sim3
except ImportError:
SE3 = Sim3 = None
import torch.nn.functional as F
try:
from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
except ImportError:
apply_distortion = iterative_undistortion = single_undistortion = None
def unproject_depth_map_to_point_map(
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
) -> np.ndarray:
"""
Unproject a batch of depth maps to 3D world coordinates.
Args:
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
Returns:
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
"""
if isinstance(depth_map, torch.Tensor):
depth_map = depth_map.cpu().numpy()
if isinstance(extrinsics_cam, torch.Tensor):
extrinsics_cam = extrinsics_cam.cpu().numpy()
if isinstance(intrinsics_cam, torch.Tensor):
intrinsics_cam = intrinsics_cam.cpu().numpy()
world_points_list = []
for frame_idx in range(depth_map.shape[0]):
cur_world_points, _, _ = depth_to_world_coords_points(
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
)
world_points_list.append(cur_world_points)
world_points_array = np.stack(world_points_list, axis=0)
return world_points_array
def depth_to_world_coords_points(
depth_map: np.ndarray,
extrinsic: np.ndarray,
intrinsic: np.ndarray,
eps=1e-8,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Convert a depth map to world coordinates.
Args:
depth_map (np.ndarray): Depth map of shape (H, W).
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
Returns:
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
"""
if depth_map is None:
return None, None, None
# Valid depth mask
point_mask = depth_map > eps
# Convert depth map to camera coordinates
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
# Apply the rotation and translation to the camera coordinates
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
return world_coords_points, cam_coords_points, point_mask
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""
Convert a depth map to camera coordinates.
Args:
depth_map (np.ndarray): Depth map of shape (H, W).
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
Returns:
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
"""
H, W = depth_map.shape
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
# Intrinsic parameters
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
# Generate grid of pixel coordinates
u, v = np.meshgrid(np.arange(W), np.arange(H))
# Unproject to camera coordinates
x_cam = (u - cu) * depth_map / fu
y_cam = (v - cv) * depth_map / fv
z_cam = depth_map
# Stack to form camera coordinates
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
return cam_coords
def closed_form_inverse_se3(se3, R=None, T=None):
"""
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
If `R` and `T` are provided, they must correspond to the rotation and translation
components of `se3`. Otherwise, they will be extracted from `se3`.
Args:
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
R (optional): Nx3x3 array or tensor of rotation matrices.
T (optional): Nx3x1 array or tensor of translation vectors.
Returns:
Inverted SE3 matrices with the same type and device as `se3`.
Shapes:
se3: (N, 4, 4)
R: (N, 3, 3)
T: (N, 3, 1)
"""
# Check if se3 is a numpy array or a torch tensor
is_numpy = isinstance(se3, np.ndarray)
# Validate shapes
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
# Extract R and T if not provided
if R is None:
R = se3[:, :3, :3] # (N,3,3)
if T is None:
T = se3[:, :3, 3:] # (N,3,1)
# Transpose R
if is_numpy:
# Compute the transpose of the rotation for NumPy
R_transposed = np.transpose(R, (0, 2, 1))
# -R^T t for NumPy
top_right = -np.matmul(R_transposed, T)
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
else:
R_transposed = R.transpose(1, 2) # (N,3,3)
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
inverted_matrix[:, :3, :3] = R_transposed
inverted_matrix[:, :3, 3:] = top_right
return inverted_matrix
def closed_form_inverse_se3_general(se3, R=None, T=None):
"""
支持任意 batch 维度的 SE3 逆运算
se3: (..., 4, 4) 或 (..., 3, 4)
"""
batch_shape = se3.shape[:-2]
if R is None:
R = se3[..., :3, :3]
if T is None:
T = se3[..., :3, 3:]
R_transposed = R.transpose(-2, -1)
top_right = -R_transposed @ T
# 构造单位阵
eye = torch.eye(4, 4, dtype=R.dtype, device=R.device)
inverted_matrix = eye.expand(*batch_shape, 4, 4).clone()
inverted_matrix[..., :3, :3] = R_transposed
inverted_matrix[..., :3, 3:] = top_right
return inverted_matrix
# TODO: this code can be further cleaned up
def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
"""
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
Args:
world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
Returns:
"""
# TODO: merge this into project_world_points_to_cam
# device = world_points.device
# with torch.autocast(device_type=device.type, enabled=False):
ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
# extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
# world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
world_points_h_exp = world_points_h.unsqueeze(-1)
# Now perform the matrix multiplication
# (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
return camera_points
def project_world_points_to_cam(
world_points,
cam_extrinsics,
cam_intrinsics=None,
distortion_params=None,
default=0,
only_points_cam=False,
):
"""
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
Args:
world_points (torch.Tensor): 3D points of shape Px3.
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
Returns:
torch.Tensor: Transformed 2D points of shape BxNx2.
"""
device = world_points.device
# with torch.autocast(device_type=device.type, dtype=torch.double):
with torch.autocast(device_type=device.type, enabled=False):
N = world_points.shape[0] # Number of points
B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
world_points_homogeneous = torch.cat(
[world_points, torch.ones_like(world_points[..., 0:1])], dim=1
) # Nx4
# Reshape for batch processing
world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
B, -1, -1
) # BxNx4
# Step 1: Apply extrinsic parameters
# Transform 3D points to camera coordinate system for all cameras
cam_points = torch.bmm(
cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
)
if only_points_cam:
return None, cam_points
# Step 2: Apply intrinsic parameters and (optional) distortion
image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
return image_points, cam_points
def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
"""
Applies intrinsic parameters and optional distortion to the given 3D points.
Args:
cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
default (float, optional): Default value to replace NaNs in the output.
Returns:
pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
"""
# Normalized device coordinates (NDC)
cam_points = cam_points / cam_points[:, 2:3, :]
ndc_xy = cam_points[:, :2, :]
# Apply distortion if distortion_params are provided
if distortion_params is not None:
x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
else:
distorted_xy = ndc_xy
# Prepare cam_points for batch matrix multiplication
cam_coords_homo = torch.cat(
(distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
) # Bx3xN
# Apply intrinsic parameters using batch matrix multiplication
pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
# Extract x and y coordinates
pixel_coords = pixel_coords[:, :2, :] # Bx2xN
# Replace NaNs with default value
pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
return pixel_coords.transpose(1, 2) # BxNx2
def cam_from_img(pred_tracks, intrinsics, extra_params=None):
"""
Normalize predicted tracks based on camera intrinsics.
Args:
intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
Returns:
torch.Tensor: Normalized tracks tensor.
"""
# We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
# otherwise we can use something like
# tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
tracks_normalized = (pred_tracks - principal_point) / focal_length
if extra_params is not None:
# Apply iterative undistortion
try:
tracks_normalized = iterative_undistortion(
extra_params, tracks_normalized
)
except:
tracks_normalized = single_undistortion(
extra_params, tracks_normalized
)
return tracks_normalized
## Droid SLAM Part
MIN_DEPTH = 0.2
def extract_intrinsics(intrinsics):
return intrinsics[...,None,None,:].unbind(dim=-1)
def projective_transform(
poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False
):
"""map points from ii->jj"""
# inverse project (pinhole)
X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian)
# transform
Gij = poses[:, jj] * poses[:, ii].inv()
# Gij.data[:, ii == jj] = torch.as_tensor(
# [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda"
# )
X1, Ja = actp(Gij, X0, jacobian=jacobian)
# project (pinhole)
x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth)
# exclude points too close to camera
valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float()
valid = valid.unsqueeze(-1)
if jacobian:
# Ji transforms according to dual adjoint
Jj = torch.matmul(Jp, Ja)
Ji = -Gij[:, :, None, None, None].adjT(Jj)
Jz = Gij[:, :, None, None] * Jz
Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
return x1, valid, (Ji, Jj, Jz)
return x1, valid
def induced_flow(poses, disps, intrinsics, ii, jj):
"""optical flow induced by camera motion"""
ht, wd = disps.shape[2:]
y, x = torch.meshgrid(
torch.arange(ht, device=disps.device, dtype=torch.float),
torch.arange(wd, device=disps.device, dtype=torch.float),
indexing="ij",
)
coords0 = torch.stack([x, y], dim=-1)
coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
return coords1[..., :2] - coords0, valid
def all_pairs_distance_matrix(poses, beta=2.5):
""" compute distance matrix between all pairs of poses """
poses = np.array(poses, dtype=np.float32)
poses[:,:3] *= beta # scale to balence rot + trans
poses = SE3(torch.from_numpy(poses))
r = (poses[:,None].inv() * poses[None,:]).log()
return r.norm(dim=-1).cpu().numpy()
def pose_matrix_to_quaternion(pose):
""" convert 4x4 pose matrix to (t, q) """
q = Rotation.from_matrix(pose[..., :3, :3]).as_quat()
return np.concatenate([pose[..., :3, 3], q], axis=-1)
def compute_distance_matrix_flow(poses, disps, intrinsics):
""" compute flow magnitude between all pairs of frames """
if not isinstance(poses, SE3):
poses = torch.from_numpy(poses).float().cuda()[None]
poses = SE3(poses).inv()
disps = torch.from_numpy(disps).float().cuda()[None]
intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
N = poses.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1).cuda()
jj = jj.reshape(-1).cuda()
MAX_FLOW = 100.0
matrix = np.zeros((N, N), dtype=np.float32)
s = 2048
for i in range(0, ii.shape[0], s):
flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
flow = torch.stack([flow1, flow2], dim=2)
val = torch.stack([val1, val2], dim=2)
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
mag = mag.view(mag.shape[1], -1)
val = val.view(val.shape[1], -1)
mag = (mag * val).mean(-1) / val.mean(-1)
mag[val.mean(-1) < 0.7] = np.inf
i1 = ii[i:i+s].cpu().numpy()
j1 = jj[i:i+s].cpu().numpy()
matrix[i1, j1] = mag.cpu().numpy()
return matrix
def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
""" compute flow magnitude between all pairs of frames """
# if not isinstance(poses, SE3):
# poses = torch.from_numpy(poses).float().cuda()[None]
# poses = SE3(poses).inv()
# disps = torch.from_numpy(disps).float().cuda()[None]
# intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
N = poses.shape[1]
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
ii = ii.reshape(-1)
jj = jj.reshape(-1)
MAX_FLOW = 128.0
matrix = np.zeros((N, N), dtype=np.float32)
s = 2048
for i in range(0, ii.shape[0], s):
flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
flow1 = flow1a + beta * flow1b
val1 = val1a * val2b
flow2 = flow2a + beta * flow2b
val2 = val2a * val2b
flow = torch.stack([flow1, flow2], dim=2)
val = torch.stack([val1, val2], dim=2)
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
mag = mag.view(mag.shape[1], -1)
val = val.view(val.shape[1], -1)
mag = (mag * val).mean(-1) / val.mean(-1)
mag[val.mean(-1) < 0.8] = np.inf
i1 = ii[i:i+s].cpu().numpy()
j1 = jj[i:i+s].cpu().numpy()
matrix[i1, j1] = mag.cpu().numpy()
return matrix
def coords_grid(ht, wd, **kwargs):
y, x = torch.meshgrid(
torch.arange(ht, dtype=torch.float, **kwargs),
torch.arange(wd, dtype=torch.float, **kwargs),
indexing="ij",
)
return torch.stack([x, y], dim=-1)
def iproj(disps, intrinsics, jacobian=False):
"""pinhole camera inverse projection"""
ht, wd = disps.shape[2:]
fx, fy, cx, cy = extract_intrinsics(intrinsics)
y, x = torch.meshgrid(
torch.arange(ht, device=disps.device, dtype=torch.float),
torch.arange(wd, device=disps.device, dtype=torch.float),
indexing="ij",
)
i = torch.ones_like(disps)
X = (x - cx) / fx
Y = (y - cy) / fy
pts = torch.stack([X, Y, i, disps], dim=-1)
if jacobian:
J = torch.zeros_like(pts)
J[..., -1] = 1.0
return pts, J
return pts, None
def proj(Xs, intrinsics, jacobian=False, return_depth=False):
"""pinhole camera projection"""
fx, fy, cx, cy = extract_intrinsics(intrinsics)
X, Y, Z, D = Xs.unbind(dim=-1)
Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z)
d = 1.0 / Z
x = fx * (X * d) + cx
y = fy * (Y * d) + cy
if return_depth:
coords = torch.stack([x, y, D * d], dim=-1)
else:
coords = torch.stack([x, y], dim=-1)
if jacobian:
B, N, H, W = d.shape
o = torch.zeros_like(d)
proj_jac = torch.stack(
[
fx * d,
o,
-fx * X * d * d,
o,
o,
fy * d,
-fy * Y * d * d,
o,
# o, o, -D*d*d, d,
],
dim=-1,
).view(B, N, H, W, 2, 4)
return coords, proj_jac
return coords, None
def actp(Gij, X0, jacobian=False):
"""action on point cloud"""
X1 = Gij[:, :, None, None] * X0
if jacobian:
X, Y, Z, d = X1.unbind(dim=-1)
o = torch.zeros_like(d)
B, N, H, W = d.shape
if isinstance(Gij, SE3):
Ja = torch.stack(
[
d,
o,
o,
o,
Z,
-Y,
o,
d,
o,
-Z,
o,
X,
o,
o,
d,
Y,
-X,
o,
o,
o,
o,
o,
o,
o,
],
dim=-1,
).view(B, N, H, W, 4, 6)
elif isinstance(Gij, Sim3):
Ja = torch.stack(
[
d,
o,
o,
o,
Z,
-Y,
X,
o,
d,
o,
-Z,
o,
X,
Y,
o,
o,
d,
Y,
-X,
o,
Z,
o,
o,
o,
o,
o,
o,
o,
],
dim=-1,
).view(B, N, H, W, 4, 7)
return X1, Ja
return X1, None
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
return standardize_quaternion(out)
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
quaternions = F.normalize(quaternions, p=2, dim=-1)
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
def umeyama(X, Y):
"""
Estimates the Sim(3) transformation between `X` and `Y` point sets.
Estimates c, R and t such as c * R @ X + t ~ Y.
Parameters
----------
X : numpy.array
(m, n) shaped numpy array. m is the dimension of the points,
n is the number of points in the point set.
Y : numpy.array
(m, n) shaped numpy array. Indexes should be consistent with `X`.
That is, Y[:, i] must be the point corresponding to X[:, i].
Returns
-------
c : float
Scale factor.
R : numpy.array
(3, 3) shaped rotation matrix.
t : numpy.array
(3, 1) shaped translation vector.
"""
mu_x = X.mean(axis=1).reshape(-1, 1)
mu_y = Y.mean(axis=1).reshape(-1, 1)
var_x = np.square(X - mu_x).sum(axis=0).mean()
cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1]
U, D, VH = np.linalg.svd(cov_xy)
S = np.eye(X.shape[0])
if np.linalg.det(U) * np.linalg.det(VH) < 0:
S[-1, -1] = -1
c = np.trace(np.diag(D) @ S) / var_x
R = U @ S @ VH
t = mu_y - c * R @ mu_x
return c, R, t

View File

@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from PIL import Image
from torchvision import transforms as TF
import numpy as np
def load_and_preprocess_images_square(image_path_list, target_size=1024):
"""
Load and preprocess images by center padding to square and resizing to target size.
Also returns the position information of original pixels after transformation.
Args:
image_path_list (list): List of paths to image files
target_size (int, optional): Target size for both width and height. Defaults to 518.
Returns:
tuple: (
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
)
Raises:
ValueError: If the input list is empty
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
images = []
original_coords = [] # Renamed from position_info to be more descriptive
to_tensor = TF.ToTensor()
for image_path in image_path_list:
# Open image
img = Image.open(image_path)
# If there's an alpha channel, blend onto white background
if img.mode == "RGBA":
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
img = Image.alpha_composite(background, img)
# Convert to RGB
img = img.convert("RGB")
# Get original dimensions
width, height = img.size
# Make the image square by padding the shorter dimension
max_dim = max(width, height)
# Calculate padding
left = (max_dim - width) // 2
top = (max_dim - height) // 2
# Calculate scale factor for resizing
scale = target_size / max_dim
# Calculate final coordinates of original image in target space
x1 = left * scale
y1 = top * scale
x2 = (left + width) * scale
y2 = (top + height) * scale
# Store original image coordinates and scale
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
# Create a new black square image and paste original
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
square_img.paste(img, (left, top))
# Resize to target size
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
# Convert to tensor
img_tensor = to_tensor(square_img)
images.append(img_tensor)
# Stack all images
images = torch.stack(images)
original_coords = torch.from_numpy(np.array(original_coords)).float()
# Add additional dimension if single image to ensure correct shape
if len(image_path_list) == 1:
if images.dim() == 3:
images = images.unsqueeze(0)
original_coords = original_coords.unsqueeze(0)
return images, original_coords
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
"""
A quick start function to load and preprocess images for model input.
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
Args:
image_path_list (list): List of paths to image files
mode (str, optional): Preprocessing mode, either "crop" or "pad".
- "crop" (default): Sets width to 518px and center crops height if needed.
- "pad": Preserves all pixels by making the largest dimension 518px
and padding the smaller dimension to reach a square shape.
Returns:
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
Raises:
ValueError: If the input list is empty or if mode is invalid
Notes:
- Images with different dimensions will be padded with white (value=1.0)
- A warning is printed when images have different shapes
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
and height is center-cropped if larger than 518px
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
and the smaller dimension is padded to reach a square shape (518x518)
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
"""
# Check for empty list
if len(image_path_list) == 0:
raise ValueError("At least 1 image is required")
# Validate mode
if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'")
images = []
shapes = set()
to_tensor = TF.ToTensor()
target_size = image_size
# First process all images and collect their shapes
for i, image_path in enumerate(image_path_list):
# Open image
img = Image.open(image_path)
# If there's an alpha channel, blend onto white background:
if img.mode == "RGBA":
# Create white background
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
# Alpha composite onto the white background
img = Image.alpha_composite(background, img)
# Now convert to "RGB" (this step assigns white for transparent areas)
img = img.convert("RGB")
width, height = img.size
if fx is not None:
fx[i] = fx[i] * width
fy[i] = fy[i] * height
cx[i] = cx[i] * width
cy[i] = cy[i] * height
if mode == "pad":
# Make the largest dimension 518px while maintaining aspect ratio
if width >= height:
new_width = target_size
new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14
else:
new_height = target_size
new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14
else: # mode == "crop"
# Original behavior: set width to 518px
new_width = target_size
# Calculate height maintaining aspect ratio, divisible by 14
new_height = round(height * (new_width / width) / patch_size) * patch_size
# Resize with new dimensions (width, height)
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
img = to_tensor(img) # Convert to tensor (0, 1)
# Center crop height if it's larger than 518 (only in crop mode)
if mode == "crop" and new_height > target_size:
start_y = (new_height - target_size) // 2
img = img[:, start_y : start_y + target_size, :]
if fx is not None:
fx[i] = fx[i] * new_width / width
fy[i] = fy[i] * new_height / height
cx[i] = img.shape[2] / 2
cy[i] = img.shape[1] / 2
# For pad mode, pad to make a square of target_size x target_size
if mode == "pad":
h_padding = target_size - img.shape[1]
w_padding = target_size - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
# Pad with white (value=1.0)
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
shapes.add((img.shape[1], img.shape[2]))
images.append(img)
# Check if we have different shapes
# In theory our model can also work well with different shapes
if len(shapes) > 1:
print(f"Warning: Found images with different shapes: {shapes}")
# Find maximum dimensions
max_height = max(shape[0] for shape in shapes)
max_width = max(shape[1] for shape in shapes)
# Pad images if necessary
padded_images = []
for img in images:
h_padding = max_height - img.shape[1]
w_padding = max_width - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
img = torch.nn.functional.pad(
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
)
padded_images.append(img)
images = padded_images
images = torch.stack(images) # concatenate images
# Ensure correct shape when single image
if len(image_path_list) == 1:
# Verify shape is (1, C, H, W)
if images.dim() == 3:
images = images.unsqueeze(0)
if fx is not None:
return images, fx, fy, cx, cy
return images

View File

@@ -0,0 +1,331 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .rotation import quat_to_mat, mat_to_quat
import os
import torch
import numpy as np
import gzip
import json
import random
import logging
import warnings
from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general
def extri_intri_to_pose_encoding(
extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
):
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
This function transforms camera parameters into a unified pose encoding format,
which can be used for various downstream tasks like pose prediction or representation.
Args:
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
where B is batch size and S is sequence length.
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
Defined in pixels, with format:
[[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]]
where fx, fy are focal lengths and (cx, cy) is the principal point
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
Required for computing field of view values. For example: (256, 512).
pose_encoding_type (str): Type of pose encoding to use. Currently only
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
Returns:
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
For "absT_quaR_FoV" type, the 9 dimensions are:
- [:3] = absolute translation vector T (3D)
- [3:7] = rotation as quaternion quat (4D)
- [7:] = field of view (2D)
"""
# extrinsics: BxSx3x4
# intrinsics: BxSx3x3
if pose_encoding_type == "absT_quaR_FoV":
R = extrinsics[:, :, :3, :3] # BxSx3x3
T = extrinsics[:, :, :3, 3] # BxSx3
quat = mat_to_quat(R)
# Note the order of h and w here
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
else:
raise NotImplementedError
return pose_encoding
def pose_encoding_to_extri_intri(
pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
):
"""Convert a pose encoding back to camera extrinsics and intrinsics.
This function performs the inverse operation of extri_intri_to_pose_encoding,
reconstructing the full camera parameters from the compact encoding.
Args:
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
where B is batch size and S is sequence length.
For "absT_quaR_FoV" type, the 9 dimensions are:
- [:3] = absolute translation vector T (3D)
- [3:7] = rotation as quaternion quat (4D)
- [7:] = field of view (2D)
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
Required for reconstructing intrinsics from field of view values.
For example: (256, 512).
pose_encoding_type (str): Type of pose encoding used. Currently only
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
If False, only extrinsics are returned and intrinsics will be None.
Returns:
tuple: (extrinsics, intrinsics)
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
a 3x1 translation vector.
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
or None if build_intrinsics is False. Defined in pixels, with format:
[[fx, 0, cx],
[0, fy, cy],
[0, 0, 1]]
where fx, fy are focal lengths and (cx, cy) is the principal point,
assumed to be at the center of the image (W/2, H/2).
"""
intrinsics = None
if pose_encoding_type == "absT_quaR_FoV":
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
if build_intrinsics:
H, W = image_size_hw
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
elif pose_encoding_type == "absT_quaR":
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
intrinsics = None
return extrinsics, intrinsics
def convert_pt3d_RT_to_opencv(Rot, Trans):
"""
Convert Point3D extrinsic matrices to OpenCV convention.
Args:
Rot: 3D rotation matrix in Point3D format
Trans: 3D translation vector in Point3D format
Returns:
extri_opencv: 3x4 extrinsic matrix in OpenCV format
"""
rot_pt3d = np.array(Rot)
trans_pt3d = np.array(Trans)
trans_pt3d[:2] *= -1
rot_pt3d[:, :2] *= -1
rot_pt3d = rot_pt3d.transpose(1, 0)
extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
return extri_opencv
def build_pair_index(N, B=1):
"""
Build indices for all possible pairs of frames.
Args:
N: Number of frames
B: Batch size
Returns:
i1, i2: Indices for all possible pairs
"""
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
return i1, i2
def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
"""
Calculate rotation angle error between ground truth and predicted rotations.
Args:
rot_gt: Ground truth rotation matrices
rot_pred: Predicted rotation matrices
batch_size: Batch size for reshaping the result
eps: Small value to avoid numerical issues
Returns:
Rotation angle error in degrees
"""
q_pred = mat_to_quat(rot_pred)
q_gt = mat_to_quat(rot_gt)
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
err_q = torch.arccos(1 - 2 * loss_q)
rel_rangle_deg = err_q * 180 / np.pi
if batch_size is not None:
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
return rel_rangle_deg
def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
"""
Calculate translation angle error between ground truth and predicted translations.
Args:
tvec_gt: Ground truth translation vectors
tvec_pred: Predicted translation vectors
batch_size: Batch size for reshaping the result
ambiguity: Whether to handle direction ambiguity
Returns:
Translation angle error in degrees
"""
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
if ambiguity:
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
if batch_size is not None:
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
return rel_tangle_deg
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
"""
Normalize the translation vectors and compute the angle between them.
Args:
t_gt: Ground truth translation vectors
t: Predicted translation vectors
eps: Small value to avoid division by zero
default_err: Default error value for invalid cases
Returns:
Angular error between translation vectors in radians
"""
t_norm = torch.norm(t, dim=1, keepdim=True)
t = t / (t_norm + eps)
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
t_gt = t_gt / (t_gt_norm + eps)
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
err_t = torch.acos(torch.sqrt(1 - loss_t))
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
return err_t
def calculate_auc_np(r_error, t_error, max_threshold=30):
"""
Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
Args:
r_error: numpy array representing R error values (Degree)
t_error: numpy array representing T error values (Degree)
max_threshold: Maximum threshold value for binning the histogram
Returns:
AUC value and the normalized histogram
"""
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
max_errors = np.max(error_matrix, axis=1)
bins = np.arange(max_threshold + 1)
histogram, _ = np.histogram(max_errors, bins=bins)
num_pairs = float(len(max_errors))
normalized_histogram = histogram.astype(float) / num_pairs
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
"""
Compute rotation and translation errors between predicted and ground truth poses.
This function assumes the input poses are world-to-camera (w2c) transformations.
Args:
pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4)
gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4)
num_frames: Number of frames (N)
Returns:
Rotation and translation angle errors in degrees
"""
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
relative_pose_gt = gt_se3[pair_idx_i1].bmm(
closed_form_inverse_se3(gt_se3[pair_idx_i2])
)
relative_pose_pred = pred_se3[pair_idx_i1].bmm(
closed_form_inverse_se3(pred_se3[pair_idx_i2])
)
rel_rangle_deg = rotation_angle(
relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
)
rel_tangle_deg = translation_angle(
relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
)
return rel_rangle_deg, rel_tangle_deg
def colmap_to_opencv_intrinsics(K):
"""
Modify camera intrinsics to follow a different convention.
Coordinates of the center of the top-left pixels are by default:
- (0.5, 0.5) in Colmap
- (0,0) in OpenCV
"""
K = K.copy()
K[..., 0, 2] -= 0.5
K[..., 1, 2] -= 0.5
return K
def read_camera_parameters(filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
return intrinsics, extrinsics

View File

@@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
import torch
import numpy as np
import torch.nn.functional as F
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""
Quaternion Order: XYZW or say ijkr, scalar-last
Convert rotations given as quaternions to rotation matrices.
Args:
quaternions: quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
i, j, k, r = torch.unbind(quaternions, -1)
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part last, as tensor of shape (..., 4).
Quaternion Order: XYZW or say ijkr, scalar-last
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
q_abs = _sqrt_positive_part(
torch.stack(
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
# Convert from rijk to ijkr
out = out[..., [1, 2, 3, 0]]
out = standardize_quaternion(out)
return out
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert a unit quaternion to a standard form: one in which the real
part is non negative.
Args:
quaternions: Quaternions with real part last,
as tensor of shape (..., 4).
Returns:
Standardized quaternions as tensor of shape (..., 4).
"""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)

View File

@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
GCT Visualization Module
This module provides visualization utilities for 3D reconstruction results:
- PointCloudViewer: Interactive point cloud viewer with camera visualization
- viser_wrapper: Quick visualization wrapper for predictions
- predictions_to_glb: Export predictions to GLB 3D format
- Colorization and utility functions
Usage:
from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
# Interactive visualization
viewer = PointCloudViewer(pred_dict=predictions, port=8080)
viewer.run()
# Quick visualization
viser_wrapper(predictions, port=8080)
# Export to GLB
scene = predictions_to_glb(predictions)
scene.export("output.glb")
"""
from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
from lingbot_map.vis.viser_wrapper import viser_wrapper
from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
from lingbot_map.vis.sky_segmentation import (
apply_sky_segmentation,
download_skyseg_model,
load_or_create_sky_masks,
segment_sky,
)
from lingbot_map.vis.glb_export import predictions_to_glb
__all__ = [
# Main viewer
"PointCloudViewer",
# Quick visualization
"viser_wrapper",
# GLB export
"predictions_to_glb",
# Utilities
"CameraState",
"colorize",
"colorize_np",
"get_vertical_colorbar",
# Sky segmentation
"apply_sky_segmentation",
"segment_sky",
"download_skyseg_model",
"load_or_create_sky_masks",
]

View File

@@ -0,0 +1,509 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
GLB 3D export utilities for GCT predictions.
"""
import os
import copy
from typing import Optional, Tuple
import numpy as np
import cv2
import matplotlib
from scipy.spatial.transform import Rotation
from lingbot_map.vis.sky_segmentation import (
_SKYSEG_INPUT_SIZE,
_SKYSEG_SOFT_THRESHOLD,
_mask_to_float,
_mask_to_uint8,
_result_map_to_non_sky_conf,
)
try:
import trimesh
except ImportError:
trimesh = None
print("trimesh not found. GLB export will not work.")
def predictions_to_glb(
predictions: dict,
conf_thres: float = 50.0,
filter_by_frames: str = "all",
mask_black_bg: bool = False,
mask_white_bg: bool = False,
show_cam: bool = True,
mask_sky: bool = False,
target_dir: Optional[str] = None,
prediction_mode: str = "Predicted Pointmap",
) -> "trimesh.Scene":
"""
Converts GCT predictions to a 3D scene represented as a GLB file.
Args:
predictions: Dictionary containing model predictions with keys:
- world_points: 3D point coordinates (S, H, W, 3)
- world_points_conf: Confidence scores (S, H, W)
- images: Input images (S, H, W, 3) or (S, 3, H, W)
- extrinsic: Camera extrinsic matrices (S, 3, 4)
conf_thres: Percentage of low-confidence points to filter out
filter_by_frames: Frame filter specification ("all" or frame index)
mask_black_bg: Mask out black background pixels
mask_white_bg: Mask out white background pixels
show_cam: Include camera visualization
mask_sky: Apply sky segmentation mask
target_dir: Output directory for intermediate files
prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
Returns:
trimesh.Scene: Processed 3D scene containing point cloud and cameras
Raises:
ValueError: If input predictions structure is invalid
ImportError: If trimesh is not available
"""
if trimesh is None:
raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
if not isinstance(predictions, dict):
raise ValueError("predictions must be a dictionary")
if conf_thres is None:
conf_thres = 10.0
print("Building GLB scene")
# Parse frame filter
selected_frame_idx = None
if filter_by_frames != "all" and filter_by_frames != "All":
try:
selected_frame_idx = int(filter_by_frames.split(":")[0])
except (ValueError, IndexError):
pass
# Select prediction source
if "Pointmap" in prediction_mode:
print("Using Pointmap Branch")
if "world_points" in predictions:
pred_world_points = predictions["world_points"]
pred_world_points_conf = predictions.get(
"world_points_conf", np.ones_like(pred_world_points[..., 0])
)
else:
print("Warning: world_points not found, falling back to depth-based points")
pred_world_points = predictions["world_points_from_depth"]
pred_world_points_conf = predictions.get(
"depth_conf", np.ones_like(pred_world_points[..., 0])
)
else:
print("Using Depthmap and Camera Branch")
pred_world_points = predictions["world_points_from_depth"]
pred_world_points_conf = predictions.get(
"depth_conf", np.ones_like(pred_world_points[..., 0])
)
images = predictions["images"]
camera_matrices = predictions["extrinsic"]
# Apply sky segmentation if enabled
if mask_sky and target_dir is not None:
pred_world_points_conf = _apply_sky_mask(
pred_world_points_conf, target_dir, images
)
# Apply frame filter
if selected_frame_idx is not None:
pred_world_points = pred_world_points[selected_frame_idx][None]
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
images = images[selected_frame_idx][None]
camera_matrices = camera_matrices[selected_frame_idx][None]
# Prepare vertices and colors
vertices_3d = pred_world_points.reshape(-1, 3)
# Handle different image formats
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
colors_rgb = np.transpose(images, (0, 2, 3, 1))
else:
colors_rgb = images
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
# Apply confidence filtering
conf = pred_world_points_conf.reshape(-1)
conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
# Apply background masking
if mask_black_bg:
black_bg_mask = colors_rgb.sum(axis=1) >= 16
conf_mask = conf_mask & black_bg_mask
if mask_white_bg:
white_bg_mask = ~(
(colors_rgb[:, 0] > 240) &
(colors_rgb[:, 1] > 240) &
(colors_rgb[:, 2] > 240)
)
conf_mask = conf_mask & white_bg_mask
vertices_3d = vertices_3d[conf_mask]
colors_rgb = colors_rgb[conf_mask]
# Handle empty point cloud
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
vertices_3d = np.array([[1, 0, 0]])
colors_rgb = np.array([[255, 255, 255]])
scene_scale = 1
else:
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
# Build scene
scene_3d = trimesh.Scene()
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
scene_3d.add_geometry(point_cloud_data)
# Prepare camera matrices
num_cameras = len(camera_matrices)
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
extrinsics_matrices[:, :3, :4] = camera_matrices
extrinsics_matrices[:, 3, 3] = 1
# Add cameras
if show_cam:
for i in range(num_cameras):
world_to_camera = extrinsics_matrices[i]
camera_to_world = np.linalg.inv(world_to_camera)
rgba_color = colormap(i / num_cameras)
current_color = tuple(int(255 * x) for x in rgba_color[:3])
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
# Align scene
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
print("GLB Scene built")
return scene_3d
def _apply_sky_mask(
conf: np.ndarray,
target_dir: str,
images: np.ndarray
) -> np.ndarray:
"""Apply sky segmentation mask to confidence scores."""
try:
import onnxruntime
except ImportError:
print("Warning: onnxruntime not available, skipping sky masking")
return conf
target_dir_images = os.path.join(target_dir, "images")
if not os.path.exists(target_dir_images):
print(f"Warning: Images directory not found at {target_dir_images}")
return conf
image_list = sorted(os.listdir(target_dir_images))
S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
skyseg_model_path = "skyseg.onnx"
if not os.path.exists(skyseg_model_path):
print("Downloading skyseg.onnx...")
download_file_from_url(
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
skyseg_model_path
)
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
sky_mask_list = []
for i, image_name in enumerate(image_list[:S]):
image_filepath = os.path.join(target_dir_images, image_name)
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
if os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
else:
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
sky_mask_list.append(_mask_to_float(sky_mask))
sky_mask_array = np.array(sky_mask_list)
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
return conf * sky_mask_binary
def integrate_camera_into_scene(
scene: "trimesh.Scene",
transform: np.ndarray,
face_colors: Tuple[int, int, int],
scene_scale: float,
frustum_thickness: float = 1.0,
):
"""
Integrates a camera mesh into the 3D scene.
Args:
scene: The 3D scene to add the camera model
transform: Transformation matrix for camera positioning
face_colors: RGB color tuple for the camera
scene_scale: Scale of the scene
frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
"""
cam_width = scene_scale * 0.05
cam_height = scene_scale * 0.1
rot_45_degree = np.eye(4)
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
rot_45_degree[2, 3] = -cam_height
opengl_transform = get_opengl_conversion_matrix()
complete_transform = transform @ opengl_transform @ rot_45_degree
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
# Build thicker frustum by stacking rotated copies
slight_rotation = np.eye(4)
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
shell_scales = [1.0, 0.95]
shell_transforms = [np.eye(4), slight_rotation]
# Add extra shells for thickness
if frustum_thickness > 1.0:
n_extra = max(1, int(frustum_thickness - 1))
for k in range(1, n_extra + 1):
# Progressively rotated and scaled copies
angle = 2.0 + k * 2.0
scale = 1.0 + k * 0.02
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
shell_scales.append(scale)
shell_transforms.append(rot)
rot_neg = np.eye(4)
rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
shell_scales.append(scale)
shell_transforms.append(rot_neg)
vertices_parts = []
for s, t_mat in zip(shell_scales, shell_transforms):
vertices_parts.append(
transform_points(t_mat, s * camera_cone_shape.vertices)
)
vertices_combined = np.concatenate(vertices_parts)
vertices_transformed = transform_points(complete_transform, vertices_combined)
mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
camera_mesh.visual.face_colors[:, :3] = face_colors
scene.add_geometry(camera_mesh)
def apply_scene_alignment(
scene_3d: "trimesh.Scene",
extrinsics_matrices: np.ndarray
) -> "trimesh.Scene":
"""
Aligns the 3D scene based on the extrinsics of the first camera.
Args:
scene_3d: The 3D scene to be aligned
extrinsics_matrices: Camera extrinsic matrices
Returns:
Aligned 3D scene
"""
opengl_conversion_matrix = get_opengl_conversion_matrix()
align_rotation = np.eye(4)
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
initial_transformation = (
np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
)
scene_3d.apply_transform(initial_transformation)
return scene_3d
def get_opengl_conversion_matrix() -> np.ndarray:
"""Returns the OpenGL conversion matrix (flips Y and Z axes)."""
matrix = np.identity(4)
matrix[1, 1] = -1
matrix[2, 2] = -1
return matrix
def transform_points(
transformation: np.ndarray,
points: np.ndarray,
dim: Optional[int] = None
) -> np.ndarray:
"""
Applies a 4x4 transformation to a set of points.
Args:
transformation: Transformation matrix
points: Points to be transformed
dim: Dimension for reshaping the result
Returns:
Transformed points
"""
points = np.asarray(points)
initial_shape = points.shape[:-1]
dim = dim or points.shape[-1]
transformation = transformation.swapaxes(-1, -2)
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
return points[..., :dim].reshape(*initial_shape, dim)
def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
"""Computes the faces for the camera mesh."""
faces_list = []
num_vertices_cone = len(cone_shape.vertices)
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
faces_list.extend([
(v1, v2, v2_offset),
(v1, v1_offset, v3),
(v3_offset, v2, v3),
(v1, v2, v2_offset_2),
(v1, v1_offset_2, v3),
(v3_offset_2, v2, v3),
])
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)
def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
"""Computes faces for a camera mesh with multiple shells (for thicker frustums).
Connects each consecutive pair of vertex shells to form the frustum edges.
"""
faces_list = []
nv = len(cone_shape.vertices)
for s in range(num_shells - 1):
off_a = s * nv
off_b = (s + 1) * nv
for face in cone_shape.faces:
if 0 in face:
continue
v1, v2, v3 = face
faces_list.extend([
(v1 + off_a, v2 + off_a, v2 + off_b),
(v1 + off_a, v1 + off_b, v3 + off_a),
(v3 + off_b, v2 + off_a, v3 + off_a),
])
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
return np.array(faces_list)
def segment_sky(
image_path: str,
onnx_session,
mask_filename: str
) -> np.ndarray:
"""
Segments sky from an image using an ONNX model.
Args:
image_path: Path to input image
onnx_session: ONNX runtime session with loaded model
mask_filename: Path to save the output mask
Returns:
Continuous non-sky confidence map in [0, 1]
"""
image = cv2.imread(image_path)
result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
result_map_original = cv2.resize(
result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
)
output_mask = _result_map_to_non_sky_conf(result_map_original)
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
return output_mask
def run_skyseg(
onnx_session,
input_size: Tuple[int, int],
image: np.ndarray
) -> np.ndarray:
"""
Runs sky segmentation inference using ONNX model.
Args:
onnx_session: ONNX runtime session
input_size: Target size for model input (width, height)
image: Input image in BGR format
Returns:
Segmentation mask
"""
temp_image = copy.deepcopy(image)
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
x = np.array(x, dtype=np.float32)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x = (x / 255 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
onnx_result = (onnx_result - min_value) / (max_value - min_value)
onnx_result *= 255
return onnx_result.astype("uint8")
def download_file_from_url(url: str, filename: str):
"""Downloads a file from a URL, handling redirects."""
import requests
try:
response = requests.get(url, allow_redirects=False)
response.raise_for_status()
if response.status_code == 302:
redirect_url = response.headers["Location"]
response = requests.get(redirect_url, stream=True)
response.raise_for_status()
else:
print(f"Unexpected status code: {response.status_code}")
return
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {filename} successfully.")
except requests.exceptions.RequestException as e:
print(f"Error downloading file: {e}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,473 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Sky segmentation utilities for filtering sky points from point clouds.
"""
import glob
import os
from typing import Optional, Tuple
import numpy as np
import cv2
from tqdm.auto import tqdm
try:
import onnxruntime
except ImportError:
onnxruntime = None
print("onnxruntime not found. Sky segmentation may not work.")
_SKYSEG_INPUT_SIZE = (320, 320)
_SKYSEG_SOFT_THRESHOLD = 0.1
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
def _get_cache_version_path(sky_mask_dir: str) -> str:
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
if sky_mask_dir is None:
return False
os.makedirs(sky_mask_dir, exist_ok=True)
version_path = _get_cache_version_path(sky_mask_dir)
refresh_cache = True
if os.path.exists(version_path):
with open(version_path, "r", encoding="utf-8") as f:
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
if refresh_cache:
print(
f"Sky mask cache at {sky_mask_dir} uses an older format; "
"regenerating masks with ImageNet-normalized skyseg input"
)
with open(version_path, "w", encoding="utf-8") as f:
f.write(_SKYSEG_CACHE_VERSION)
return refresh_cache
def run_skyseg(
onnx_session,
input_size: Tuple[int, int],
image: np.ndarray,
) -> np.ndarray:
"""
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
"""
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
x = (x / 255.0 - mean) / std
x = x.transpose(2, 0, 1)
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name
onnx_result = onnx_session.run([output_name], {input_name: x})
onnx_result = np.array(onnx_result).squeeze()
min_value = np.min(onnx_result)
max_value = np.max(onnx_result)
denom = max(max_value - min_value, 1e-8)
onnx_result = (onnx_result - min_value) / denom
onnx_result *= 255.0
return onnx_result.astype(np.uint8)
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
mask = mask.astype(np.float32)
if mask.size == 0:
return mask
return np.clip(mask, 0.0, 1.0)
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
mask = np.asarray(mask)
if mask.dtype == np.uint8:
return mask
mask = mask.astype(np.float32)
if mask.size > 0 and mask.max() <= 1.0:
mask = mask * 255.0
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
# The raw skyseg map is higher on sky and lower on non-sky.
return 1.0 - _mask_to_float(result_map)
def segment_sky_from_array(
image: np.ndarray,
skyseg_session,
target_h: int,
target_w: int
) -> np.ndarray:
"""
Segment sky from an image array using ONNX model.
Args:
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
skyseg_session: ONNX runtime inference session
target_h: Target output height
target_w: Target output width
Returns:
Continuous non-sky confidence map in [0, 1].
"""
image_rgb = _image_to_rgb_uint8(image)
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
return _result_map_to_non_sky_conf(result_map)
def segment_sky(
image_path: str,
skyseg_session,
output_path: Optional[str] = None
) -> np.ndarray:
"""
Segment sky from an image using ONNX model.
Args:
image_path: Path to the input image
skyseg_session: ONNX runtime inference session
output_path: Optional path to save the mask
Returns:
Continuous non-sky confidence map in [0, 1].
"""
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Failed to read image: {image_path}")
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
mask = _result_map_to_non_sky_conf(result_map)
if output_path is not None:
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(output_path, _mask_to_uint8(mask))
return mask
def _list_image_files(image_folder: str) -> list[str]:
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
image = image.transpose(1, 2, 0)
if image.ndim != 3 or image.shape[2] != 3:
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
if image.dtype != np.uint8:
image = image.astype(np.float32)
if image.max() <= 1.0:
image = image * 255.0
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
return image
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
if image_paths is not None and index < len(image_paths):
return os.path.basename(image_paths[index])
return f"frame_{index:06d}.png"
def _save_sky_mask_visualization(
image: np.ndarray,
sky_mask: np.ndarray,
output_path: str,
) -> None:
image_rgb = _image_to_rgb_uint8(image)
if sky_mask.shape[:2] != image_rgb.shape[:2]:
sky_mask = cv2.resize(
sky_mask,
(image_rgb.shape[1], image_rgb.shape[0]),
interpolation=cv2.INTER_NEAREST,
)
mask_uint8 = _mask_to_uint8(sky_mask)
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
overlay = image_rgb.astype(np.float32).copy()
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
output_dir = os.path.dirname(output_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
def load_or_create_sky_masks(
image_folder: Optional[str] = None,
image_paths: Optional[list[str]] = None,
images: Optional[np.ndarray] = None,
skyseg_model_path: str = "skyseg.onnx",
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
target_shape: Optional[Tuple[int, int]] = None,
num_frames: Optional[int] = None,
) -> Optional[np.ndarray]:
"""
Load cached sky masks or generate them with the ONNX model.
Args:
image_folder: Folder containing input images.
image_paths: Optional explicit image file list, in the exact order to process.
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
skyseg_model_path: Path to the sky segmentation ONNX model.
sky_mask_dir: Optional directory for cached raw masks.
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
target_shape: Optional output mask shape (H, W) after resizing.
num_frames: Optional maximum number of frames to process.
Returns:
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
"""
if onnxruntime is None:
print("Warning: onnxruntime not available, skipping sky segmentation")
return None
if image_folder is None and image_paths is None and images is None:
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
return None
if not os.path.exists(skyseg_model_path):
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
try:
download_skyseg_model(skyseg_model_path)
except Exception as e:
print(f"Warning: Failed to download sky segmentation model: {e}")
return None
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
sky_masks = []
if sky_mask_visualization_dir is not None:
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
if images is not None:
if image_paths is None and image_folder is not None:
image_paths = _list_image_files(image_folder)
num_images = images.shape[0]
if num_frames is not None:
num_images = min(num_images, num_frames)
if image_paths is not None:
image_paths = image_paths[:num_images]
if sky_mask_dir is None and image_folder is not None:
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image array...")
for i in tqdm(range(num_images)):
image_rgb = _image_to_rgb_uint8(images[i])
image_h, image_w = image_rgb.shape[:2]
image_name = _get_mask_filename(image_paths, i)
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
elif sky_mask.shape[:2] != (image_h, image_w):
print(
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
)
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
else:
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
if mask_filepath is not None:
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
if sky_mask_visualization_dir is not None:
_save_sky_mask_visualization(
image_rgb,
sky_mask,
os.path.join(sky_mask_visualization_dir, image_name),
)
if target_shape is not None and sky_mask.shape[:2] != target_shape:
sky_mask = cv2.resize(
sky_mask,
(target_shape[1], target_shape[0]),
interpolation=cv2.INTER_LINEAR,
)
sky_masks.append(_mask_to_float(sky_mask))
else:
if image_paths is None and image_folder is not None:
image_paths = _list_image_files(image_folder)
if images is None and image_paths is not None:
if len(image_paths) == 0:
print("Warning: No image files provided, skipping sky segmentation")
return None
if num_frames is not None:
image_paths = image_paths[:num_frames]
if sky_mask_dir is None:
if image_folder is None:
image_folder = os.path.dirname(image_paths[0])
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
print("Generating sky masks from image files...")
for image_path in tqdm(image_paths):
image_name = os.path.basename(image_path)
mask_filepath = os.path.join(sky_mask_dir, image_name)
if not refresh_cache and os.path.exists(mask_filepath):
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
if sky_mask is None:
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
else:
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
if sky_mask is None:
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
continue
if sky_mask_visualization_dir is not None:
image_bgr = cv2.imread(image_path)
if image_bgr is not None:
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
_save_sky_mask_visualization(
image_rgb,
sky_mask,
os.path.join(sky_mask_visualization_dir, image_name),
)
if target_shape is not None and sky_mask.shape[:2] != target_shape:
sky_mask = cv2.resize(
sky_mask,
(target_shape[1], target_shape[0]),
interpolation=cv2.INTER_LINEAR,
)
sky_masks.append(_mask_to_float(sky_mask))
if len(sky_masks) == 0:
print("Warning: No sky masks generated, skipping sky segmentation")
return None
try:
return np.stack(sky_masks, axis=0)
except ValueError:
return np.array(sky_masks, dtype=object)
def apply_sky_segmentation(
conf: np.ndarray,
image_folder: Optional[str] = None,
image_paths: Optional[list[str]] = None,
images: Optional[np.ndarray] = None,
skyseg_model_path: str = "skyseg.onnx",
sky_mask_dir: Optional[str] = None,
sky_mask_visualization_dir: Optional[str] = None,
) -> np.ndarray:
"""
Apply sky segmentation to confidence scores.
Args:
conf: Confidence scores with shape (S, H, W)
image_folder: Path to the folder containing input images (optional if images provided)
image_paths: Optional explicit image file list in processing order
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
skyseg_model_path: Path to the sky segmentation ONNX model
sky_mask_dir: Optional directory for cached raw masks
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
Returns:
Updated confidence scores with sky regions masked out
"""
S, H, W = conf.shape
sky_mask_array = load_or_create_sky_masks(
image_folder=image_folder,
image_paths=image_paths,
images=images,
skyseg_model_path=skyseg_model_path,
sky_mask_dir=sky_mask_dir,
sky_mask_visualization_dir=sky_mask_visualization_dir,
target_shape=(H, W),
num_frames=S,
)
if sky_mask_array is None:
return conf
if sky_mask_array.shape[0] < S:
print(
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
"leaving the remaining frames unmasked"
)
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
padded[: sky_mask_array.shape[0]] = sky_mask_array
sky_mask_array = padded
elif sky_mask_array.shape[0] > S:
sky_mask_array = sky_mask_array[:S]
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
conf = conf * sky_mask_binary
print("Sky segmentation applied successfully")
return conf
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
"""
Download sky segmentation model from HuggingFace.
Args:
output_path: Path to save the model
Returns:
Path to the downloaded model
"""
import requests
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
print(f"Downloading sky segmentation model from {url}...")
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
with open(output_path, 'wb') as f:
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
pbar.update(len(chunk))
print(f"Model saved to {output_path}")
return output_path

206
lingbot_map/vis/utils.py Normal file
View File

@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Visualization utility functions for colorization and color bars.
"""
import dataclasses
from typing import Optional, Tuple
import numpy as np
import torch
import cv2
import matplotlib.cm as cm
@dataclasses.dataclass
class CameraState:
"""Camera state for rendering."""
fov: float
aspect: float
c2w: np.ndarray
def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
"""Get camera intrinsic matrix from FOV and image size."""
W, H = img_wh
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
K = np.array([
[focal_length, 0.0, W / 2.0],
[0.0, focal_length, H / 2.0],
[0.0, 0.0, 1.0],
])
return K
def get_vertical_colorbar(
h: int,
vmin: float,
vmax: float,
cmap_name: str = "jet",
label: Optional[str] = None,
cbar_precision: int = 2
) -> np.ndarray:
"""
Create a vertical colorbar image.
Args:
h: Height in pixels
vmin: Minimum value
vmax: Maximum value
cmap_name: Colormap name
label: Optional label for the colorbar
cbar_precision: Decimal precision for tick labels
Returns:
Colorbar image as numpy array (H, W, 3)
"""
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib as mpl
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x: np.ndarray,
cmap_name: str = "jet",
mask: Optional[np.ndarray] = None,
range: Optional[Tuple[float, float]] = None,
append_cbar: bool = False,
cbar_in_image: bool = False,
cbar_precision: int = 2,
) -> np.ndarray:
"""
Turn a grayscale image into a color image.
Args:
x: Input grayscale image [H, W]
cmap_name: Colormap name
mask: Optional mask image [H, W]
range: Value range for scaling [min, max], automatic if None
append_cbar: Whether to append colorbar
cbar_in_image: Put colorbar inside image
cbar_precision: Colorbar tick precision
Returns:
Colorized image [H, W, 3]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
x[np.logical_not(mask)] = vmin
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1]:, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
def colorize(
x: torch.Tensor,
cmap_name: str = "jet",
mask: Optional[torch.Tensor] = None,
range: Optional[Tuple[float, float]] = None,
append_cbar: bool = False,
cbar_in_image: bool = False
) -> torch.Tensor:
"""
Turn a grayscale image into a color image (PyTorch tensor version).
Args:
x: Grayscale image tensor [H, W] or [B, H, W]
cmap_name: Colormap name
mask: Optional mask tensor [H, W] or [B, H, W]
range: Value range for scaling
append_cbar: Whether to append colorbar
cbar_in_image: Put colorbar inside image
Returns:
Colorized tensor
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out

View File

@@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Quick visualization wrapper for GCT predictions using Viser.
"""
import time
import threading
from typing import List, Optional
import numpy as np
import viser
import viser.transforms as tf
from tqdm.auto import tqdm
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
def viser_wrapper(
pred_dict: dict,
port: int = 8080,
init_conf_threshold: float = 50.0,
use_point_map: bool = False,
background_mode: bool = False,
mask_sky: bool = False,
image_folder: Optional[str] = None,
):
"""
Visualize predicted 3D points and camera poses with viser.
This is a simplified wrapper for quick visualization without the full
PointCloudViewer controls.
Args:
pred_dict: Dictionary containing predictions with keys:
- images: (S, 3, H, W) - Input images
- world_points: (S, H, W, 3)
- world_points_conf: (S, H, W)
- depth: (S, H, W, 1)
- depth_conf: (S, H, W)
- extrinsic: (S, 3, 4)
- intrinsic: (S, 3, 3)
port: Port number for the viser server
init_conf_threshold: Initial percentage of low-confidence points to filter out
use_point_map: Whether to visualize world_points or use depth-based points
background_mode: Whether to run the server in background thread
mask_sky: Whether to apply sky segmentation to filter out sky points
image_folder: Path to the folder containing input images (for sky segmentation)
Returns:
viser.ViserServer: The viser server instance
"""
print(f"Starting viser server on port {port}")
server = viser.ViserServer(host="0.0.0.0", port=port)
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
# Unpack prediction dict
images = pred_dict["images"] # (S, 3, H, W)
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
conf_map = pred_dict["world_points_conf"] # (S, H, W)
depth_map = pred_dict["depth"] # (S, H, W, 1)
depth_conf = pred_dict["depth_conf"] # (S, H, W)
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
# Compute world points from depth if not using the precomputed point map
if not use_point_map:
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
conf = depth_conf
else:
world_points = world_points_map
conf = conf_map
# Apply sky segmentation if enabled
if mask_sky and image_folder is not None:
conf = apply_sky_segmentation(conf, image_folder)
# Convert images from (S, 3, H, W) to (S, H, W, 3)
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
shape = world_points.shape
S: int = shape[0]
H: int = shape[1]
W: int = shape[2]
# Flatten
points = world_points.reshape(-1, 3)
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
conf_flat = conf.reshape(-1)
# Random sample points if too many
indices = None
if points.shape[0] > 6000000:
print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
indices = np.random.choice(points.shape[0], size=6000000, replace=False)
points = points[indices]
colors_flat = colors_flat[indices]
conf_flat = conf_flat[indices]
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
cam_to_world = cam_to_world_mat[:, :3, :]
# Compute scene center and recenter
scene_center = np.mean(points, axis=0)
points_centered = points - scene_center
cam_to_world[..., -1] -= scene_center
# Store frame indices for filtering
frame_indices = (
np.repeat(np.arange(S), H * W)[indices]
if indices is not None
else np.repeat(np.arange(S), H * W)
)
# Build the viser GUI
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
gui_points_conf = server.gui.add_slider(
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
)
gui_frame_selector = server.gui.add_dropdown(
"Show Points from Frames",
options=["All"] + [str(i) for i in range(S)],
initial_value="All"
)
# Create the main point cloud
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
point_cloud = server.scene.add_point_cloud(
name="viser_pcd",
points=points_centered[init_conf_mask],
colors=colors_flat[init_conf_mask],
point_size=0.0005,
point_shape="circle",
)
frames: List[viser.FrameHandle] = []
frustums: List[viser.CameraFrustumHandle] = []
def visualize_frames(extrinsics, images_: np.ndarray) -> None:
"""Add camera frames and frustums to the scene."""
for f in frames:
f.remove()
frames.clear()
for fr in frustums:
fr.remove()
frustums.clear()
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
@frustum.on_click
def _(_) -> None:
for client in server.get_clients().values():
client.camera.wxyz = frame.wxyz
client.camera.position = frame.position
for img_id in tqdm(range(S)):
cam2world_3x4 = extrinsics[img_id]
T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
frame_axis = server.scene.add_frame(
f"frame_{img_id}",
wxyz=T_world_camera.rotation().wxyz,
position=T_world_camera.translation(),
axes_length=0.05,
axes_radius=0.002,
origin_radius=0.002,
)
frames.append(frame_axis)
img = images_[img_id]
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
h, w = img.shape[:2]
fy = 1.1 * h
fov = 2 * np.arctan2(h / 2, fy)
frustum_cam = server.scene.add_camera_frustum(
f"frame_{img_id}/frustum",
fov=fov,
aspect=w / h,
scale=0.05,
image=img,
line_width=1.0
)
frustums.append(frustum_cam)
attach_callback(frustum_cam, frame_axis)
def update_point_cloud() -> None:
"""Update point cloud based on current GUI selections."""
current_percentage = gui_points_conf.value
threshold_val = np.percentile(conf_flat, current_percentage)
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
if gui_frame_selector.value == "All":
frame_mask = np.ones_like(conf_mask, dtype=bool)
else:
selected_idx = int(gui_frame_selector.value)
frame_mask = frame_indices == selected_idx
combined_mask = conf_mask & frame_mask
point_cloud.points = points_centered[combined_mask]
point_cloud.colors = colors_flat[combined_mask]
@gui_points_conf.on_update
def _(_) -> None:
update_point_cloud()
@gui_frame_selector.on_update
def _(_) -> None:
update_point_cloud()
@gui_show_frames.on_update
def _(_) -> None:
for f in frames:
f.visible = gui_show_frames.value
for fr in frustums:
fr.visible = gui_show_frames.value
# Add camera frames
import torch
if torch.is_tensor(cam_to_world):
cam_to_world_np = cam_to_world.cpu().numpy()
else:
cam_to_world_np = cam_to_world
visualize_frames(cam_to_world_np, images)
print("Starting viser server...")
if background_mode:
def server_loop():
while True:
time.sleep(0.001)
thread = threading.Thread(target=server_loop, daemon=True)
thread.start()
else:
while True:
time.sleep(0.01)
return server

27
pyproject.toml Normal file
View File

@@ -0,0 +1,27 @@
[project]
name = "lingbot-map"
version = "0.1.0"
description = "LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction"
requires-python = ">= 3.10"
dependencies = [
"Pillow",
"huggingface_hub",
"einops",
"safetensors",
"opencv-python",
"tqdm",
"scipy",
"torchvision",
]
[project.optional-dependencies]
vis = ["viser>=0.2.23", "trimesh", "matplotlib", "onnxruntime", "requests"]
demo = ["lingbot-map[vis]"]
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
include = ["lingbot_map*"]