first commit
This commit is contained in:
12
.gitignore
vendored
Normal file
12
.gitignore
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.so
|
||||
.eggs/
|
||||
demo_render/
|
||||
CLAUDE.md
|
||||
.claude/
|
||||
.agents/
|
||||
399
LICENSE.txt
Normal file
399
LICENSE.txt
Normal file
@@ -0,0 +1,399 @@
|
||||
Attribution-NonCommercial 4.0 International
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
||||
does not provide legal services or legal advice. Distribution of
|
||||
Creative Commons public licenses does not create a lawyer-client or
|
||||
other relationship. Creative Commons makes its licenses and related
|
||||
information available on an "as-is" basis. Creative Commons gives no
|
||||
warranties regarding its licenses, any material licensed under their
|
||||
terms and conditions, or any related information. Creative Commons
|
||||
disclaims all liability for damages resulting from their use to the
|
||||
fullest extent possible.
|
||||
|
||||
Using Creative Commons Public Licenses
|
||||
|
||||
Creative Commons public licenses provide a standard set of terms and
|
||||
conditions that creators and other rights holders may use to share
|
||||
original works of authorship and other material subject to copyright
|
||||
and certain other rights specified in the public license below. The
|
||||
following considerations are for informational purposes only, are not
|
||||
exhaustive, and do not form part of our licenses.
|
||||
|
||||
Considerations for licensors: Our public licenses are
|
||||
intended for use by those authorized to give the public
|
||||
permission to use material in ways otherwise restricted by
|
||||
copyright and certain other rights. Our licenses are
|
||||
irrevocable. Licensors should read and understand the terms
|
||||
and conditions of the license they choose before applying it.
|
||||
Licensors should also secure all rights necessary before
|
||||
applying our licenses so that the public can reuse the
|
||||
material as expected. Licensors should clearly mark any
|
||||
material not subject to the license. This includes other CC-
|
||||
licensed material, or material used under an exception or
|
||||
limitation to copyright. More considerations for licensors:
|
||||
wiki.creativecommons.org/Considerations_for_licensors
|
||||
|
||||
Considerations for the public: By using one of our public
|
||||
licenses, a licensor grants the public permission to use the
|
||||
licensed material under specified terms and conditions. If
|
||||
the licensor's permission is not necessary for any reason--for
|
||||
example, because of any applicable exception or limitation to
|
||||
copyright--then that use is not regulated by the license. Our
|
||||
licenses grant only permissions under copyright and certain
|
||||
other rights that a licensor has authority to grant. Use of
|
||||
the licensed material may still be restricted for other
|
||||
reasons, including because others have copyright or other
|
||||
rights in the material. A licensor may make special requests,
|
||||
such as asking that all changes be marked or described.
|
||||
Although not required by our licenses, you are encouraged to
|
||||
respect those requests where reasonable. More_considerations
|
||||
for the public:
|
||||
wiki.creativecommons.org/Considerations_for_licensees
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons Attribution-NonCommercial 4.0 International Public
|
||||
License
|
||||
|
||||
By exercising the Licensed Rights (defined below), You accept and agree
|
||||
to be bound by the terms and conditions of this Creative Commons
|
||||
Attribution-NonCommercial 4.0 International Public License ("Public
|
||||
License"). To the extent this Public License may be interpreted as a
|
||||
contract, You are granted the Licensed Rights in consideration of Your
|
||||
acceptance of these terms and conditions, and the Licensor grants You
|
||||
such rights in consideration of benefits the Licensor receives from
|
||||
making the Licensed Material available under these terms and
|
||||
conditions.
|
||||
|
||||
Section 1 -- Definitions.
|
||||
|
||||
a. Adapted Material means material subject to Copyright and Similar
|
||||
Rights that is derived from or based upon the Licensed Material
|
||||
and in which the Licensed Material is translated, altered,
|
||||
arranged, transformed, or otherwise modified in a manner requiring
|
||||
permission under the Copyright and Similar Rights held by the
|
||||
Licensor. For purposes of this Public License, where the Licensed
|
||||
Material is a musical work, performance, or sound recording,
|
||||
Adapted Material is always produced where the Licensed Material is
|
||||
synched in timed relation with a moving image.
|
||||
|
||||
b. Adapter's License means the license You apply to Your Copyright
|
||||
and Similar Rights in Your contributions to Adapted Material in
|
||||
accordance with the terms and conditions of this Public License.
|
||||
|
||||
c. Copyright and Similar Rights means copyright and/or similar rights
|
||||
closely related to copyright including, without limitation,
|
||||
performance, broadcast, sound recording, and Sui Generis Database
|
||||
Rights, without regard to how the rights are labeled or
|
||||
categorized. For purposes of this Public License, the rights
|
||||
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
||||
Rights.
|
||||
d. Effective Technological Measures means those measures that, in the
|
||||
absence of proper authority, may not be circumvented under laws
|
||||
fulfilling obligations under Article 11 of the WIPO Copyright
|
||||
Treaty adopted on December 20, 1996, and/or similar international
|
||||
agreements.
|
||||
|
||||
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
||||
any other exception or limitation to Copyright and Similar Rights
|
||||
that applies to Your use of the Licensed Material.
|
||||
|
||||
f. Licensed Material means the artistic or literary work, database,
|
||||
or other material to which the Licensor applied this Public
|
||||
License.
|
||||
|
||||
g. Licensed Rights means the rights granted to You subject to the
|
||||
terms and conditions of this Public License, which are limited to
|
||||
all Copyright and Similar Rights that apply to Your use of the
|
||||
Licensed Material and that the Licensor has authority to license.
|
||||
|
||||
h. Licensor means the individual(s) or entity(ies) granting rights
|
||||
under this Public License.
|
||||
|
||||
i. NonCommercial means not primarily intended for or directed towards
|
||||
commercial advantage or monetary compensation. For purposes of
|
||||
this Public License, the exchange of the Licensed Material for
|
||||
other material subject to Copyright and Similar Rights by digital
|
||||
file-sharing or similar means is NonCommercial provided there is
|
||||
no payment of monetary compensation in connection with the
|
||||
exchange.
|
||||
|
||||
j. Share means to provide material to the public by any means or
|
||||
process that requires permission under the Licensed Rights, such
|
||||
as reproduction, public display, public performance, distribution,
|
||||
dissemination, communication, or importation, and to make material
|
||||
available to the public including in ways that members of the
|
||||
public may access the material from a place and at a time
|
||||
individually chosen by them.
|
||||
|
||||
k. Sui Generis Database Rights means rights other than copyright
|
||||
resulting from Directive 96/9/EC of the European Parliament and of
|
||||
the Council of 11 March 1996 on the legal protection of databases,
|
||||
as amended and/or succeeded, as well as other essentially
|
||||
equivalent rights anywhere in the world.
|
||||
|
||||
l. You means the individual or entity exercising the Licensed Rights
|
||||
under this Public License. Your has a corresponding meaning.
|
||||
|
||||
Section 2 -- Scope.
|
||||
|
||||
a. License grant.
|
||||
|
||||
1. Subject to the terms and conditions of this Public License,
|
||||
the Licensor hereby grants You a worldwide, royalty-free,
|
||||
non-sublicensable, non-exclusive, irrevocable license to
|
||||
exercise the Licensed Rights in the Licensed Material to:
|
||||
|
||||
a. reproduce and Share the Licensed Material, in whole or
|
||||
in part, for NonCommercial purposes only; and
|
||||
|
||||
b. produce, reproduce, and Share Adapted Material for
|
||||
NonCommercial purposes only.
|
||||
|
||||
2. Exceptions and Limitations. For the avoidance of doubt, where
|
||||
Exceptions and Limitations apply to Your use, this Public
|
||||
License does not apply, and You do not need to comply with
|
||||
its terms and conditions.
|
||||
|
||||
3. Term. The term of this Public License is specified in Section
|
||||
6(a).
|
||||
|
||||
4. Media and formats; technical modifications allowed. The
|
||||
Licensor authorizes You to exercise the Licensed Rights in
|
||||
all media and formats whether now known or hereafter created,
|
||||
and to make technical modifications necessary to do so. The
|
||||
Licensor waives and/or agrees not to assert any right or
|
||||
authority to forbid You from making technical modifications
|
||||
necessary to exercise the Licensed Rights, including
|
||||
technical modifications necessary to circumvent Effective
|
||||
Technological Measures. For purposes of this Public License,
|
||||
simply making modifications authorized by this Section 2(a)
|
||||
(4) never produces Adapted Material.
|
||||
|
||||
5. Downstream recipients.
|
||||
|
||||
a. Offer from the Licensor -- Licensed Material. Every
|
||||
recipient of the Licensed Material automatically
|
||||
receives an offer from the Licensor to exercise the
|
||||
Licensed Rights under the terms and conditions of this
|
||||
Public License.
|
||||
|
||||
b. No downstream restrictions. You may not offer or impose
|
||||
any additional or different terms or conditions on, or
|
||||
apply any Effective Technological Measures to, the
|
||||
Licensed Material if doing so restricts exercise of the
|
||||
Licensed Rights by any recipient of the Licensed
|
||||
Material.
|
||||
|
||||
6. No endorsement. Nothing in this Public License constitutes or
|
||||
may be construed as permission to assert or imply that You
|
||||
are, or that Your use of the Licensed Material is, connected
|
||||
with, or sponsored, endorsed, or granted official status by,
|
||||
the Licensor or others designated to receive attribution as
|
||||
provided in Section 3(a)(1)(A)(i).
|
||||
|
||||
b. Other rights.
|
||||
|
||||
1. Moral rights, such as the right of integrity, are not
|
||||
licensed under this Public License, nor are publicity,
|
||||
privacy, and/or other similar personality rights; however, to
|
||||
the extent possible, the Licensor waives and/or agrees not to
|
||||
assert any such rights held by the Licensor to the limited
|
||||
extent necessary to allow You to exercise the Licensed
|
||||
Rights, but not otherwise.
|
||||
|
||||
2. Patent and trademark rights are not licensed under this
|
||||
Public License.
|
||||
|
||||
3. To the extent possible, the Licensor waives any right to
|
||||
collect royalties from You for the exercise of the Licensed
|
||||
Rights, whether directly or through a collecting society
|
||||
under any voluntary or waivable statutory or compulsory
|
||||
licensing scheme. In all other cases the Licensor expressly
|
||||
reserves any right to collect such royalties, including when
|
||||
the Licensed Material is used other than for NonCommercial
|
||||
purposes.
|
||||
|
||||
Section 3 -- License Conditions.
|
||||
|
||||
Your exercise of the Licensed Rights is expressly made subject to the
|
||||
following conditions.
|
||||
|
||||
a. Attribution.
|
||||
|
||||
1. If You Share the Licensed Material (including in modified
|
||||
form), You must:
|
||||
|
||||
a. retain the following if it is supplied by the Licensor
|
||||
with the Licensed Material:
|
||||
|
||||
i. identification of the creator(s) of the Licensed
|
||||
Material and any others designated to receive
|
||||
attribution, in any reasonable manner requested by
|
||||
the Licensor (including by pseudonym if
|
||||
designated);
|
||||
|
||||
ii. a copyright notice;
|
||||
|
||||
iii. a notice that refers to this Public License;
|
||||
|
||||
iv. a notice that refers to the disclaimer of
|
||||
warranties;
|
||||
|
||||
v. a URI or hyperlink to the Licensed Material to the
|
||||
extent reasonably practicable;
|
||||
|
||||
b. indicate if You modified the Licensed Material and
|
||||
retain an indication of any previous modifications; and
|
||||
|
||||
c. indicate the Licensed Material is licensed under this
|
||||
Public License, and include the text of, or the URI or
|
||||
hyperlink to, this Public License.
|
||||
|
||||
2. You may satisfy the conditions in Section 3(a)(1) in any
|
||||
reasonable manner based on the medium, means, and context in
|
||||
which You Share the Licensed Material. For example, it may be
|
||||
reasonable to satisfy the conditions by providing a URI or
|
||||
hyperlink to a resource that includes the required
|
||||
information.
|
||||
|
||||
3. If requested by the Licensor, You must remove any of the
|
||||
information required by Section 3(a)(1)(A) to the extent
|
||||
reasonably practicable.
|
||||
|
||||
4. If You Share Adapted Material You produce, the Adapter's
|
||||
License You apply must not prevent recipients of the Adapted
|
||||
Material from complying with this Public License.
|
||||
|
||||
Section 4 -- Sui Generis Database Rights.
|
||||
|
||||
Where the Licensed Rights include Sui Generis Database Rights that
|
||||
apply to Your use of the Licensed Material:
|
||||
|
||||
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
||||
to extract, reuse, reproduce, and Share all or a substantial
|
||||
portion of the contents of the database for NonCommercial purposes
|
||||
only;
|
||||
|
||||
b. if You include all or a substantial portion of the database
|
||||
contents in a database in which You have Sui Generis Database
|
||||
Rights, then the database in which You have Sui Generis Database
|
||||
Rights (but not its individual contents) is Adapted Material; and
|
||||
|
||||
c. You must comply with the conditions in Section 3(a) if You Share
|
||||
all or a substantial portion of the contents of the database.
|
||||
|
||||
For the avoidance of doubt, this Section 4 supplements and does not
|
||||
replace Your obligations under this Public License where the Licensed
|
||||
Rights include other Copyright and Similar Rights.
|
||||
|
||||
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
||||
|
||||
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
||||
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
||||
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
||||
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
||||
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
||||
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
||||
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
||||
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
||||
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
||||
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
||||
|
||||
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
||||
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
||||
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
||||
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
||||
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
||||
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
||||
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
||||
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
||||
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
||||
|
||||
c. The disclaimer of warranties and limitation of liability provided
|
||||
above shall be interpreted in a manner that, to the extent
|
||||
possible, most closely approximates an absolute disclaimer and
|
||||
waiver of all liability.
|
||||
|
||||
Section 6 -- Term and Termination.
|
||||
|
||||
a. This Public License applies for the term of the Copyright and
|
||||
Similar Rights licensed here. However, if You fail to comply with
|
||||
this Public License, then Your rights under this Public License
|
||||
terminate automatically.
|
||||
|
||||
b. Where Your right to use the Licensed Material has terminated under
|
||||
Section 6(a), it reinstates:
|
||||
|
||||
1. automatically as of the date the violation is cured, provided
|
||||
it is cured within 30 days of Your discovery of the
|
||||
violation; or
|
||||
|
||||
2. upon express reinstatement by the Licensor.
|
||||
|
||||
For the avoidance of doubt, this Section 6(b) does not affect any
|
||||
right the Licensor may have to seek remedies for Your violations
|
||||
of this Public License.
|
||||
|
||||
c. For the avoidance of doubt, the Licensor may also offer the
|
||||
Licensed Material under separate terms or conditions or stop
|
||||
distributing the Licensed Material at any time; however, doing so
|
||||
will not terminate this Public License.
|
||||
|
||||
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
||||
License.
|
||||
|
||||
Section 7 -- Other Terms and Conditions.
|
||||
|
||||
a. The Licensor shall not be bound by any additional or different
|
||||
terms or conditions communicated by You unless expressly agreed.
|
||||
|
||||
b. Any arrangements, understandings, or agreements regarding the
|
||||
Licensed Material not stated herein are separate from and
|
||||
independent of the terms and conditions of this Public License.
|
||||
|
||||
Section 8 -- Interpretation.
|
||||
|
||||
a. For the avoidance of doubt, this Public License does not, and
|
||||
shall not be interpreted to, reduce, limit, restrict, or impose
|
||||
conditions on any use of the Licensed Material that could lawfully
|
||||
be made without permission under this Public License.
|
||||
|
||||
b. To the extent possible, if any provision of this Public License is
|
||||
deemed unenforceable, it shall be automatically reformed to the
|
||||
minimum extent necessary to make it enforceable. If the provision
|
||||
cannot be reformed, it shall be severed from this Public License
|
||||
without affecting the enforceability of the remaining terms and
|
||||
conditions.
|
||||
|
||||
c. No term or condition of this Public License will be waived and no
|
||||
failure to comply consented to unless expressly agreed to by the
|
||||
Licensor.
|
||||
|
||||
d. Nothing in this Public License constitutes or may be interpreted
|
||||
as a limitation upon, or waiver of, any privileges and immunities
|
||||
that apply to the Licensor or You, including from the legal
|
||||
processes of any jurisdiction or authority.
|
||||
|
||||
=======================================================================
|
||||
|
||||
Creative Commons is not a party to its public
|
||||
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
||||
its public licenses to material it publishes and in those instances
|
||||
will be considered the “Licensor.” The text of the Creative Commons
|
||||
public licenses is dedicated to the public domain under the CC0 Public
|
||||
Domain Dedication. Except for the limited purpose of indicating that
|
||||
material is shared under a Creative Commons public license or as
|
||||
otherwise permitted by the Creative Commons policies published at
|
||||
creativecommons.org/policies, Creative Commons does not authorize the
|
||||
use of the trademark "Creative Commons" or any other trademark or logo
|
||||
of Creative Commons without its prior written consent including,
|
||||
without limitation, in connection with any unauthorized modifications
|
||||
to any of its public licenses or any other arrangements,
|
||||
understandings, or agreements concerning use of licensed material. For
|
||||
the avoidance of doubt, this paragraph does not form part of the
|
||||
public licenses.
|
||||
|
||||
Creative Commons may be contacted at creativecommons.org.
|
||||
141
README.md
Normal file
141
README.md
Normal file
@@ -0,0 +1,141 @@
|
||||
<h1 align="center">LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction</h1>
|
||||
|
||||
<p align="center">
|
||||
<a href="lingbot-map_paper.pdf"><img src="https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv"></a>
|
||||
<a href="https://technology.robbyant.com/lingbot-map"><img src="https://img.shields.io/badge/Project-Website-blue"></a>
|
||||
<a href="https://huggingface.co/robbyant/lingbot-map"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=orange"></a>
|
||||
<a href="LICENSE.txt"><img src="https://img.shields.io/badge/License-Apache--2.0-green"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/teaser.png" width="100%">
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<video src="https://gw.alipayobjects.com/v/huamei_vaouhm/afts/video/q0sdTr9Mm6IAAAAAmyAAAAgADglFAQJr" width="100%" autoplay loop muted playsinline></video>
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
# Quick Start
|
||||
|
||||
## Installation
|
||||
|
||||
**1. Create conda environment**
|
||||
|
||||
```bash
|
||||
conda create -n lingbot-map python=3.10 -y
|
||||
conda activate lingbot-map
|
||||
```
|
||||
|
||||
**2. Install PyTorch (CUDA 12.8)**
|
||||
|
||||
```bash
|
||||
pip install torch==2.9.1 torchvision==0.24.1 --index-url https://download.pytorch.org/whl/cu128
|
||||
```
|
||||
|
||||
> For other CUDA versions, see [PyTorch Get Started](https://pytorch.org/get-started/locally/).
|
||||
|
||||
**3. Install lingbot-map**
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
**4. Install FlashInfer (recommended)**
|
||||
|
||||
FlashInfer provides paged KV cache attention for efficient streaming inference:
|
||||
|
||||
```bash
|
||||
# CUDA 12.8 + PyTorch 2.9
|
||||
pip install flashinfer-python -i https://flashinfer.ai/whl/cu128/torch2.9/
|
||||
```
|
||||
|
||||
> For other CUDA/PyTorch combinations, see [FlashInfer installation](https://docs.flashinfer.ai/installation.html).
|
||||
> If FlashInfer is not installed, the model falls back to SDPA (PyTorch native attention) via `--use_sdpa`.
|
||||
|
||||
**5. Visualization dependencies (optional)**
|
||||
|
||||
```bash
|
||||
pip install -e ".[vis]"
|
||||
```
|
||||
|
||||
# Demo
|
||||
|
||||
## Streaming Inference from Images
|
||||
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/
|
||||
```
|
||||
|
||||
## Streaming Inference from Video
|
||||
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--video_path video.mp4 --fps 10
|
||||
```
|
||||
|
||||
## Streaming with Keyframe Interval
|
||||
|
||||
Use `--keyframe_interval` to reduce KV cache memory by only keeping every N-th frame as a keyframe. Non-keyframe frames still produce predictions but are not stored in the cache. This is useful for long sequences
|
||||
which excesses 320 frames.
|
||||
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/ --keyframe_interval 6
|
||||
```
|
||||
|
||||
## Windowed Inference (for long sequences, >3000 frames)
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--video_path video.mp4 --fps 10 \
|
||||
--mode windowed --window_size 64
|
||||
```
|
||||
|
||||
|
||||
## With Sky Masking
|
||||
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/ --mask_sky
|
||||
```
|
||||
|
||||
## Without FlashInfer (SDPA fallback)
|
||||
|
||||
```bash
|
||||
python demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/ --use_sdpa
|
||||
```
|
||||
|
||||
# Model Download
|
||||
|
||||
| Model Name | Huggingface Repository | Description |
|
||||
| :--- | :--- | :--- |
|
||||
| lingbot-map | [robbyant/lingbot-map](https://huggingface.co/robbyant/lingbot-map) | Base model checkpoint (4.63 GB) |
|
||||
|
||||
|
||||
# License
|
||||
|
||||
This project is released under the Apache License 2.0. See [LICENSE](LICENSE.txt) file for details.
|
||||
|
||||
# Citation
|
||||
|
||||
```bibtex
|
||||
@article{lingbot-map2026,
|
||||
title={},
|
||||
author={},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2026}
|
||||
}
|
||||
```
|
||||
|
||||
# Acknowledgments
|
||||
|
||||
This work builds upon several excellent open-source projects:
|
||||
|
||||
- [VGGT](https://github.com/facebookresearch/vggt)
|
||||
- [DINOv2](https://github.com/facebookresearch/dinov2)
|
||||
- [Flashinfer](https://github.com/flashinfer-ai/flashinfer)
|
||||
|
||||
---
|
||||
BIN
assets/teaser.png
Normal file
BIN
assets/teaser.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.9 MiB |
346
demo.py
Normal file
346
demo.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""LingBot-MAP demo: streaming 3D reconstruction from images or video.
|
||||
|
||||
Usage:
|
||||
# Streaming inference (frame-by-frame with KV cache)
|
||||
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/
|
||||
|
||||
# Streaming inference with keyframe KV caching
|
||||
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
||||
--image_folder /path/to/images/ --mode streaming --keyframe_interval 6
|
||||
|
||||
# Windowed inference (for very long sequences, >500 frames)
|
||||
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
||||
--video_path video.mp4 --fps 10 --mode windowed --window_size 64
|
||||
|
||||
# From video with custom FPS sampling
|
||||
python examples/demo.py --model_path /path/to/checkpoint.pt \
|
||||
--video_path video.mp4 --fps 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3_general
|
||||
from lingbot_map.utils.load_fn import load_and_preprocess_images
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image loading
|
||||
# =============================================================================
|
||||
|
||||
def load_images(image_folder=None, video_path=None, fps=10, image_ext=".jpg,.png",
|
||||
first_k=None, stride=1, image_size=518, patch_size=14, num_workers=8):
|
||||
"""Load images from folder or video and preprocess into a tensor."""
|
||||
if video_path is not None:
|
||||
video_name = os.path.splitext(os.path.basename(video_path))[0]
|
||||
out_dir = os.path.join(os.path.dirname(video_path), f"{video_name}_frames")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
src_fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
interval = max(1, round(src_fps / fps))
|
||||
idx, saved = 0, []
|
||||
pbar = tqdm(total=total_frames, desc="Extracting frames", unit="frame")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if idx % interval == 0:
|
||||
path = os.path.join(out_dir, f"{len(saved):06d}.jpg")
|
||||
cv2.imwrite(path, frame)
|
||||
saved.append(path)
|
||||
idx += 1
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
cap.release()
|
||||
paths = saved
|
||||
print(f"Extracted {len(paths)} frames from video ({total_frames} total, interval={interval})")
|
||||
else:
|
||||
exts = image_ext.split(",")
|
||||
paths = []
|
||||
for ext in exts:
|
||||
paths.extend(glob.glob(os.path.join(image_folder, f"*{ext}")))
|
||||
paths = sorted(paths)
|
||||
|
||||
if stride > 1:
|
||||
paths = paths[::stride]
|
||||
if first_k is not None and first_k > 0:
|
||||
paths = paths[:first_k]
|
||||
|
||||
print(f"Loading {len(paths)} images...")
|
||||
images = load_and_preprocess_images(
|
||||
paths,
|
||||
mode="crop",
|
||||
image_size=image_size,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
h, w = images.shape[-2:]
|
||||
print(f"Preprocessed images to {w}x{h} using canonical crop mode")
|
||||
return images, paths
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model loading
|
||||
# =============================================================================
|
||||
|
||||
def load_model(args, device):
|
||||
"""Load GCTStream model from checkpoint."""
|
||||
if getattr(args, "mode", "streaming") == "windowed":
|
||||
from lingbot_map.models.gct_stream_window import GCTStream
|
||||
else:
|
||||
from lingbot_map.models.gct_stream import GCTStream
|
||||
|
||||
print("Building model...")
|
||||
model = GCTStream(
|
||||
img_size=args.image_size,
|
||||
patch_size=args.patch_size,
|
||||
enable_3d_rope=args.enable_3d_rope,
|
||||
max_frame_num=args.max_frame_num,
|
||||
kv_cache_sliding_window=args.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=args.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=True,
|
||||
kv_cache_include_scale_frames=True,
|
||||
use_sdpa=args.use_sdpa,
|
||||
)
|
||||
|
||||
if args.model_path:
|
||||
print(f"Loading checkpoint: {args.model_path}")
|
||||
ckpt = torch.load(args.model_path, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("model", ckpt)
|
||||
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
||||
if missing:
|
||||
print(f" Missing keys: {len(missing)}")
|
||||
if unexpected:
|
||||
print(f" Unexpected keys: {len(unexpected)}")
|
||||
print(" Checkpoint loaded.")
|
||||
|
||||
return model.to(device).eval()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-processing
|
||||
# =============================================================================
|
||||
|
||||
_BATCHED_NDIMS = {
|
||||
"pose_enc": 3,
|
||||
"depth": 5,
|
||||
"depth_conf": 4,
|
||||
"world_points": 5,
|
||||
"world_points_conf": 4,
|
||||
"extrinsic": 4,
|
||||
"intrinsic": 4,
|
||||
"chunk_sim3_scales": 2,
|
||||
"chunk_sim3_poses": 4,
|
||||
"chunk_se3_poses": 4,
|
||||
"images": 5,
|
||||
}
|
||||
|
||||
|
||||
def _squeeze_single_batch(key, value):
|
||||
"""Drop the leading batch dimension for single-sequence demo outputs."""
|
||||
batched_ndim = _BATCHED_NDIMS.get(key)
|
||||
if batched_ndim is None or not hasattr(value, "ndim"):
|
||||
return value
|
||||
if value.ndim == batched_ndim and value.shape[0] == 1:
|
||||
return value[0]
|
||||
return value
|
||||
|
||||
|
||||
def postprocess(predictions, images):
|
||||
"""Convert pose encoding to extrinsics (c2w) and move to CPU."""
|
||||
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
|
||||
|
||||
# Convert w2c to c2w
|
||||
extrinsic_4x4 = torch.zeros((*extrinsic.shape[:-2], 4, 4), device=extrinsic.device, dtype=extrinsic.dtype)
|
||||
extrinsic_4x4[..., :3, :4] = extrinsic
|
||||
extrinsic_4x4[..., 3, 3] = 1.0
|
||||
extrinsic_4x4 = closed_form_inverse_se3_general(extrinsic_4x4)
|
||||
extrinsic = extrinsic_4x4[..., :3, :4]
|
||||
|
||||
predictions["extrinsic"] = extrinsic
|
||||
predictions["intrinsic"] = intrinsic
|
||||
predictions.pop("pose_enc_list", None)
|
||||
predictions.pop("images", None)
|
||||
|
||||
print("Moving results to CPU...")
|
||||
for k in list(predictions.keys()):
|
||||
if isinstance(predictions[k], torch.Tensor):
|
||||
predictions[k] = _squeeze_single_batch(
|
||||
k, predictions[k].to("cpu", non_blocking=True)
|
||||
)
|
||||
images_cpu = images.to("cpu", non_blocking=True)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return predictions, images_cpu
|
||||
|
||||
|
||||
def prepare_for_visualization(predictions, images=None):
|
||||
"""Convert predictions to the unbatched NumPy format used by vis code."""
|
||||
vis_predictions = {}
|
||||
for k, v in predictions.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = _squeeze_single_batch(k, v.detach().cpu())
|
||||
vis_predictions[k] = v.numpy()
|
||||
elif isinstance(v, np.ndarray):
|
||||
vis_predictions[k] = _squeeze_single_batch(k, v)
|
||||
else:
|
||||
vis_predictions[k] = v
|
||||
|
||||
if images is None:
|
||||
images = predictions.get("images")
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.detach().cpu()
|
||||
if isinstance(images, np.ndarray):
|
||||
images = _squeeze_single_batch("images", images)
|
||||
elif isinstance(images, torch.Tensor):
|
||||
images = _squeeze_single_batch("images", images).numpy()
|
||||
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.numpy()
|
||||
|
||||
if images is not None:
|
||||
vis_predictions["images"] = images
|
||||
|
||||
return vis_predictions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="LingBot-MAP: Streaming 3D Reconstruction Demo")
|
||||
|
||||
# Input
|
||||
parser.add_argument("--image_folder", type=str, default=None)
|
||||
parser.add_argument("--video_path", type=str, default=None)
|
||||
parser.add_argument("--fps", type=int, default=10)
|
||||
parser.add_argument("--first_k", type=int, default=None)
|
||||
parser.add_argument("--stride", type=int, default=1)
|
||||
|
||||
# Model
|
||||
parser.add_argument("--model_path", type=str, required=True)
|
||||
parser.add_argument("--image_size", type=int, default=518)
|
||||
parser.add_argument("--patch_size", type=int, default=14)
|
||||
|
||||
# Inference mode
|
||||
parser.add_argument("--mode", type=str, default="streaming", choices=["streaming", "windowed"],
|
||||
help="streaming: frame-by-frame with KV cache; windowed: overlapping windows for long sequences")
|
||||
|
||||
# Streaming options
|
||||
parser.add_argument("--enable_3d_rope", action="store_true", default=True)
|
||||
parser.add_argument("--max_frame_num", type=int, default=1024)
|
||||
parser.add_argument("--num_scale_frames", type=int, default=8)
|
||||
parser.add_argument(
|
||||
"--keyframe_interval",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Streaming only. Every N-th frame after scale frames is kept as a keyframe. 1 = every frame.",
|
||||
)
|
||||
parser.add_argument("--kv_cache_sliding_window", type=int, default=64)
|
||||
parser.add_argument("--kv_cache_scale_frames", type=int, default=8)
|
||||
parser.add_argument("--use_sdpa", action="store_true", default=False,
|
||||
help="Use SDPA backend (no flashinfer needed). Default: FlashInfer")
|
||||
|
||||
# Windowed options
|
||||
parser.add_argument("--window_size", type=int, default=64, help="Frames per window (windowed mode)")
|
||||
parser.add_argument("--overlap_size", type=int, default=16, help="Overlap between windows")
|
||||
parser.add_argument("--sim3", action="store_true", default=True, help="Use Sim(3) alignment between windows")
|
||||
parser.add_argument("--no_sim3", dest="sim3", action="store_false", help="Disable Sim(3), use SE(3) instead")
|
||||
|
||||
# Visualization
|
||||
parser.add_argument("--port", type=int, default=8080)
|
||||
parser.add_argument("--conf_threshold", type=float, default=1.0)
|
||||
parser.add_argument("--downsample_factor", type=int, default=10)
|
||||
parser.add_argument("--point_size", type=float, default=0.005)
|
||||
parser.add_argument("--mask_sky", action="store_true", help="Apply sky segmentation to filter out sky points")
|
||||
|
||||
args = parser.parse_args()
|
||||
assert args.image_folder or args.video_path, \
|
||||
"Provide --image_folder or --video_path"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ── Load images & model ──────────────────────────────────────────────────
|
||||
t0 = time.time()
|
||||
images, paths = load_images(
|
||||
image_folder=args.image_folder, video_path=args.video_path,
|
||||
fps=args.fps, first_k=args.first_k, stride=args.stride,
|
||||
image_size=args.image_size, patch_size=args.patch_size,
|
||||
)
|
||||
model = load_model(args, device)
|
||||
print(f"Total load time: {time.time() - t0:.1f}s")
|
||||
|
||||
images = images.to(device)
|
||||
num_frames = images.shape[0]
|
||||
print(f"Input: {num_frames} frames, shape {tuple(images.shape)}")
|
||||
print(f"Mode: {args.mode}")
|
||||
|
||||
if args.mode != "streaming" and args.keyframe_interval != 1:
|
||||
print("Warning: --keyframe_interval only applies to --mode streaming. Ignoring it for windowed inference.")
|
||||
args.keyframe_interval = 1
|
||||
elif args.mode == "streaming" and args.keyframe_interval > 1:
|
||||
print(
|
||||
f"Keyframe streaming enabled: interval={args.keyframe_interval} "
|
||||
f"(after the first {args.num_scale_frames} scale frames)."
|
||||
)
|
||||
|
||||
# ── Inference ────────────────────────────────────────────────────────────
|
||||
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
|
||||
print(f"Running {args.mode} inference (dtype={dtype})...")
|
||||
t0 = time.time()
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cuda", dtype=dtype):
|
||||
if args.mode == "streaming":
|
||||
predictions = model.inference_streaming(
|
||||
images,
|
||||
num_scale_frames=args.num_scale_frames,
|
||||
keyframe_interval=args.keyframe_interval,
|
||||
)
|
||||
else: # windowed
|
||||
predictions = model.inference_windowed(
|
||||
images,
|
||||
window_size=args.window_size,
|
||||
overlap_size=args.overlap_size,
|
||||
num_scale_frames=args.num_scale_frames,
|
||||
sim3=args.sim3,
|
||||
se3=not args.sim3,
|
||||
)
|
||||
|
||||
t_infer = time.time() - t0
|
||||
print(f"Inference done: {t_infer:.1f}s ({num_frames / t_infer:.1f} FPS)")
|
||||
|
||||
# ── Post-process ─────────────────────────────────────────────────────────
|
||||
predictions, images_cpu = postprocess(predictions, images)
|
||||
|
||||
# ── Visualize ────────────────────────────────────────────────────────────
|
||||
try:
|
||||
from lingbot_map.vis import PointCloudViewer
|
||||
viewer = PointCloudViewer(
|
||||
pred_dict=prepare_for_visualization(predictions, images_cpu),
|
||||
port=args.port,
|
||||
init_conf_threshold=args.conf_threshold,
|
||||
downsample_factor=args.downsample_factor,
|
||||
point_size=args.point_size,
|
||||
mask_sky=args.mask_sky,
|
||||
)
|
||||
print(f"3D viewer at http://localhost:{args.port}")
|
||||
viewer.run()
|
||||
except ImportError:
|
||||
print("viser not installed. Install with: pip install lingbot-map[vis]")
|
||||
print(f"Predictions contain keys: {list(predictions.keys())}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
docs/.DS_Store
vendored
Normal file
BIN
docs/.DS_Store
vendored
Normal file
Binary file not shown.
BIN
lingbot-map_paper.pdf
Normal file
BIN
lingbot-map_paper.pdf
Normal file
Binary file not shown.
BIN
lingbot_map/.DS_Store
vendored
Normal file
BIN
lingbot_map/.DS_Store
vendored
Normal file
Binary file not shown.
0
lingbot_map/__init__.py
Normal file
0
lingbot_map/__init__.py
Normal file
2
lingbot_map/aggregator/__init__.py
Normal file
2
lingbot_map/aggregator/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .stream import AggregatorStream
|
||||
from .base import AggregatorBase
|
||||
608
lingbot_map/aggregator/base.py
Normal file
608
lingbot_map/aggregator/base.py
Normal file
@@ -0,0 +1,608 @@
|
||||
"""
|
||||
AggregatorBase - Base class for all Aggregator implementations.
|
||||
|
||||
Provides shared functionality:
|
||||
- Patch embedding (DINOv2)
|
||||
- Special tokens (camera, register, scale)
|
||||
- Block building
|
||||
- Common forward pass structure
|
||||
|
||||
Subclasses implement mode-specific attention logic.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from lingbot_map.layers import PatchEmbed
|
||||
from lingbot_map.layers.block import Block
|
||||
from lingbot_map.layers.rope import RotaryPositionEmbedding2D, PositionGetter
|
||||
from lingbot_map.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
||||
_RESNET_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
def slice_expand_and_flatten(token, B, S, first_num_frame=1):
|
||||
"""
|
||||
Helper function to slice, expand and flatten tokens.
|
||||
|
||||
Args:
|
||||
token: Token tensor [1, 2, N, C] where first index is for first frames
|
||||
B: Batch size
|
||||
S: Sequence length
|
||||
first_num_frame: Number of frames to use first token for
|
||||
|
||||
Returns:
|
||||
Flattened tokens [B*S, N, C]
|
||||
"""
|
||||
# token shape: [1, 2, N, C]
|
||||
# Expand to [B, S, N, C]
|
||||
if first_num_frame > 1:
|
||||
# Use first token for first first_num_frame frames, second for rest
|
||||
token_first = token[:, :1].expand(B, first_num_frame, -1, -1) # [B, first_num_frame, N, C]
|
||||
token_rest = token[:, 1:].expand(B, S - first_num_frame, -1, -1) # [B, S-first_num_frame, N, C]
|
||||
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
|
||||
else:
|
||||
# Use first token for first frame, second for rest
|
||||
token_first = token[:, :1].expand(B, 1, -1, -1) # [B, 1, N, C]
|
||||
token_rest = token[:, 1:].expand(B, S - 1, -1, -1) # [B, S-1, N, C]
|
||||
token_expanded = torch.cat([token_first, token_rest], dim=1) # [B, S, N, C]
|
||||
|
||||
# Flatten to [B*S, N, C]
|
||||
return token_expanded.reshape(B * S, -1, token.shape[-1])
|
||||
|
||||
|
||||
class AggregatorBase(nn.Module, ABC):
|
||||
"""
|
||||
Base class for all Aggregator implementations.
|
||||
|
||||
Handles shared components:
|
||||
- Patch embedding (DINOv2 or conv)
|
||||
- Special tokens (camera, register, optionally scale)
|
||||
- Block creation (frame + global)
|
||||
- RoPE (2D rotary position embeddings)
|
||||
- Common forward pass scaffolding
|
||||
|
||||
Subclasses must implement:
|
||||
- _process_global_attention(): Mode-specific cross-frame attention logic
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size=518,
|
||||
patch_size=14,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
num_register_tokens=4,
|
||||
# Block configuration
|
||||
block_fn=Block,
|
||||
qkv_bias=True,
|
||||
proj_bias=True,
|
||||
ffn_bias=True,
|
||||
qk_norm=True,
|
||||
init_values=0.01,
|
||||
# Patch embedding
|
||||
patch_embed="dinov2_vitl14_reg",
|
||||
pretrained_path=None,
|
||||
# Attention pattern
|
||||
aa_order=["frame", "global"],
|
||||
aa_block_size=1,
|
||||
# RoPE
|
||||
rope_freq=100,
|
||||
disable_global_rope=False,
|
||||
# Gradient checkpointing
|
||||
use_reentrant: bool = False,
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store configuration
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.depth = depth
|
||||
self.num_heads = num_heads
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.aa_order = aa_order
|
||||
self.aa_block_size = aa_block_size
|
||||
self.disable_global_rope = disable_global_rope
|
||||
self.use_reentrant = use_reentrant
|
||||
self.use_gradient_checkpoint = use_gradient_checkpoint
|
||||
self.pretrained_path = pretrained_path
|
||||
self.enable_ulysses_cp = False # CP disabled
|
||||
|
||||
print("pretrained_path:", self.pretrained_path)
|
||||
|
||||
# Validate depth
|
||||
if self.depth % self.aa_block_size != 0:
|
||||
raise ValueError(f"depth ({depth}) must be divisible by aa_block_size ({aa_block_size})")
|
||||
self.aa_block_num = self.depth // self.aa_block_size
|
||||
|
||||
# Build patch embedding
|
||||
self._build_patch_embed(
|
||||
patch_embed=patch_embed,
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
num_register_tokens=num_register_tokens,
|
||||
embed_dim=embed_dim,
|
||||
pretrained_path=pretrained_path
|
||||
)
|
||||
|
||||
# Initialize RoPE
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
||||
self.position_getter = PositionGetter() if self.rope is not None else None
|
||||
|
||||
# Build blocks (frame + global)
|
||||
self._build_blocks(
|
||||
block_fn=block_fn,
|
||||
depth=depth,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
init_values=init_values,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
# Setup special tokens (camera, register, optionally scale)
|
||||
self._setup_special_tokens()
|
||||
|
||||
# Register normalization constants
|
||||
for name, value in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
|
||||
self.register_buffer(name, torch.FloatTensor(value).view(1, 1, 3, 1, 1), persistent=False)
|
||||
|
||||
# Initialize from DINO checkpoint if available
|
||||
if hasattr(self, '_dino_checkpoint') and self._dino_checkpoint is not None:
|
||||
self._init_blocks_from_dino(self._dino_checkpoint)
|
||||
del self._dino_checkpoint # Free memory
|
||||
|
||||
def _build_patch_embed(
|
||||
self,
|
||||
patch_embed: str,
|
||||
img_size: int,
|
||||
patch_size: int,
|
||||
num_register_tokens: int,
|
||||
embed_dim: int,
|
||||
pretrained_path: str,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
block_chunks=0,
|
||||
init_values=1.0,
|
||||
):
|
||||
"""
|
||||
Build patch embedding layer.
|
||||
|
||||
Supports:
|
||||
- "conv": Simple convolutional patch embedding
|
||||
- "dinov2_*": DINOv2 ViT variants (vitl14, vitb14, vits14, vitg2)
|
||||
"""
|
||||
if "conv" in patch_embed:
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=3,
|
||||
embed_dim=embed_dim
|
||||
)
|
||||
self._dino_checkpoint = None
|
||||
|
||||
else:
|
||||
vit_models = {
|
||||
"dinov2_vitl14_reg": vit_large,
|
||||
"dinov2_vitb14_reg": vit_base,
|
||||
"dinov2_vits14_reg": vit_small,
|
||||
"dinov2_vitg2_reg": vit_giant2,
|
||||
}
|
||||
|
||||
if patch_embed not in vit_models:
|
||||
raise NotImplementedError(f"Unknown patch_embed type: {patch_embed}")
|
||||
|
||||
self.patch_embed = vit_models[patch_embed](
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
num_register_tokens=num_register_tokens,
|
||||
interpolate_antialias=interpolate_antialias,
|
||||
interpolate_offset=interpolate_offset,
|
||||
block_chunks=block_chunks,
|
||||
init_values=init_values,
|
||||
)
|
||||
|
||||
# Load pretrained weights
|
||||
try:
|
||||
ckpt = torch.load(pretrained_path)
|
||||
del ckpt['pos_embed']
|
||||
logger.info("Loading pretrained weights for DINOv2")
|
||||
missing, unexpected = self.patch_embed.load_state_dict(ckpt, strict=False)
|
||||
logger.info(f"Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
||||
|
||||
# Store checkpoint for block initialization
|
||||
self._dino_checkpoint = ckpt
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load pretrained weights: {e}")
|
||||
self._dino_checkpoint = None
|
||||
|
||||
# Disable gradients for mask token
|
||||
if hasattr(self.patch_embed, "mask_token"):
|
||||
self.patch_embed.mask_token.requires_grad_(False)
|
||||
|
||||
@abstractmethod
|
||||
def _build_blocks(
|
||||
self,
|
||||
block_fn,
|
||||
depth: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
ffn_bias: bool,
|
||||
init_values: float,
|
||||
qk_norm: bool,
|
||||
):
|
||||
"""
|
||||
Build frame_blocks and global_blocks.
|
||||
|
||||
Subclasses implement mode-specific block creation.
|
||||
|
||||
Must create:
|
||||
- self.frame_blocks: nn.ModuleList of frame attention blocks
|
||||
- self.global_blocks: nn.ModuleList of global attention blocks
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _setup_special_tokens(self):
|
||||
"""
|
||||
Setup camera token, register tokens, and optionally scale token.
|
||||
|
||||
Subclasses implement mode-specific token initialization.
|
||||
|
||||
Must create:
|
||||
- self.camera_token
|
||||
- self.register_token
|
||||
- self.scale_token (optional, for causal mode)
|
||||
- self.patch_start_idx
|
||||
- self.num_special_tokens
|
||||
"""
|
||||
pass
|
||||
|
||||
def _init_blocks_from_dino(self, dino_ckpt: dict):
|
||||
"""
|
||||
Initialize frame_blocks and global_blocks from DINOv2 pretrained weights.
|
||||
|
||||
Args:
|
||||
dino_ckpt: Checkpoint dictionary from DINOv2 model
|
||||
"""
|
||||
logger.info("Initializing blocks from DINOv2 pretrained weights")
|
||||
|
||||
# Extract block keys
|
||||
dino_block_keys = [k for k in dino_ckpt.keys() if k.startswith('blocks.')]
|
||||
if not dino_block_keys:
|
||||
logger.warning("No 'blocks' found in DINO checkpoint")
|
||||
return
|
||||
|
||||
# Get block indices
|
||||
block_indices = set()
|
||||
for key in dino_block_keys:
|
||||
parts = key.split('.')
|
||||
if len(parts) > 1 and parts[1].isdigit():
|
||||
block_indices.add(int(parts[1]))
|
||||
|
||||
num_dino_blocks = len(block_indices)
|
||||
print(f"Found {num_dino_blocks} blocks in DINO checkpoint")
|
||||
|
||||
# Initialize frame_blocks
|
||||
for i, block in enumerate(self.frame_blocks):
|
||||
dino_block_idx = i % num_dino_blocks
|
||||
block_state_dict = {}
|
||||
prefix = f'blocks.{dino_block_idx}.'
|
||||
for key, value in dino_ckpt.items():
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
block_state_dict[new_key] = value
|
||||
|
||||
if block_state_dict:
|
||||
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
|
||||
if i == 0: # Only log for first block to avoid spam
|
||||
print(f"Frame block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
||||
|
||||
# Initialize global_blocks
|
||||
for i, block in enumerate(self.global_blocks):
|
||||
dino_block_idx = i % num_dino_blocks
|
||||
block_state_dict = {}
|
||||
prefix = f'blocks.{dino_block_idx}.'
|
||||
for key, value in dino_ckpt.items():
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix):]
|
||||
block_state_dict[new_key] = value
|
||||
|
||||
if block_state_dict:
|
||||
missing, unexpected = block.load_state_dict(block_state_dict, strict=False)
|
||||
if i == 0: # Only log for first block to avoid spam
|
||||
print(f"Global block 0: Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
|
||||
|
||||
logger.info("Successfully initialized blocks from DINOv2 weights")
|
||||
|
||||
def _embed_images(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int, int, int, int, int]:
|
||||
"""
|
||||
Embed images and prepare for attention processing.
|
||||
|
||||
Handles:
|
||||
- Image normalization
|
||||
- Patch embedding
|
||||
- Special token concatenation
|
||||
- Position embedding
|
||||
|
||||
Args:
|
||||
images: Input images [B, S, 3, H, W] in range [0, 1]
|
||||
num_frame_for_scale: Number of frames for scale estimation (passed to special tokens)
|
||||
|
||||
Returns:
|
||||
(tokens, B, S, S, P, C):
|
||||
tokens: Embedded tokens [B*S, P, C]
|
||||
B: Batch size
|
||||
S: Sequence length
|
||||
S: Same as above (no CP slicing)
|
||||
P: Number of tokens per frame (patches + special tokens)
|
||||
C: Embedding dimension
|
||||
"""
|
||||
B, S, C_in, H, W = images.shape
|
||||
|
||||
if C_in != 3:
|
||||
raise ValueError(f"Expected 3 input channels, got {C_in}")
|
||||
|
||||
# Normalize images
|
||||
images = (images - self._resnet_mean) / self._resnet_std
|
||||
|
||||
# No CP slicing: S_local == S_global
|
||||
S_local = S
|
||||
S_global = S
|
||||
|
||||
# Reshape for patch embedding [B*S, C, H, W]
|
||||
images = images.view(B * S, C_in, H, W)
|
||||
|
||||
# Patch embedding
|
||||
patch_tokens = self.patch_embed(images)
|
||||
if isinstance(patch_tokens, dict):
|
||||
patch_tokens = patch_tokens["x_norm_patchtokens"]
|
||||
|
||||
_, P_patch, C = patch_tokens.shape
|
||||
|
||||
# Prepare special tokens
|
||||
special_tokens = self._prepare_special_tokens(
|
||||
B, S_local, S_global, C,
|
||||
num_frame_for_scale=num_frame_for_scale
|
||||
)
|
||||
|
||||
# Concatenate special tokens + patch tokens
|
||||
tokens = torch.cat([special_tokens, patch_tokens], dim=1)
|
||||
|
||||
_, P, C = tokens.shape
|
||||
|
||||
return tokens, B, S_local, S_global, P, C
|
||||
|
||||
@abstractmethod
|
||||
def _prepare_special_tokens(self, B: int, S_local: int, S_global: int, C: int, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Prepare special tokens (camera, register, optionally scale).
|
||||
|
||||
Subclasses implement mode-specific token preparation.
|
||||
|
||||
Args:
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
C: Embedding dimension
|
||||
**kwargs: Mode-specific parameters (e.g., num_frame_for_scale for causal mode)
|
||||
|
||||
Returns:
|
||||
Special tokens [B*S, N_special, C]
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_positions(self, B: int, S: int, H: int, W: int, device) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get 2D position embeddings for RoPE.
|
||||
|
||||
Args:
|
||||
B: Batch size
|
||||
S: Sequence length
|
||||
H: Image height
|
||||
W: Image width
|
||||
device: Device to create positions on
|
||||
|
||||
Returns:
|
||||
Position tensor [B*S, P, 2] or None if rope is disabled
|
||||
"""
|
||||
if self.rope is None:
|
||||
return None
|
||||
|
||||
# Get patch positions
|
||||
pos = self.position_getter(B * S, H // self.patch_size, W // self.patch_size, device=device)
|
||||
|
||||
# Add offset for patch tokens (skip special tokens at pos=0)
|
||||
if self.patch_start_idx > 0:
|
||||
pos = pos + 1
|
||||
pos_special = torch.zeros(B * S, self.patch_start_idx, 2, dtype=pos.dtype, device=device)
|
||||
pos = torch.cat([pos_special, pos], dim=1)
|
||||
|
||||
return pos
|
||||
|
||||
def _process_frame_attention(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S: int,
|
||||
P: int,
|
||||
C: int,
|
||||
frame_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
||||
"""
|
||||
Process frame attention blocks.
|
||||
|
||||
Frame attention operates independently per frame (no cross-frame communication).
|
||||
Tokens stay in shape [B*S, P, C].
|
||||
|
||||
Args:
|
||||
tokens: Input tokens [B*S, P, C]
|
||||
B: Batch size
|
||||
S: Sequence length
|
||||
P: Tokens per frame
|
||||
C: Embedding dimension
|
||||
frame_idx: Current frame block index
|
||||
pos: Position embeddings [B*S, P, 2]
|
||||
|
||||
Returns:
|
||||
(tokens, frame_idx, intermediates):
|
||||
tokens: Output tokens [B*S, P, C]
|
||||
frame_idx: Updated frame block index
|
||||
intermediates: List of intermediate outputs [B, S, P, C]
|
||||
"""
|
||||
# Ensure correct shape
|
||||
if tokens.shape != (B * S, P, C):
|
||||
tokens = tokens.view(B * S, P, C)
|
||||
|
||||
if pos is not None and pos.shape != (B * S, P, 2):
|
||||
pos = pos.view(B * S, P, 2)
|
||||
|
||||
intermediates = []
|
||||
|
||||
# Process blocks
|
||||
for i in range(self.aa_block_size):
|
||||
if self.training and self.use_gradient_checkpoint:
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
tokens = checkpoint(
|
||||
self.frame_blocks[frame_idx],
|
||||
tokens,
|
||||
pos,
|
||||
False, # enable_ulysses_cp (always False)
|
||||
use_reentrant=self.use_reentrant
|
||||
)
|
||||
else:
|
||||
tokens = self.frame_blocks[frame_idx](tokens, pos=pos, enable_ulysses_cp=False)
|
||||
|
||||
frame_idx += 1
|
||||
intermediates.append(tokens.view(B, S, P, C))
|
||||
|
||||
return tokens, frame_idx, intermediates
|
||||
|
||||
@abstractmethod
|
||||
def _process_global_attention(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
P: int,
|
||||
C: int,
|
||||
global_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
||||
"""
|
||||
Process global (cross-frame) attention blocks.
|
||||
|
||||
Subclasses implement mode-specific attention logic.
|
||||
|
||||
Args:
|
||||
tokens: Input tokens
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
P: Tokens per frame
|
||||
C: Embedding dimension
|
||||
global_idx: Current global block index
|
||||
pos: Position embeddings
|
||||
**kwargs: Mode-specific parameters
|
||||
|
||||
Returns:
|
||||
(tokens, global_idx, intermediates):
|
||||
tokens: Output tokens
|
||||
global_idx: Updated global block index
|
||||
intermediates: List of intermediate outputs
|
||||
"""
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
selected_idx: Optional[List[int]] = None,
|
||||
# Mode-specific parameters
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
) -> Tuple[List[torch.Tensor], int]:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Args:
|
||||
images: Input images [B, S, 3, H, W] in range [0, 1]
|
||||
selected_idx: Which block indices to output (None = all)
|
||||
num_frame_for_scale: Number of frames for scale estimation (causal mode)
|
||||
sliding_window_size: Sliding window size in blocks (causal mode)
|
||||
num_frame_per_block: Number of frames per processing block (causal mode)
|
||||
|
||||
Returns:
|
||||
(output_list, patch_start_idx):
|
||||
output_list: List of block outputs [B, S, P, 2C]
|
||||
patch_start_idx: Index where patch tokens start
|
||||
"""
|
||||
B, S_input, _, H, W = images.shape
|
||||
|
||||
# Embed images
|
||||
tokens, B, S_local, S_global, P, C = self._embed_images(
|
||||
images,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
)
|
||||
|
||||
# Get position embeddings
|
||||
pos_local = self._get_positions(B, S_local, H, W, device=images.device)
|
||||
pos_global = self._get_positions(B, S_global, H, W, device=images.device)
|
||||
|
||||
# Alternating attention
|
||||
frame_idx = 0
|
||||
global_idx = 0
|
||||
output_list = []
|
||||
|
||||
for block_group_idx in range(self.aa_block_num):
|
||||
for attn_type in self.aa_order:
|
||||
if attn_type == "frame":
|
||||
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
||||
tokens, B, S_local, P, C, frame_idx, pos=pos_local
|
||||
)
|
||||
elif attn_type == "global":
|
||||
tokens, global_idx, global_intermediates = self._process_global_attention(
|
||||
tokens, B, S_local, S_global, P, C, global_idx,
|
||||
pos=pos_global,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
image_height=H,
|
||||
image_width=W,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_type}")
|
||||
|
||||
# Collect outputs
|
||||
if selected_idx is None or block_group_idx in selected_idx:
|
||||
for i in range(len(frame_intermediates)):
|
||||
# Concatenate frame and global intermediates [B, S, P, 2C]
|
||||
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
||||
output_list.append(concat_inter)
|
||||
|
||||
return output_list, self.patch_start_idx
|
||||
531
lingbot_map/aggregator/stream.py
Normal file
531
lingbot_map/aggregator/stream.py
Normal file
@@ -0,0 +1,531 @@
|
||||
"""
|
||||
AggregatorStream - Streaming causal aggregator with FlashInfer KV cache.
|
||||
|
||||
Provides:
|
||||
- Temporal causal attention
|
||||
- Sliding window support
|
||||
- Scale token for scale estimation frames
|
||||
- Streaming inference with FlashInfer paged KV cache
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
from lingbot_map.layers.block import Block, FlashInferBlock, SDPABlock
|
||||
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
||||
from lingbot_map.aggregator.base import AggregatorBase, slice_expand_and_flatten
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AggregatorStream(AggregatorBase):
|
||||
"""
|
||||
Streaming causal aggregator with FlashInfer paged KV cache.
|
||||
|
||||
Features:
|
||||
- Temporal causal attention (each frame only attends to past frames)
|
||||
- Sliding window support to limit attention scope
|
||||
- Scale token for scale estimation frames
|
||||
- Streaming inference with FlashInfer KV cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Causal-specific parameters
|
||||
sliding_window_size: int = -1,
|
||||
num_frame_for_scale: int = 1,
|
||||
num_random_frames: int = 0,
|
||||
attend_to_special_tokens: bool = False,
|
||||
attend_to_scale_frames: bool = False,
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# Base class parameters via **kwargs
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize AggregatorStream.
|
||||
|
||||
Args:
|
||||
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
||||
num_frame_for_scale: Number of scale estimation frames
|
||||
num_random_frames: Number of random frames for long-range dependencies
|
||||
attend_to_special_tokens: Enable cross-frame special token attention
|
||||
attend_to_scale_frames: Include scale frames in attention
|
||||
enable_3d_rope: Enable 3D RoPE for temporal dimension in KV cache
|
||||
max_frame_num: Maximum number of frames for 3D RoPE
|
||||
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
||||
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
||||
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
||||
kv_cache_include_scale_frames: Include scale frames in KV cache
|
||||
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
||||
**kwargs: Base class parameters
|
||||
"""
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.num_frame_for_scale = num_frame_for_scale
|
||||
self.num_random_frames = num_random_frames
|
||||
self.attend_to_special_tokens = attend_to_special_tokens
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
self.max_frame_num = max_frame_num
|
||||
# KV cache parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
# Pop kwargs that are passed but not needed by base class
|
||||
kwargs.pop('enable_stream_inference', None)
|
||||
use_flashinfer = kwargs.pop('use_flashinfer', True)
|
||||
kwargs.pop('use_flexflash', None)
|
||||
use_sdpa = kwargs.pop('use_sdpa', False)
|
||||
|
||||
# Backend selection: SDPA (no extra deps) or FlashInfer (paged KV cache)
|
||||
self.use_sdpa = use_sdpa
|
||||
self.use_flashinfer = not use_sdpa # FlashInfer is default unless SDPA requested
|
||||
|
||||
# Call parent __init__
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Initialize KV cache
|
||||
self._init_kv_cache()
|
||||
|
||||
# Initialize 3D RoPE if enabled
|
||||
if self.enable_3d_rope:
|
||||
self._init_3d_rope()
|
||||
|
||||
def _build_blocks(
|
||||
self,
|
||||
block_fn,
|
||||
depth: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
qkv_bias: bool,
|
||||
proj_bias: bool,
|
||||
ffn_bias: bool,
|
||||
init_values: float,
|
||||
qk_norm: bool,
|
||||
):
|
||||
"""Build frame and global blocks for streaming causal mode."""
|
||||
block_params = dict(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
init_values=init_values,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
|
||||
# Frame blocks: Standard Block + RoPE
|
||||
self.frame_blocks = nn.ModuleList([
|
||||
block_fn(**block_params, rope=self.rope)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
# Global blocks: FlashInferBlock (default) or SDPABlock (fallback)
|
||||
GlobalBlockCls = SDPABlock if self.use_sdpa else FlashInferBlock
|
||||
self.global_blocks = nn.ModuleList([
|
||||
GlobalBlockCls(
|
||||
**block_params,
|
||||
rope=self.rope if not self.disable_global_rope else None,
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
)
|
||||
for _ in range(depth)
|
||||
])
|
||||
|
||||
def _setup_special_tokens(self):
|
||||
"""Setup camera, register, and scale tokens for causal mode."""
|
||||
# Camera token
|
||||
self.camera_token = nn.Parameter(
|
||||
torch.randn(1, 2, 1, self.embed_dim)
|
||||
)
|
||||
|
||||
# Register tokens
|
||||
if self.num_register_tokens > 0:
|
||||
self.register_token = nn.Parameter(
|
||||
torch.randn(1, 2, self.num_register_tokens, self.embed_dim)
|
||||
)
|
||||
|
||||
# Scale token (causal mode specific)
|
||||
self.scale_token = nn.Parameter(
|
||||
torch.ones(1, 2, 1, self.embed_dim)
|
||||
)
|
||||
|
||||
# Initialize
|
||||
nn.init.normal_(self.camera_token, std=1e-6)
|
||||
if self.num_register_tokens > 0:
|
||||
nn.init.normal_(self.register_token, std=1e-6)
|
||||
nn.init.normal_(self.scale_token, std=1e-6)
|
||||
|
||||
# Token indexing (includes scale token)
|
||||
self.patch_start_idx = 1 + self.num_register_tokens + 1 # camera + register + scale
|
||||
self.num_special_tokens = 1 + self.num_register_tokens + 1
|
||||
|
||||
def _init_kv_cache(self):
|
||||
"""Initialize KV cache for streaming inference."""
|
||||
self.kv_cache_manager = None # FlashInfer (lazy-initialized)
|
||||
self.kv_cache = {} # Dict-based cache for SDPA
|
||||
self.total_frames_processed = 0
|
||||
self._cached_pos3d = None
|
||||
|
||||
if self.use_sdpa:
|
||||
# Dict-based KV cache for SDPA
|
||||
if hasattr(self, 'depth'):
|
||||
for i in range(self.depth):
|
||||
self.kv_cache[f"k_{i}"] = None
|
||||
self.kv_cache[f"v_{i}"] = None
|
||||
self.kv_cache[f"k_{i}_special"] = None
|
||||
self.kv_cache[f"v_{i}_special"] = None
|
||||
logger.info(f"SDPA KV cache initialized with {self.depth} blocks")
|
||||
else:
|
||||
logger.info("FlashInfer KV cache will be lazily initialized on first forward")
|
||||
|
||||
def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
|
||||
"""Lazily initialize FlashInferKVCacheManager on first use.
|
||||
|
||||
Args:
|
||||
device: Device for cache tensors.
|
||||
dtype: Data type for cache tensors.
|
||||
tokens_per_frame: Actual number of tokens per frame (patches + specials).
|
||||
If None, falls back to assuming square images of self.img_size.
|
||||
"""
|
||||
if self.kv_cache_manager is None:
|
||||
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
||||
num_heads = self.embed_dim // 64 # head_dim = 64 for ViT-L
|
||||
head_dim = 64
|
||||
if tokens_per_frame is None:
|
||||
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
|
||||
# max_num_frames: scale + window + headroom
|
||||
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16
|
||||
self.kv_cache_manager = FlashInferKVCacheManager(
|
||||
num_blocks=self.depth,
|
||||
max_num_frames=max_num_frames,
|
||||
tokens_per_frame=tokens_per_frame,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_special_tokens=self.num_special_tokens,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
max_total_frames=self.max_frame_num + 100,
|
||||
force_fp32=getattr(self, 'kv_cache_force_fp32', False),
|
||||
fa3=getattr(self, 'kv_cache_fa3', False),
|
||||
)
|
||||
logger.info(
|
||||
f"FlashInfer KV cache manager initialized: {self.depth} blocks, "
|
||||
f"max_frames={max_num_frames}, tokens_per_frame={tokens_per_frame}"
|
||||
)
|
||||
return self.kv_cache_manager
|
||||
|
||||
def clean_kv_cache(self):
|
||||
"""Clean KV cache (call this when starting a new sequence)."""
|
||||
if self.kv_cache_manager is not None:
|
||||
self.kv_cache_manager.reset()
|
||||
if self.kv_cache:
|
||||
for key in list(self.kv_cache.keys()):
|
||||
if key == "_skip_append":
|
||||
self.kv_cache[key] = False
|
||||
else:
|
||||
self.kv_cache[key] = None
|
||||
self.total_frames_processed = 0
|
||||
self._cached_pos3d = None
|
||||
logger.info("KV cache cleaned")
|
||||
|
||||
def _init_3d_rope(self):
|
||||
"""Initialize 3D RoPE for streaming inference."""
|
||||
if not self.enable_3d_rope:
|
||||
self.rope3d = None
|
||||
return
|
||||
|
||||
num_heads = 16
|
||||
head_dim = self.embed_dim // num_heads
|
||||
|
||||
self.rope3d = WanRotaryPosEmbed(
|
||||
attention_head_dim=head_dim,
|
||||
patch_size=(1, self.patch_size, self.patch_size),
|
||||
max_seq_len=self.max_frame_num,
|
||||
)
|
||||
logger.info(f"3D RoPE initialized for max {self.max_frame_num} frames, head_dim={head_dim}")
|
||||
|
||||
def _get_3d_positions_streaming(self, num_frames, H, W, device, f_start, f_end):
|
||||
"""
|
||||
Generate 3D RoPE positions for streaming mode with correct global frame indices.
|
||||
|
||||
Args:
|
||||
num_frames: Number of frames in current batch
|
||||
H, W: Image height and width
|
||||
device: Device to create positions on
|
||||
f_start: Global start frame index
|
||||
f_end: Global end frame index
|
||||
|
||||
Returns:
|
||||
pos3d: [1, 1, num_frames * P, head_dim//2] complex tensor
|
||||
"""
|
||||
if self.rope3d is None:
|
||||
return None
|
||||
|
||||
pph = H // self.patch_size
|
||||
ppw = W // self.patch_size
|
||||
|
||||
pos3d = self.rope3d(
|
||||
ppf=num_frames,
|
||||
pph=pph,
|
||||
ppw=ppw,
|
||||
patch_start_idx=self.num_special_tokens,
|
||||
device=device,
|
||||
f_start=f_start,
|
||||
f_end=f_end
|
||||
)
|
||||
return pos3d
|
||||
|
||||
def _prepare_special_tokens(
|
||||
self,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
C: int,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare camera, register, and scale tokens.
|
||||
|
||||
Args:
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
C: Embedding dimension
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
|
||||
Returns:
|
||||
Special tokens [B*S_global, N_special, C]
|
||||
"""
|
||||
# Get effective num_frame_for_scale
|
||||
scale_frames = self.num_frame_for_scale if num_frame_for_scale is None else num_frame_for_scale
|
||||
|
||||
# Check cache state for both backends
|
||||
has_flashinfer_cache = self.kv_cache_manager is not None and self.kv_cache_manager.num_frames > 0
|
||||
has_sdpa_cache = self.kv_cache is not None and self.kv_cache.get("k_0") is not None
|
||||
|
||||
# Determine if we're in causal inference mode based on KV cache state
|
||||
causal_inference = True
|
||||
|
||||
if causal_inference and has_flashinfer_cache:
|
||||
S_cached = self.kv_cache_manager.num_frames
|
||||
S_true = S_cached + S_global
|
||||
elif causal_inference and has_sdpa_cache:
|
||||
_, _, S_cached, _, _ = self.kv_cache["k_0"].shape
|
||||
S_true = S_cached + S_global
|
||||
else:
|
||||
S_true = S_global
|
||||
|
||||
# Expand tokens based on mode
|
||||
if causal_inference and S_true > S_global:
|
||||
# Streaming mode: expand with S_true, then slice to get current frames
|
||||
effective_scale_frames = min(scale_frames, S_true)
|
||||
|
||||
camera_token_full = slice_expand_and_flatten(self.camera_token, B, S_true)
|
||||
camera_token = camera_token_full[-S_global:, :, :]
|
||||
|
||||
register_token_full = slice_expand_and_flatten(self.register_token, B, S_true)
|
||||
register_token = register_token_full[-S_global:, :, :]
|
||||
scale_token_full = slice_expand_and_flatten(
|
||||
self.scale_token, B, S_true, first_num_frame=effective_scale_frames
|
||||
)
|
||||
scale_token = scale_token_full[-S_global:, :, :]
|
||||
else:
|
||||
# Batch mode or first inference: expand directly
|
||||
effective_scale_frames = min(scale_frames, S_global)
|
||||
|
||||
camera_token = slice_expand_and_flatten(self.camera_token, B, S_global)
|
||||
register_token = slice_expand_and_flatten(self.register_token, B, S_global)
|
||||
scale_token = slice_expand_and_flatten(
|
||||
self.scale_token, B, S_global, first_num_frame=effective_scale_frames
|
||||
)
|
||||
|
||||
special_tokens = torch.cat([camera_token, register_token, scale_token], dim=1)
|
||||
|
||||
# Verify shape
|
||||
expected_shape = (B * S_global, self.num_special_tokens, C)
|
||||
assert special_tokens.shape == expected_shape, \
|
||||
f"Expected {expected_shape}, got {special_tokens.shape}"
|
||||
|
||||
return special_tokens
|
||||
|
||||
def _process_global_attention(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
P: int,
|
||||
C: int,
|
||||
global_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
# Mode-specific parameters
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, int, List[torch.Tensor]]:
|
||||
"""
|
||||
Process causal global attention via FlashInfer streaming path.
|
||||
|
||||
Args:
|
||||
tokens: Input tokens
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
P: Tokens per frame
|
||||
C: Embedding dimension
|
||||
global_idx: Current global block index
|
||||
pos: Position embeddings
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
sliding_window_size: Sliding window size in blocks
|
||||
num_frame_per_block: Number of frames per processing block
|
||||
|
||||
Returns:
|
||||
(tokens, global_idx, intermediates)
|
||||
"""
|
||||
# Extract image dimensions from kwargs for 3D RoPE
|
||||
image_height = kwargs.get('image_height', self.img_size)
|
||||
image_width = kwargs.get('image_width', self.img_size)
|
||||
|
||||
return self._process_causal_stream(
|
||||
tokens, B, S_local, S_global, P, C, global_idx, pos,
|
||||
num_frame_per_block, sliding_window_size, num_frame_for_scale,
|
||||
image_height=image_height, image_width=image_width
|
||||
)
|
||||
|
||||
def _process_causal_stream(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
B: int,
|
||||
S_local: int,
|
||||
S_global: int,
|
||||
P: int,
|
||||
C: int,
|
||||
global_idx: int,
|
||||
pos: Optional[torch.Tensor] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
image_height: Optional[int] = None,
|
||||
image_width: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Causal attention for streaming inference using FlashInfer KV cache.
|
||||
|
||||
Args:
|
||||
tokens: Input tokens [B*S_local, P, C]
|
||||
B: Batch size
|
||||
S_local: Local sequence length
|
||||
S_global: Global sequence length
|
||||
P: Number of patches per frame (includes special tokens)
|
||||
C: Channel dimension
|
||||
global_idx: Starting block index
|
||||
pos: Position embeddings [B*S_global, P, 2]
|
||||
num_frame_per_block: Number of frames per block
|
||||
sliding_window_size: Sliding window size in blocks
|
||||
num_frame_for_scale: Number of scale frames
|
||||
image_height: Image height for 3D RoPE calculation
|
||||
image_width: Image width for 3D RoPE calculation
|
||||
|
||||
Returns:
|
||||
(tokens, global_idx, intermediates): Updated tokens, next block index, intermediate outputs
|
||||
"""
|
||||
# Get effective parameters
|
||||
scale_frames = num_frame_for_scale if num_frame_for_scale is not None else self.num_frame_for_scale
|
||||
|
||||
# Reshape tokens: [B*S_local, P, C] -> [B, S_local*P, C]
|
||||
if tokens.shape != (B, S_local * P, C):
|
||||
tokens = tokens.view(B, S_local, P, C).view(B, S_local * P, C)
|
||||
|
||||
# Calculate number of frames for block mask
|
||||
num_frames = S_global
|
||||
num_patches = P - self.num_special_tokens
|
||||
|
||||
# Check if this is the first block group
|
||||
is_first_block_group = (global_idx < self.aa_block_size)
|
||||
|
||||
if self.enable_3d_rope and self.rope3d is not None:
|
||||
if is_first_block_group:
|
||||
f_start = self.total_frames_processed
|
||||
f_end = self.total_frames_processed + S_global
|
||||
|
||||
H = image_height if image_height is not None else self.img_size
|
||||
W = image_width if image_width is not None else self.img_size
|
||||
pos3d = self._get_3d_positions_streaming(
|
||||
S_global, H, W, tokens.device, f_start, f_end
|
||||
)
|
||||
self._cached_pos3d = pos3d
|
||||
else:
|
||||
pos3d = self._cached_pos3d
|
||||
pos = pos3d
|
||||
else:
|
||||
# Reshape pos: [B*S_global, P, 2] -> [B, S_global*P, 2]
|
||||
if pos is not None and pos.shape != (B, S_global * P, 2):
|
||||
pos = pos.view(B, S_global, P, 2).view(B, S_global * P, 2)
|
||||
|
||||
intermediates = []
|
||||
|
||||
# Process blocks with KV cache
|
||||
for _ in range(self.aa_block_size):
|
||||
num_patches = P - self.num_special_tokens
|
||||
if self.use_sdpa:
|
||||
# SDPA: dict-based KV cache
|
||||
tokens = self.global_blocks[global_idx](
|
||||
tokens,
|
||||
pos=pos,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=num_patches,
|
||||
num_special=self.num_special_tokens,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
kv_cache=self.kv_cache,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_register_tokens=self.num_register_tokens,
|
||||
)
|
||||
else:
|
||||
# FlashInfer: paged KV cache manager
|
||||
manager = self._get_flashinfer_manager(tokens.device, tokens.dtype, tokens_per_frame=P)
|
||||
tokens = self.global_blocks[global_idx](
|
||||
tokens,
|
||||
pos=pos,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=num_patches,
|
||||
num_special=self.num_special_tokens,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
kv_cache=manager,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_register_tokens=self.num_register_tokens,
|
||||
)
|
||||
|
||||
global_idx += 1
|
||||
intermediates.append(tokens.view(B, S_local, P, C))
|
||||
|
||||
# Update total frames processed counter only on the first block group
|
||||
if is_first_block_group and not (isinstance(self.kv_cache, dict) and self.kv_cache.get("_skip_append", False)):
|
||||
self.total_frames_processed += S_global
|
||||
|
||||
return tokens, global_idx, intermediates
|
||||
0
lingbot_map/heads/__init__.py
Normal file
0
lingbot_map/heads/__init__.py
Normal file
454
lingbot_map/heads/camera_head.py
Normal file
454
lingbot_map/heads/camera_head.py
Normal file
@@ -0,0 +1,454 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lingbot_map.layers import Mlp
|
||||
from lingbot_map.layers.block import Block
|
||||
from lingbot_map.layers.block import CameraBlock
|
||||
from lingbot_map.heads.head_act import activate_pose
|
||||
from lingbot_map.layers.rope import WanRotaryPosEmbed
|
||||
from functools import partial
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
class CameraHead(nn.Module):
|
||||
"""
|
||||
CameraHead predicts camera parameters from token representations using iterative refinement.
|
||||
|
||||
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int = 2048,
|
||||
trunk_depth: int = 4,
|
||||
pose_encoding_type: str = "absT_quaR_FoV",
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
trans_act: str = "linear",
|
||||
quat_act: str = "linear",
|
||||
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
||||
enable_ulysses_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
self.target_dim = 9
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
||||
|
||||
self.trans_act = trans_act
|
||||
self.quat_act = quat_act
|
||||
self.fl_act = fl_act
|
||||
self.trunk_depth = trunk_depth
|
||||
|
||||
self.enable_ulysses_cp = enable_ulysses_cp
|
||||
|
||||
# Build the trunk using a sequence of transformer blocks.
|
||||
self.trunk = nn.Sequential(
|
||||
*[
|
||||
Block(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values)
|
||||
for _ in range(trunk_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Normalizations for camera token and trunk output.
|
||||
self.token_norm = nn.LayerNorm(dim_in)
|
||||
self.trunk_norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Learnable empty camera pose token.
|
||||
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
||||
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
||||
|
||||
# Module for producing modulation parameters: shift, scale, and a gate.
|
||||
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
||||
|
||||
# Adaptive layer normalization without affine parameters.
|
||||
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
||||
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, num_iterations: int = 4, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
# Extract the camera tokens
|
||||
pose_tokens = tokens[:, :, 0]
|
||||
pose_tokens = self.token_norm(pose_tokens)
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
|
||||
return pred_pose_enc_list
|
||||
|
||||
def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
|
||||
"""
|
||||
Iteratively refine camera pose predictions.
|
||||
|
||||
Args:
|
||||
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
|
||||
num_iterations (int): Number of refinement iterations.
|
||||
|
||||
Returns:
|
||||
list: List of activated camera encodings from each iteration.
|
||||
"""
|
||||
B, S, C = pose_tokens.shape # S is expected to be 1.
|
||||
pred_pose_enc = None
|
||||
pred_pose_enc_list = []
|
||||
|
||||
for _ in range(num_iterations):
|
||||
# Use a learned empty pose for the first iteration.
|
||||
if pred_pose_enc is None:
|
||||
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
||||
else:
|
||||
# Detach the previous prediction to avoid backprop through time.
|
||||
pred_pose_enc = pred_pose_enc.detach()
|
||||
module_input = self.embed_pose(pred_pose_enc)
|
||||
|
||||
# Generate modulation parameters and split them into shift, scale, and gate components.
|
||||
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
||||
|
||||
# Adaptive layer normalization and modulation.
|
||||
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
||||
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
||||
|
||||
# Apply trunk blocks with enable_ulysses_cp
|
||||
for block in self.trunk:
|
||||
pose_tokens_modulated = block(pose_tokens_modulated, enable_ulysses_cp=self.enable_ulysses_cp)
|
||||
# Compute the delta update for the pose encoding.
|
||||
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
||||
|
||||
if pred_pose_enc is None:
|
||||
pred_pose_enc = pred_pose_enc_delta
|
||||
else:
|
||||
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
||||
|
||||
# Apply final activation functions for translation, quaternion, and field-of-view.
|
||||
activated_pose = activate_pose(
|
||||
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
||||
)
|
||||
pred_pose_enc_list.append(activated_pose)
|
||||
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modulate the input tensor using scaling and shifting parameters.
|
||||
"""
|
||||
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
class CameraCausalHead(nn.Module):
|
||||
"""
|
||||
CameraHead predicts camera parameters from token representations using iterative refinement.
|
||||
|
||||
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int = 2048,
|
||||
trunk_depth: int = 4,
|
||||
pose_encoding_type: str = "absT_quaR_FoV",
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
trans_act: str = "linear",
|
||||
quat_act: str = "linear",
|
||||
fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
|
||||
num_iterations = 4,
|
||||
elementwise_attn_output_gate: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
attend_to_scale_frames: bool = False,
|
||||
num_random_frames: int = 0,
|
||||
enable_ulysses_cp: bool = False,
|
||||
attn_class: str = "flexflashattn_varlen",
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# 3D RoPE parameters
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
rope_theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
self.target_dim = 9
|
||||
else:
|
||||
raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
|
||||
|
||||
self.trans_act = trans_act
|
||||
self.quat_act = quat_act
|
||||
self.fl_act = fl_act
|
||||
self.trunk_depth = trunk_depth
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.enable_ulysses_cp = enable_ulysses_cp
|
||||
self.num_heads = num_heads
|
||||
|
||||
# 3D RoPE for temporal position encoding
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
if enable_3d_rope:
|
||||
head_dim = dim_in // num_heads
|
||||
# For camera head: each frame has 1 token (frame_seqlen=1)
|
||||
# patch_size is (max_frames, h=1, w=1) for 3D RoPE
|
||||
# fhw_dim=None lets auto-calculation: h_dim=w_dim=2*(head_dim//6), t_dim=remainder
|
||||
self.rope3d = WanRotaryPosEmbed(
|
||||
attention_head_dim=head_dim,
|
||||
patch_size=(max_frame_num, 1, 1),
|
||||
theta=rope_theta,
|
||||
fhw_dim=[40, 44, 44], # Auto-calculate dimension allocation
|
||||
)
|
||||
else:
|
||||
self.rope3d = None
|
||||
|
||||
# Build the trunk using a sequence of transformer blocks.
|
||||
self.trunk = nn.Sequential(
|
||||
*[
|
||||
CameraBlock(dim=dim_in, num_heads=num_heads, mlp_ratio=mlp_ratio, init_values=init_values, elementwise_attn_output_gate=elementwise_attn_output_gate, sliding_window_size=sliding_window_size, attend_to_scale_frames=attend_to_scale_frames, num_random_frames=num_random_frames, kv_cache_sliding_window=kv_cache_sliding_window, kv_cache_scale_frames=kv_cache_scale_frames, kv_cache_cross_frame_special=kv_cache_cross_frame_special, kv_cache_include_scale_frames=kv_cache_include_scale_frames, kv_cache_camera_only=kv_cache_camera_only)
|
||||
for _ in range(trunk_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Normalizations for camera token and trunk output.
|
||||
self.token_norm = nn.LayerNorm(dim_in)
|
||||
self.trunk_norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Learnable empty camera pose token.
|
||||
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
||||
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
||||
|
||||
# Module for producing modulation parameters: shift, scale, and a gate.
|
||||
self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
|
||||
|
||||
# Adaptive layer normalization without affine parameters.
|
||||
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
||||
self.pose_branch = Mlp(in_features=dim_in, hidden_features=dim_in // 2, out_features=self.target_dim, drop=0)
|
||||
|
||||
self.num_iterations = num_iterations
|
||||
|
||||
self.kv_cache = None
|
||||
self.pos_cache = None
|
||||
self.frame_idx = 0
|
||||
self.cp_size = 1
|
||||
|
||||
## Get cp size if enable ulysses cp
|
||||
if self.enable_ulysses_cp:
|
||||
from torchtitan.distributed.sequence_parallel import (
|
||||
init_sequence_parallel,
|
||||
get_ulysses_sequence_parallel_rank,
|
||||
get_ulysses_sequence_parallel_world_size,
|
||||
)
|
||||
|
||||
self.cp_size = get_ulysses_sequence_parallel_world_size()
|
||||
|
||||
|
||||
|
||||
def clean_kv_cache(self):
|
||||
del self.kv_cache
|
||||
self.kv_cache = None
|
||||
self.frame_idx = 0
|
||||
|
||||
def forward(self, aggregated_tokens_list: list, mask=None, num_iterations: int = 4, causal_inference=False, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None, **kwargs) -> list:
|
||||
"""
|
||||
Forward pass to predict camera parameters.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (list): List of token tensors from the network;
|
||||
the last tensor is used for prediction.
|
||||
num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
|
||||
sliding_window_size (int, optional): Override the sliding window size for this forward pass.
|
||||
If None, use the default self.sliding_window_size.
|
||||
|
||||
Returns:
|
||||
list: A list of predicted camera encodings (post-activation) from each iteration.
|
||||
"""
|
||||
# Use passed sliding_window_size if provided, otherwise use default
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
# Use tokens from the last block for camera prediction.
|
||||
tokens = aggregated_tokens_list[-1]
|
||||
|
||||
# Extract the camera tokens
|
||||
pose_tokens = tokens[:, :, 0]
|
||||
pose_tokens = self.token_norm(pose_tokens)
|
||||
|
||||
if causal_inference:
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = []
|
||||
for i in range(self.num_iterations):
|
||||
self.kv_cache.append({"_skip_append": False})
|
||||
for j in range(self.trunk_depth):
|
||||
self.kv_cache[i][f"k_{j}"] = None
|
||||
self.kv_cache[i][f"v_{j}"] = None
|
||||
|
||||
pred_pose_enc_list = self.trunk_fn(pose_tokens, mask, num_iterations, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size)
|
||||
return pred_pose_enc_list
|
||||
|
||||
def trunk_fn(self, pose_tokens: torch.Tensor, mask=None, num_iterations: int=4, num_frame_per_block=1, num_frame_for_scale=-1, sliding_window_size=None) -> list:
|
||||
"""
|
||||
Iteratively refine camera pose predictions.
|
||||
|
||||
Args:
|
||||
pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, S, C].
|
||||
num_iterations (int): Number of refinement iterations.
|
||||
sliding_window_size (int, optional): Sliding window size to use.
|
||||
|
||||
Returns:
|
||||
list: List of activated camera encodings from each iteration.
|
||||
"""
|
||||
B, S, C = pose_tokens.shape
|
||||
pred_pose_enc = None
|
||||
pred_pose_enc_list = []
|
||||
|
||||
# Check if this is the first call (processing scale frames)
|
||||
# Scale frames should use batch mode attention for numerical consistency
|
||||
is_scale_frames = (self.kv_cache is not None and self.frame_idx == 0)
|
||||
|
||||
# Generate 3D RoPE positions if enabled
|
||||
pos3d = None
|
||||
if self.rope3d is not None:
|
||||
# For camera tokens: shape [B, S, C] where each frame has 1 token
|
||||
# Position for frame f is (f, 0, 0) - temporal varies, spatial fixed
|
||||
|
||||
# In streaming mode with KV cache, use frame_idx to track global position
|
||||
# Otherwise, generate positions from 0
|
||||
if self.kv_cache is not None:
|
||||
f_start = self.frame_idx
|
||||
f_end = self.frame_idx + S
|
||||
else:
|
||||
f_start = 0
|
||||
f_end = None # Will use ppf as frame count
|
||||
|
||||
pos3d = self.rope3d(
|
||||
ppf=S * self.cp_size, # Total frames (with CP)
|
||||
pph=1, # height = 1 (camera token)
|
||||
ppw=1, # width = 1 (camera token)
|
||||
patch_start_idx=0, # No special tokens before
|
||||
device=pose_tokens.device,
|
||||
f_start=f_start,
|
||||
f_end=f_end,
|
||||
) # Returns [1, 1, S*cp_size, head_dim//2] complex
|
||||
|
||||
for i in range(num_iterations):
|
||||
# Use a learned empty pose for the first iteration.
|
||||
if pred_pose_enc is None:
|
||||
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
||||
else:
|
||||
# Detach the previous prediction to avoid backprop through time.
|
||||
pred_pose_enc = pred_pose_enc.detach()
|
||||
module_input = self.embed_pose(pred_pose_enc)
|
||||
|
||||
# Generate modulation parameters and split them into shift, scale, and gate components.
|
||||
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
||||
|
||||
# Adaptive layer normalization and modulation.
|
||||
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
||||
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
||||
|
||||
for idx in range(self.trunk_depth):
|
||||
pose_tokens_modulated = self.trunk[idx](pose_tokens_modulated, pos=pos3d, video_mask=mask, num_frames=S*self.cp_size, frame_seqlen=1, kv_cache=self.kv_cache[i] if self.kv_cache is not None else None, global_idx=idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=sliding_window_size, enable_ulysses_cp=self.enable_ulysses_cp, enable_3d_rope=self.enable_3d_rope, is_scale_frames=is_scale_frames)
|
||||
# Compute the delta update for the pose encoding.
|
||||
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
||||
|
||||
if pred_pose_enc is None:
|
||||
pred_pose_enc = pred_pose_enc_delta
|
||||
else:
|
||||
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
||||
|
||||
# Apply final activation functions for translation, quaternion, and field-of-view.
|
||||
activated_pose = activate_pose(
|
||||
pred_pose_enc, trans_act=self.trans_act, quat_act=self.quat_act, fl_act=self.fl_act
|
||||
)
|
||||
pred_pose_enc_list.append(activated_pose)
|
||||
|
||||
# Update frame_idx for streaming mode (KV cache)
|
||||
if self.kv_cache is not None:
|
||||
self.frame_idx += S
|
||||
|
||||
return pred_pose_enc_list
|
||||
|
||||
|
||||
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Modulate the input tensor using scaling and shifting parameters.
|
||||
"""
|
||||
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
|
||||
|
||||
|
||||
class CameraDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
dec_embed_dim=512,
|
||||
depth=5,
|
||||
dec_num_heads=8,
|
||||
mlp_ratio=4,
|
||||
rope=None,
|
||||
need_project=True,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.projects = nn.Linear(in_dim, dec_embed_dim) if need_project else nn.Identity()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=dec_embed_dim,
|
||||
num_heads=dec_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=True,
|
||||
proj_bias=True,
|
||||
ffn_bias=True,
|
||||
drop_path=0.0,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
ffn_layer=Mlp,
|
||||
init_values=None,
|
||||
qk_norm=False,
|
||||
# attn_class=MemEffAttentionRope,
|
||||
rope=rope
|
||||
) for _ in range(depth)])
|
||||
|
||||
self.linear_out = nn.Linear(dec_embed_dim, out_dim)
|
||||
|
||||
def forward(self, hidden, xpos=None):
|
||||
hidden = self.projects(hidden)
|
||||
B, V, P, C = hidden.shape
|
||||
hidden = hidden.view(hidden.shape[0]*hidden.shape[1], hidden.shape[2], hidden.shape[3])
|
||||
for i, blk in enumerate(self.blocks):
|
||||
if self.use_checkpoint and self.training:
|
||||
hidden = checkpoint(blk, hidden, pos=xpos, use_reentrant=False)
|
||||
else:
|
||||
hidden = blk(hidden, pos=xpos)
|
||||
out = self.linear_out(hidden).view(B, V, P, -1)
|
||||
|
||||
return out
|
||||
679
lingbot_map/heads/dpt_head.py
Normal file
679
lingbot_map/heads/dpt_head.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Inspired by https://github.com/DepthAnything/Depth-Anything-V2
|
||||
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .head_act import activate_head
|
||||
from .utils import create_uv_grid, position_grid_to_embed
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
"""
|
||||
DPT Head for dense prediction tasks.
|
||||
|
||||
This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
|
||||
(https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
|
||||
backbone and produces dense predictions by fusing multi-scale features.
|
||||
|
||||
Args:
|
||||
dim_in (int): Input dimension (channels).
|
||||
patch_size (int, optional): Patch size. Default is 14.
|
||||
output_dim (int, optional): Number of output channels. Default is 4.
|
||||
activation (str, optional): Activation type. Default is "inv_log".
|
||||
conf_activation (str, optional): Confidence activation type. Default is "expp1".
|
||||
features (int, optional): Feature channels for intermediate representations. Default is 256.
|
||||
out_channels (List[int], optional): Output channels for each intermediate layer.
|
||||
intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
|
||||
pos_embed (bool, optional): Whether to use positional embedding. Default is True.
|
||||
feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
|
||||
down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 4,
|
||||
activation: str = "inv_log",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: List[int] = [256, 512, 1024, 1024],
|
||||
intermediate_layer_idx: List[int] = [0, 1, 2, 3],
|
||||
pos_embed: bool = True,
|
||||
feature_only: bool = False,
|
||||
down_ratio: int = 1,
|
||||
) -> None:
|
||||
super(DPTHead, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.feature_only = feature_only
|
||||
self.down_ratio = down_ratio
|
||||
self.intermediate_layer_idx = intermediate_layer_idx
|
||||
|
||||
self.norm = nn.LayerNorm(dim_in)
|
||||
|
||||
# Projection layers for each output channel from tokens.
|
||||
self.projects = nn.ModuleList(
|
||||
[nn.Conv2d(in_channels=dim_in, out_channels=oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
||||
)
|
||||
|
||||
# Resize layers for upsampling feature maps.
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, expand=False)
|
||||
|
||||
# Attach additional modules to scratch.
|
||||
self.scratch.stem_transpose = None
|
||||
self.scratch.refinenet1 = _make_fusion_block(features)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
if feature_only:
|
||||
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.scratch.output_conv1 = nn.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
conv2_in_channels = head_features_1 // 2
|
||||
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
aggregated_tokens_list: List[torch.Tensor],
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
frames_chunk_size: int = 8,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass through the DPT head, supports processing by chunking frames.
|
||||
Args:
|
||||
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
||||
images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
|
||||
patch_start_idx (int): Starting index for patch tokens in the token sequence.
|
||||
Used to separate patch tokens from other tokens (e.g., camera or register tokens).
|
||||
frames_chunk_size (int, optional): Number of frames to process in each chunk.
|
||||
If None or larger than S, all frames are processed at once. Default: 8.
|
||||
|
||||
Returns:
|
||||
Tensor or Tuple[Tensor, Tensor]:
|
||||
- If feature_only=True: Feature maps with shape [B, S, C, H, W]
|
||||
- Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
|
||||
"""
|
||||
B, _, _, H, W = images.shape
|
||||
|
||||
S = aggregated_tokens_list[0].shape[1]
|
||||
|
||||
# If frames_chunk_size is not specified or greater than S, process all frames at once
|
||||
if frames_chunk_size is None or frames_chunk_size >= S:
|
||||
return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
|
||||
|
||||
# Otherwise, process frames in chunks to manage memory usage
|
||||
assert frames_chunk_size > 0
|
||||
|
||||
# Process frames in batches
|
||||
all_preds = []
|
||||
all_conf = []
|
||||
|
||||
for frames_start_idx in range(0, S, frames_chunk_size):
|
||||
frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
|
||||
|
||||
# Process batch of frames
|
||||
if self.feature_only:
|
||||
chunk_output = self._forward_impl(
|
||||
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
||||
)
|
||||
all_preds.append(chunk_output)
|
||||
else:
|
||||
chunk_preds, chunk_conf = self._forward_impl(
|
||||
aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
|
||||
)
|
||||
all_preds.append(chunk_preds)
|
||||
all_conf.append(chunk_conf)
|
||||
|
||||
# Concatenate results along the sequence dimension
|
||||
if self.feature_only:
|
||||
return torch.cat(all_preds, dim=1)
|
||||
else:
|
||||
return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
|
||||
|
||||
def _forward_impl(
|
||||
self,
|
||||
aggregated_tokens_list: List[torch.Tensor],
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
frames_start_idx: int = None,
|
||||
frames_end_idx: int = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Implementation of the forward pass through the DPT head.
|
||||
|
||||
This method processes a specific chunk of frames from the sequence.
|
||||
|
||||
Args:
|
||||
aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
|
||||
images (Tensor): Input images with shape [B, S, 3, H, W].
|
||||
patch_start_idx (int): Starting index for patch tokens.
|
||||
frames_start_idx (int, optional): Starting index for frames to process.
|
||||
frames_end_idx (int, optional): Ending index for frames to process.
|
||||
|
||||
Returns:
|
||||
Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
|
||||
"""
|
||||
|
||||
B, _, _, H, W = images.shape
|
||||
|
||||
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
||||
|
||||
out = []
|
||||
dpt_idx = 0
|
||||
|
||||
for layer_idx in self.intermediate_layer_idx:
|
||||
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
||||
|
||||
|
||||
|
||||
if frames_start_idx is not None and frames_end_idx is not None:
|
||||
x = x[:, frames_start_idx:frames_end_idx]
|
||||
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
|
||||
x = x.reshape(B * S, -1, x.shape[-1])
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[dpt_idx](x)
|
||||
if self.pos_embed:
|
||||
x = self._apply_pos_embed(x, W, H)
|
||||
x = self.resize_layers[dpt_idx](x)
|
||||
|
||||
out.append(x)
|
||||
dpt_idx += 1
|
||||
|
||||
# Fuse features from multiple layers.
|
||||
out = self.scratch_forward(out)
|
||||
# Interpolate fused output to match target image resolution.
|
||||
out = custom_interpolate(
|
||||
out,
|
||||
(int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
if self.pos_embed:
|
||||
out = self._apply_pos_embed(out, W, H)
|
||||
|
||||
if self.feature_only:
|
||||
return out.view(B, S, *out.shape[1:])
|
||||
|
||||
out = self.scratch.output_conv2(out)
|
||||
preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
|
||||
|
||||
preds = preds.view(B, S, *preds.shape[1:])
|
||||
conf = conf.view(B, S, *conf.shape[1:])
|
||||
return preds, conf
|
||||
|
||||
def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""
|
||||
Apply positional embedding to tensor x.
|
||||
"""
|
||||
patch_w = x.shape[-1]
|
||||
patch_h = x.shape[-2]
|
||||
pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
||||
pos_embed = pos_embed * ratio
|
||||
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
||||
return x + pos_embed
|
||||
|
||||
def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the fusion blocks.
|
||||
|
||||
Args:
|
||||
features (List[Tensor]): List of feature maps from different layers.
|
||||
|
||||
Returns:
|
||||
Tensor: Fused feature map.
|
||||
"""
|
||||
layer_1, layer_2, layer_3, layer_4 = features
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
del layer_4_rn, layer_4
|
||||
|
||||
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
del layer_3_rn, layer_3
|
||||
|
||||
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
del layer_2_rn, layer_2
|
||||
|
||||
out = self.scratch.refinenet1(out, layer_1_rn)
|
||||
del layer_1_rn, layer_1
|
||||
|
||||
out = self.scratch.output_conv1(out)
|
||||
return out
|
||||
|
||||
|
||||
################################################################################
|
||||
# Modules
|
||||
################################################################################
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
|
||||
return FeatureFusionBlock(
|
||||
features,
|
||||
nn.ReLU(inplace=True),
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
has_residual=has_residual,
|
||||
groups=groups,
|
||||
)
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
|
||||
scratch = nn.Module()
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape
|
||||
out_shape3 = out_shape
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape
|
||||
|
||||
if expand:
|
||||
out_shape1 = out_shape
|
||||
out_shape2 = out_shape * 2
|
||||
out_shape3 = out_shape * 4
|
||||
if len(in_shape) >= 4:
|
||||
out_shape4 = out_shape * 8
|
||||
|
||||
scratch.layer1_rn = nn.Conv2d(
|
||||
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer2_rn = nn.Conv2d(
|
||||
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
scratch.layer3_rn = nn.Conv2d(
|
||||
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
if len(in_shape) >= 4:
|
||||
scratch.layer4_rn = nn.Conv2d(
|
||||
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||
)
|
||||
return scratch
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module."""
|
||||
|
||||
def __init__(self, features, activation, bn, groups=1):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.bn = bn
|
||||
self.groups = groups
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||
|
||||
self.norm1 = None
|
||||
self.norm2 = None
|
||||
|
||||
self.activation = activation
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
if self.norm1 is not None:
|
||||
out = self.norm1(out)
|
||||
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
if self.norm2 is not None:
|
||||
out = self.norm2(out)
|
||||
|
||||
return self.skip_add.add(out, x)
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=None,
|
||||
has_residual=True,
|
||||
groups=1,
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
self.groups = groups
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand == True:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(
|
||||
features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
|
||||
)
|
||||
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
||||
|
||||
self.has_residual = has_residual
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
self.size = size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if self.has_residual:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Tuple[int, int] = None,
|
||||
scale_factor: float = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
|
||||
"""
|
||||
if size is None:
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
|
||||
INT_MAX = 1610612736
|
||||
|
||||
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
|
||||
if input_elements > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
||||
interpolated_chunks = [
|
||||
nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
|
||||
]
|
||||
x = torch.cat(interpolated_chunks, dim=0)
|
||||
return x.contiguous()
|
||||
else:
|
||||
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
class DPTHead_Update(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
features=256,
|
||||
use_bn=False,
|
||||
out_channels=[256, 512, 1024, 1024],
|
||||
use_clstoken=False
|
||||
):
|
||||
super(DPTHead_Update, self).__init__()
|
||||
|
||||
self.use_clstoken = use_clstoken
|
||||
|
||||
self.projects = nn.ModuleList([
|
||||
nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
) for out_channel in out_channels
|
||||
])
|
||||
|
||||
self.resize_layers = nn.ModuleList([
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0],
|
||||
out_channels=out_channels[0],
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1],
|
||||
out_channels=out_channels[1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
padding=0),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3],
|
||||
out_channels=out_channels[3],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
])
|
||||
|
||||
if use_clstoken:
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(2 * in_channels, in_channels),
|
||||
nn.GELU()))
|
||||
|
||||
self.scratch = _make_scratch(
|
||||
out_channels,
|
||||
features,
|
||||
groups=1,
|
||||
expand=False,
|
||||
)
|
||||
|
||||
self.scratch.stem_transpose = None
|
||||
|
||||
self.scratch.refinenet1 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet2 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet3 = _make_fusion_block_slam(features, use_bn)
|
||||
self.scratch.refinenet4 = _make_fusion_block_slam(features, use_bn)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||
nn.ReLU(True),
|
||||
nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, out_features, patch_h, patch_w, return_intermediate=True):
|
||||
out = []
|
||||
for i, x in enumerate(out_features):
|
||||
if self.use_clstoken:
|
||||
x, cls_token = x[0], x[1]
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
|
||||
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
|
||||
out.append(x)
|
||||
|
||||
layer_1, layer_2, layer_3, layer_4 = out
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
out = self.scratch.output_conv1(path_1)
|
||||
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||
if return_intermediate:
|
||||
return out, path_1, path_2, path_3, path_4
|
||||
else:
|
||||
out = self.scratch.output_conv2(out)
|
||||
return out
|
||||
|
||||
def _make_fusion_block_slam(features, use_bn, size=None):
|
||||
return FeatureFusionBlock_slam(
|
||||
features,
|
||||
nn.ReLU(False),
|
||||
deconv=False,
|
||||
bn=use_bn,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=size,
|
||||
)
|
||||
|
||||
|
||||
class FeatureFusionBlock_slam(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
activation,
|
||||
deconv=False,
|
||||
bn=False,
|
||||
expand=False,
|
||||
align_corners=True,
|
||||
size=None
|
||||
):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock_slam, self).__init__()
|
||||
|
||||
self.deconv = deconv
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.groups=1
|
||||
|
||||
self.expand = expand
|
||||
out_features = features
|
||||
if self.expand == True:
|
||||
out_features = features // 2
|
||||
|
||||
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||
|
||||
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||
|
||||
self.skip_add = nn.quantized.FloatFunctional()
|
||||
|
||||
self.size=size
|
||||
|
||||
def forward(self, *xs, size=None):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
res = self.resConfUnit1(xs[1])
|
||||
output = self.skip_add.add(output, res)
|
||||
|
||||
output = self.resConfUnit2(output)
|
||||
|
||||
if (size is None) and (self.size is None):
|
||||
modifier = {"scale_factor": 2}
|
||||
elif size is None:
|
||||
modifier = {"size": self.size}
|
||||
else:
|
||||
modifier = {"size": size}
|
||||
|
||||
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||
|
||||
output = self.out_conv(output)
|
||||
|
||||
return output
|
||||
125
lingbot_map/heads/head_act.py
Normal file
125
lingbot_map/heads/head_act.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
||||
"""
|
||||
Activate pose parameters with specified activation functions.
|
||||
|
||||
Args:
|
||||
pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
|
||||
trans_act: Activation type for translation component
|
||||
quat_act: Activation type for quaternion component
|
||||
fl_act: Activation type for focal length component
|
||||
|
||||
Returns:
|
||||
Activated pose parameters tensor
|
||||
"""
|
||||
T = pred_pose_enc[..., :3]
|
||||
quat = pred_pose_enc[..., 3:7]
|
||||
fl = pred_pose_enc[..., 7:] # or fov
|
||||
|
||||
T = base_pose_act(T, trans_act)
|
||||
quat = base_pose_act(quat, quat_act)
|
||||
fl = base_pose_act(fl, fl_act) # or fov
|
||||
|
||||
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
||||
|
||||
return pred_pose_enc
|
||||
|
||||
|
||||
def base_pose_act(pose_enc, act_type="linear"):
|
||||
"""
|
||||
Apply basic activation function to pose parameters.
|
||||
|
||||
Args:
|
||||
pose_enc: Tensor containing encoded pose parameters
|
||||
act_type: Activation type ("linear", "inv_log", "exp", "relu")
|
||||
|
||||
Returns:
|
||||
Activated pose parameters
|
||||
"""
|
||||
if act_type == "linear":
|
||||
return pose_enc
|
||||
elif act_type == "inv_log":
|
||||
return inverse_log_transform(pose_enc)
|
||||
elif act_type == "exp":
|
||||
return torch.exp(pose_enc)
|
||||
elif act_type == "relu":
|
||||
return F.relu(pose_enc)
|
||||
else:
|
||||
raise ValueError(f"Unknown act_type: {act_type}")
|
||||
|
||||
|
||||
def activate_head(out, activation="norm_exp", conf_activation="expp1"):
|
||||
"""
|
||||
Process network output to extract 3D points and confidence values.
|
||||
|
||||
Args:
|
||||
out: Network output tensor (B, C, H, W)
|
||||
activation: Activation type for 3D points
|
||||
conf_activation: Activation type for confidence values
|
||||
|
||||
Returns:
|
||||
Tuple of (3D points tensor, confidence tensor)
|
||||
"""
|
||||
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
||||
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
||||
|
||||
# Split into xyz (first C-1 channels) and confidence (last channel)
|
||||
xyz = fmap[:, :, :, :-1]
|
||||
conf = fmap[:, :, :, -1]
|
||||
|
||||
if activation == "norm_exp":
|
||||
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
||||
xyz_normed = xyz / d
|
||||
pts3d = xyz_normed * torch.expm1(d)
|
||||
elif activation == "norm":
|
||||
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
||||
elif activation == "exp":
|
||||
pts3d = torch.exp(xyz)
|
||||
elif activation == "relu":
|
||||
pts3d = F.relu(xyz)
|
||||
elif activation == "inv_log":
|
||||
pts3d = inverse_log_transform(xyz)
|
||||
elif activation == "xy_inv_log":
|
||||
xy, z = xyz.split([2, 1], dim=-1)
|
||||
z = inverse_log_transform(z)
|
||||
pts3d = torch.cat([xy * z, z], dim=-1)
|
||||
elif activation == "sigmoid":
|
||||
pts3d = torch.sigmoid(xyz)
|
||||
elif activation == "linear":
|
||||
pts3d = xyz
|
||||
else:
|
||||
raise ValueError(f"Unknown activation: {activation}")
|
||||
|
||||
if conf_activation == "expp1":
|
||||
conf_out = 1 + conf.exp()
|
||||
elif conf_activation == "expp0":
|
||||
conf_out = conf.exp()
|
||||
elif conf_activation == "sigmoid":
|
||||
conf_out = torch.sigmoid(conf)
|
||||
else:
|
||||
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
||||
|
||||
return pts3d, conf_out
|
||||
|
||||
|
||||
def inverse_log_transform(y):
|
||||
"""
|
||||
Apply inverse log transform: sign(y) * (exp(|y|) - 1)
|
||||
|
||||
Args:
|
||||
y: Input tensor
|
||||
|
||||
Returns:
|
||||
Transformed tensor
|
||||
"""
|
||||
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
||||
109
lingbot_map/heads/utils.py
Normal file
109
lingbot_map/heads/utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
||||
"""
|
||||
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
||||
|
||||
Args:
|
||||
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
||||
embed_dim: Output channel dimension for embeddings
|
||||
|
||||
Returns:
|
||||
Tensor of shape (H, W, embed_dim) with positional embeddings
|
||||
"""
|
||||
H, W, grid_dim = pos_grid.shape
|
||||
assert grid_dim == 2
|
||||
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
||||
|
||||
# Process x and y coordinates separately
|
||||
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
||||
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
||||
|
||||
# Combine and reshape
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
||||
|
||||
return emb.view(H, W, embed_dim) # [H, W, D]
|
||||
|
||||
|
||||
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
||||
"""
|
||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||
|
||||
Args:
|
||||
- embed_dim: The embedding dimension.
|
||||
- pos: The position to generate the embedding from.
|
||||
|
||||
Returns:
|
||||
- emb: The generated 1D positional embedding.
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
device = pos.device
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / omega_0**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb.float()
|
||||
|
||||
|
||||
# Inspired by https://github.com/microsoft/moge
|
||||
|
||||
|
||||
def create_uv_grid(
|
||||
width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a normalized UV grid of shape (width, height, 2).
|
||||
|
||||
The grid spans horizontally and vertically according to an aspect ratio,
|
||||
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
||||
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
||||
|
||||
Args:
|
||||
width (int): Number of points horizontally.
|
||||
height (int): Number of points vertically.
|
||||
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
||||
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
||||
device (torch.device, optional): Device on which the tensor is created.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
||||
"""
|
||||
# Derive aspect ratio if not explicitly provided
|
||||
if aspect_ratio is None:
|
||||
aspect_ratio = float(width) / float(height)
|
||||
|
||||
# Compute normalized spans for X and Y
|
||||
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
|
||||
# Establish the linspace boundaries
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
|
||||
# Generate 1D coordinates
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
|
||||
# Create 2D meshgrid (width x height) and stack into UV
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
uv_grid = torch.stack((uu, vv), dim=-1)
|
||||
|
||||
return uv_grid
|
||||
5
lingbot_map/layers/__init__.py
Normal file
5
lingbot_map/layers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from lingbot_map.layers.mlp import Mlp
|
||||
from lingbot_map.layers.patch_embed import PatchEmbed
|
||||
from lingbot_map.layers.block import Block
|
||||
from lingbot_map.layers.swiglu_ffn import SwiGLUFFN as SwiGLUFFNFused
|
||||
from lingbot_map.layers.attention import Attention as MemEffAttention
|
||||
766
lingbot_map/layers/attention.py
Normal file
766
lingbot_map/layers/attention.py
Normal file
@@ -0,0 +1,766 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lingbot_map.layers.rope import apply_rotary_emb
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
# FlashInfer imports (optional - for paged attention)
|
||||
try:
|
||||
import flashinfer
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_AVAILABLE = False
|
||||
print("flashinfer not available")
|
||||
|
||||
try:
|
||||
from torchtitan.distributed.sequence_parallel import (
|
||||
gather_seq_scatter_heads,
|
||||
gather_heads_scatter_seq,
|
||||
pad_tensor,
|
||||
slice_input_tensor_scale_grad,
|
||||
gather_outputs,
|
||||
)
|
||||
except ImportError:
|
||||
print("torchtitan not available for ulysses cp")
|
||||
|
||||
def gather_seq_scatter_heads_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_dim: int, head_dim: int):
|
||||
"""Gather sequence dimension and scatter head dimension for Q, K, V tensors."""
|
||||
q = gather_seq_scatter_heads(q, seq_dim, head_dim)
|
||||
k = gather_seq_scatter_heads(k, seq_dim, head_dim)
|
||||
v = gather_seq_scatter_heads(v, seq_dim, head_dim)
|
||||
return q, k, v
|
||||
|
||||
from typing_extensions import List
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = fused_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.rope = rope
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False, num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
|
||||
if self.rope is not None:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.0)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
|
||||
if enable_ulysses_cp:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalAttention(nn.Module):
|
||||
"""
|
||||
Causal self-attention module with KV cache support for streaming inference.
|
||||
Used by CasualBlockCamera in camera_head.py.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
elementwise_attn_output_gate=False,
|
||||
# KV cache eviction parameters (matching build_attn_mask)
|
||||
kv_cache_sliding_window: int =64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False, # If True, only cache camera token (no scale token)
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = fused_attn
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.rope = rope
|
||||
|
||||
self.gate_proj = nn.Linear(dim, dim, bias=True) if elementwise_attn_output_gate else None
|
||||
|
||||
# Store KV cache eviction parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def forward(self, x: Tensor, block_mask=None, pos=None, pos_kv=None, frame_seqlen=None, video_mask=None, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=1, num_frame_for_scale=-1, enable_3d_rope=False, sliding_window_size=-1, attend_to_scale_frames=False, num_random_frames=0, attend_to_special_tokens=False, num_register_tokens=4, enable_ulysses_cp=False, is_scale_frames=False) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
|
||||
# Calculate special token indices
|
||||
camera_token_idx = 0
|
||||
scale_token_idx = camera_token_idx + num_register_tokens + 1 # camera + register tokens + scale
|
||||
|
||||
# [3, B, num_heads, N, head_dim]
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.gate_proj is not None:
|
||||
gate_score = self.gate_proj(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||
if kv_cache is None:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if enable_ulysses_cp:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
N = q.shape[2] # Update N after gather
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif enable_3d_rope and pos is not None:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
with torch.no_grad():
|
||||
block_mask = block_mask.squeeze()[:q.shape[2], :k.shape[2]]
|
||||
if block_mask.dim() == 2:
|
||||
block_mask = block_mask.unsqueeze(0).unsqueeze(0) # [1, 1, N, N]
|
||||
block_mask = block_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
||||
|
||||
video_mask = video_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) if video_mask is not None else torch.ones_like(block_mask, device=block_mask.device) # [1, 1, N, N]
|
||||
video_mask = video_mask.expand(B, 1, block_mask.shape[-2], block_mask.shape[-1])
|
||||
|
||||
mask = block_mask | ~video_mask
|
||||
|
||||
# Apply sliding window mask if sliding_window_size > 0
|
||||
# sliding_window_size is in units of num_frame_per_block
|
||||
if sliding_window_size > 0 and frame_seqlen is not None:
|
||||
# Create sliding window mask: each frame can only attend to frames within the window
|
||||
num_frames = N // frame_seqlen
|
||||
sliding_mask = torch.zeros_like(mask, dtype=torch.bool)
|
||||
|
||||
for i in range(num_frames):
|
||||
q_start = i * frame_seqlen
|
||||
q_end = (i + 1) * frame_seqlen
|
||||
# Calculate the window start: sliding_window_size is in units of num_frame_per_block
|
||||
# So the actual window size in frames is sliding_window_size * num_frame_per_block
|
||||
window_size_in_frames = sliding_window_size * num_frame_per_block
|
||||
window_start_frame = max(0, i - window_size_in_frames + 1)
|
||||
k_start = window_start_frame * frame_seqlen
|
||||
k_end = (i + 1) * frame_seqlen # Can attend up to current frame (causal)
|
||||
sliding_mask[:, :, q_start:q_end, k_start:k_end] = True
|
||||
|
||||
# Combine with existing mask: both masks need to allow attention
|
||||
mask = mask & sliding_mask
|
||||
|
||||
# If attend_to_scale_frames is True, also allow attention to first num_frame_for_scale frames
|
||||
if num_frame_for_scale > 0:
|
||||
for i in range(num_frames):
|
||||
q_start = i * frame_seqlen
|
||||
q_end = (i + 1) * frame_seqlen
|
||||
# Allow attending to first num_frame_for_scale frames (directly set to True, not depending on block_mask)
|
||||
mask[:, :, q_start:q_end, :num_frame_for_scale * frame_seqlen] = True
|
||||
|
||||
## global attention for the first num_frame_for_scale frames
|
||||
if num_frame_for_scale > 0:
|
||||
mask[:, :, :num_frame_for_scale * frame_seqlen, :num_frame_for_scale * frame_seqlen] = True
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
attn_mask=mask
|
||||
)
|
||||
else:
|
||||
# Apply RoPE to current k before caching
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif enable_3d_rope and pos is not None:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
# Check if we should skip appending to cache (non-keyframe in keyframe mode)
|
||||
skip_append = kv_cache.get("_skip_append", False)
|
||||
|
||||
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
|
||||
if not skip_append:
|
||||
# KEYFRAME: store in cache (original behavior)
|
||||
if kv_cache[f"k_{global_idx}"] is None:
|
||||
kv_cache[f"k_{global_idx}"] = k_reshaped
|
||||
kv_cache[f"v_{global_idx}"] = v_reshaped
|
||||
else:
|
||||
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
||||
k_reshaped = k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
v_reshaped = v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
||||
|
||||
# Apply sliding window eviction BEFORE attention to match causal_3drope behavior
|
||||
# This ensures current frame only attends to frames within the sliding window
|
||||
self._apply_kv_cache_eviction_causal(kv_cache, global_idx, camera_token_idx, scale_token_idx)
|
||||
|
||||
# Retrieve full k, v from cache (already RoPE-applied, already evicted)
|
||||
k = kv_cache[f"k_{global_idx}"].clone()
|
||||
v = kv_cache[f"v_{global_idx}"].clone()
|
||||
else:
|
||||
# NON-KEYFRAME: attend to [cached + current] without storing in cache
|
||||
if kv_cache[f"k_{global_idx}"] is not None:
|
||||
k = torch.cat((kv_cache[f"k_{global_idx}"], k_reshaped), dim=2)
|
||||
v = torch.cat((kv_cache[f"v_{global_idx}"], v_reshaped), dim=2)
|
||||
else:
|
||||
k = k_reshaped
|
||||
v = v_reshaped
|
||||
a, b, c, d, e = k.shape
|
||||
|
||||
k = k.reshape(a, b, c*d, e)
|
||||
v = v.reshape(a, b, c*d, e)
|
||||
|
||||
# Prepend special tokens (camera + scale) from evicted frames if they exist
|
||||
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
||||
special_k = kv_cache[f"k_{global_idx}_special"] # [B, H, num_evicted_frames, 2, D]
|
||||
special_v = kv_cache[f"v_{global_idx}_special"]
|
||||
sa, sb, sc, sd, se = special_k.shape
|
||||
special_k = special_k.reshape(sa, sb, sc * sd, se) # [B, H, num_evicted*2, D]
|
||||
special_v = special_v.reshape(sa, sb, sc * sd, se)
|
||||
|
||||
# Prepend special tokens (older frames first)
|
||||
k = torch.cat([special_k, k], dim=2)
|
||||
v = torch.cat([special_v, v], dim=2)
|
||||
|
||||
# Note: k from cache is already RoPE-applied, no need to apply again
|
||||
|
||||
if self.fused_attn:
|
||||
# Use mask-based SDPA to ensure same kernel as batch mode
|
||||
# The causal constraint is enforced by KV cache contents, not by mask
|
||||
mask = torch.ones(B, 1, q.shape[2], k.shape[2], dtype=torch.bool, device=q.device)
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
attn_mask=mask,
|
||||
)
|
||||
|
||||
if self.gate_proj is not None:
|
||||
x = x * torch.sigmoid(gate_score)
|
||||
if enable_ulysses_cp:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
# Use actual dimensions from attention output, not original input C
|
||||
# x shape: [B, H, seq_len, head_dim] -> [B, seq_len, H*head_dim]
|
||||
x = x.transpose(1, 2).reshape(B, -1, self.num_heads * self.head_dim)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def _apply_kv_cache_eviction_causal(self, kv_cache, global_idx, camera_token_idx, scale_token_idx):
|
||||
"""
|
||||
Apply sliding window eviction to KV cache BEFORE attention.
|
||||
|
||||
This ensures current frame only attends to frames within the sliding window,
|
||||
matching the behavior of causal_3drope's attention mask.
|
||||
"""
|
||||
sliding_window_frames = self.kv_cache_sliding_window
|
||||
scale_frames = self.kv_cache_scale_frames
|
||||
|
||||
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
||||
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
||||
|
||||
if num_cached_frames > sliding_window_frames + scale_frames:
|
||||
evict_start = scale_frames
|
||||
evict_end = num_cached_frames - sliding_window_frames
|
||||
|
||||
if evict_end > evict_start:
|
||||
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
|
||||
if self.kv_cache_cross_frame_special:
|
||||
if self.kv_cache_camera_only:
|
||||
# Only keep camera token
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
else:
|
||||
# Keep ALL special tokens (camera + register + scale) to match attention_mask behavior
|
||||
# Special tokens are in range [camera_token_idx, scale_token_idx+1)
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
|
||||
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
||||
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
||||
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
||||
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
||||
|
||||
if self.kv_cache_include_scale_frames:
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
|
||||
|
||||
class FlashInferAttention(Attention):
|
||||
"""
|
||||
FlashInfer variant of the GCT attention layer.
|
||||
Uses FlashInferKVCacheManager for paged KV cache storage and
|
||||
FlashInfer attention kernels (BatchPrefillWithPagedKVCacheWrapper).
|
||||
Supports the same optimized token layout and KV cache streaming inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True,
|
||||
rope=None,
|
||||
# KV cache eviction parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
||||
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=proj_drop,
|
||||
norm_layer=norm_layer,
|
||||
qk_norm=qk_norm,
|
||||
fused_attn=fused_attn,
|
||||
rope=rope,
|
||||
)
|
||||
|
||||
# Store KV cache eviction parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def prepare_qkv(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
||||
"""Fused pre-attention ops for single-frame streaming (Phase 2).
|
||||
|
||||
Computes q/k/v from x, applies q_norm/k_norm/RoPE, and converts to
|
||||
[tpf, H, D] format ready for append_frame + compute_attention.
|
||||
|
||||
Extracted as a method so torch.compile can capture all pre-attn ops as one
|
||||
CUDA graph (qkv linear -> reshape -> unbind -> q_norm -> k_norm -> RoPE x2 ->
|
||||
squeeze/permute/contiguous x3).
|
||||
"""
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None: # enable_3d_rope=True
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
# Convert to [tpf, H, D] format for FlashInfer (B=1 in streaming mode)
|
||||
q_nhd = q.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
k_nhd = k.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
v_nhd = v.squeeze(0).permute(1, 0, 2).contiguous()
|
||||
return q_nhd, k_nhd, v_nhd
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
# KV cache parameters (kv_cache is a FlashInferKVCacheManager or None)
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
"""
|
||||
Forward pass with FlashInfer paged KV cache and attention.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, N, C]
|
||||
kv_cache: FlashInferKVCacheManager instance or None (batch mode)
|
||||
global_idx: Block index for per-block cache access
|
||||
"""
|
||||
from lingbot_map.layers.flashinfer_cache import FlashInferKVCacheManager
|
||||
|
||||
B, N, C = x.shape
|
||||
|
||||
# Detect if using optimized layout
|
||||
using_optimized_layout = (num_patches is not None and num_special is not None
|
||||
and num_frames is not None)
|
||||
|
||||
# ========== Batch Mode (no KV cache manager) ==========
|
||||
if not isinstance(kv_cache, FlashInferKVCacheManager):
|
||||
# [3, B, num_heads, N, head_dim]
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # Each: [B, num_heads, N, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
if using_optimized_layout:
|
||||
boundary = num_frames * num_patches
|
||||
q_patch, k_patch, v_patch = q[:, :, :boundary, :], k[:, :, :boundary, :], v[:, :, :boundary, :]
|
||||
q_special, k_special, v_special = q[:, :, boundary:, :], k[:, :, boundary:, :], v[:, :, boundary:, :]
|
||||
q_patch, k_patch, v_patch = gather_seq_scatter_heads_qkv(
|
||||
q_patch, k_patch, v_patch, seq_dim=2, head_dim=1
|
||||
)
|
||||
q_special, k_special, v_special = gather_seq_scatter_heads_qkv(
|
||||
q_special, k_special, v_special, seq_dim=2, head_dim=1
|
||||
)
|
||||
q = torch.cat([q_patch, q_special], dim=2)
|
||||
k = torch.cat([k_patch, k_special], dim=2)
|
||||
v = torch.cat([v_patch, v_special], dim=2)
|
||||
else:
|
||||
q, k, v = gather_seq_scatter_heads_qkv(q, k, v, seq_dim=2, head_dim=1)
|
||||
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
# Batch mode: use SDPA for numerical consistency with SDPA variant
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
|
||||
if enable_ulysses_cp:
|
||||
if using_optimized_layout:
|
||||
seq_global = x.shape[2]
|
||||
seq_local = num_frames * (num_patches + num_special)
|
||||
cp_size = seq_global // seq_local
|
||||
boundary_global = num_frames * cp_size * num_patches
|
||||
x_patch = x[:, :, :boundary_global, :]
|
||||
x_special = x[:, :, boundary_global:, :]
|
||||
x_patch = gather_heads_scatter_seq(x_patch, seq_dim=2, head_dim=1)
|
||||
x_special = gather_heads_scatter_seq(x_special, seq_dim=2, head_dim=1)
|
||||
x = torch.cat([x_patch, x_special], dim=2)
|
||||
else:
|
||||
x = gather_heads_scatter_seq(x, seq_dim=2, head_dim=1)
|
||||
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# ========== Streaming Mode (with FlashInferKVCacheManager) ==========
|
||||
else:
|
||||
manager = kv_cache # FlashInferKVCacheManager
|
||||
|
||||
# Phase 1 (scale frames): num_frames > 1 — multi-frame batch
|
||||
# Phase 2 (streaming): num_frames == 1 — single frame
|
||||
is_multi_frame = (num_frames is not None and num_frames > 1)
|
||||
|
||||
if is_multi_frame:
|
||||
# Phase 1: compute full self-attention via SDPA (all frames attend to each other),
|
||||
# then append each frame's K/V to the paged cache one at a time.
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
# Apply RoPE before caching (RoPE baked into K before append)
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# Append each frame's K/V to the paged cache individually.
|
||||
tpf = manager.tokens_per_frame
|
||||
k_all = k.squeeze(0).permute(1, 0, 2) # [num_frames*tpf, H, D]
|
||||
v_all = v.squeeze(0).permute(1, 0, 2)
|
||||
for f_idx in range(num_frames):
|
||||
s = f_idx * tpf
|
||||
manager.append_frame(global_idx, k_all[s:s+tpf].contiguous(), v_all[s:s+tpf].contiguous())
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
camera_only=self.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
else:
|
||||
# Phase 2: single-frame streaming via FlashInfer paged attention.
|
||||
q_nhd, k_nhd, v_nhd = self.prepare_qkv(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
|
||||
# 1. Append to paged cache
|
||||
manager.append_frame(global_idx, k_nhd, v_nhd)
|
||||
|
||||
# 2. Apply sliding window eviction
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.kv_cache_scale_frames,
|
||||
sliding_window=self.kv_cache_sliding_window,
|
||||
cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
camera_only=self.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
|
||||
# 3. Compute attention via FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
||||
x = manager.compute_attention(global_idx, q_nhd)
|
||||
|
||||
# Convert back: [tpf, H, D] -> [B, tpf, C].
|
||||
x = x.reshape(B, q_nhd.shape[0], self.num_heads * self.head_dim)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SDPAAttention(Attention):
|
||||
"""
|
||||
SDPA variant for streaming inference.
|
||||
Uses F.scaled_dot_product_attention with dict-based KV cache.
|
||||
No FlashInfer dependency required — works on any CUDA GPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: nn.Module = nn.LayerNorm,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
|
||||
attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer,
|
||||
qk_norm=qk_norm, fused_attn=fused_attn, rope=rope,
|
||||
)
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
using_optimized_layout = (num_patches is not None and num_special is not None
|
||||
and num_frames is not None)
|
||||
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
# ========== Batch Mode (no KV cache) ==========
|
||||
if kv_cache is None:
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, N, self.num_heads * self.head_dim)
|
||||
|
||||
# ========== Streaming Mode (with KV cache dict) ==========
|
||||
else:
|
||||
if self.rope is not None and not enable_3d_rope:
|
||||
q = self.rope(q, pos)
|
||||
k = self.rope(k, pos)
|
||||
elif self.rope is not None and enable_3d_rope:
|
||||
q = apply_rotary_emb(q, pos)
|
||||
k = apply_rotary_emb(k, pos)
|
||||
|
||||
camera_token_idx = 0
|
||||
scale_token_idx = camera_token_idx + num_register_tokens + 1
|
||||
|
||||
if kv_cache[f"k_{global_idx}"] is None:
|
||||
kv_cache[f"k_{global_idx}"] = k.view(B, self.num_heads, num_frame_per_block,
|
||||
N // num_frame_per_block, self.head_dim)
|
||||
kv_cache[f"v_{global_idx}"] = v.view(B, self.num_heads, num_frame_per_block,
|
||||
N // num_frame_per_block, self.head_dim)
|
||||
else:
|
||||
num_frame_per_block = k.shape[2] // kv_cache[f"k_{global_idx}"].shape[3]
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat((
|
||||
kv_cache[f"k_{global_idx}"],
|
||||
k.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
), dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat((
|
||||
kv_cache[f"v_{global_idx}"],
|
||||
v.view(B, self.num_heads, num_frame_per_block, N // num_frame_per_block, self.head_dim)
|
||||
), dim=2)
|
||||
|
||||
self._apply_kv_cache_eviction(
|
||||
kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens
|
||||
)
|
||||
|
||||
k_cached = kv_cache[f"k_{global_idx}"].clone()
|
||||
v_cached = kv_cache[f"v_{global_idx}"].clone()
|
||||
a, b, c, d, e = k_cached.shape
|
||||
k_full = k_cached.reshape(a, b, c * d, e)
|
||||
v_full = v_cached.reshape(a, b, c * d, e)
|
||||
|
||||
if f"k_{global_idx}_special" in kv_cache and kv_cache[f"k_{global_idx}_special"] is not None:
|
||||
special_k = kv_cache[f"k_{global_idx}_special"]
|
||||
special_v = kv_cache[f"v_{global_idx}_special"]
|
||||
sa, sb, sc, sd, se = special_k.shape
|
||||
k_full = torch.cat([special_k.reshape(sa, sb, sc * sd, se), k_full], dim=2)
|
||||
v_full = torch.cat([special_v.reshape(sa, sb, sc * sd, se), v_full], dim=2)
|
||||
|
||||
q_seq_len = q.shape[2]
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k_full, v_full,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
)
|
||||
x = x.transpose(1, 2).reshape(B, q_seq_len, self.num_heads * self.head_dim)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def _apply_kv_cache_eviction(self, kv_cache, global_idx, camera_token_idx, scale_token_idx, num_register_tokens):
|
||||
"""Apply sliding window eviction to KV cache."""
|
||||
sliding_window_frames = self.kv_cache_sliding_window
|
||||
scale_frames = self.kv_cache_scale_frames
|
||||
|
||||
if kv_cache[f"k_{global_idx}"].shape[3] > 1:
|
||||
num_cached_frames = kv_cache[f"k_{global_idx}"].shape[2]
|
||||
if num_cached_frames > sliding_window_frames + scale_frames:
|
||||
evict_start = scale_frames
|
||||
evict_end = num_cached_frames - sliding_window_frames
|
||||
if evict_end > evict_start:
|
||||
evicted_k = kv_cache[f"k_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
evicted_v = kv_cache[f"v_{global_idx}"][:, :, evict_start:evict_end, :, :]
|
||||
|
||||
if self.kv_cache_cross_frame_special:
|
||||
if self.kv_cache_camera_only:
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:camera_token_idx+1, :].clone()
|
||||
else:
|
||||
new_special_k = evicted_k[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
new_special_v = evicted_v[:, :, :, camera_token_idx:scale_token_idx+1, :].clone()
|
||||
|
||||
if f"k_{global_idx}_special" not in kv_cache or kv_cache[f"k_{global_idx}_special"] is None:
|
||||
kv_cache[f"k_{global_idx}_special"] = new_special_k
|
||||
kv_cache[f"v_{global_idx}_special"] = new_special_v
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"k_{global_idx}_special"], new_special_k], dim=2)
|
||||
kv_cache[f"v_{global_idx}_special"] = torch.cat(
|
||||
[kv_cache[f"v_{global_idx}_special"], new_special_v], dim=2)
|
||||
|
||||
if self.kv_cache_include_scale_frames:
|
||||
kv_cache[f"k_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"k_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
kv_cache[f"v_{global_idx}"] = torch.cat([
|
||||
kv_cache[f"v_{global_idx}"][:, :, :scale_frames, :, :],
|
||||
kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
], dim=2)
|
||||
else:
|
||||
kv_cache[f"k_{global_idx}"] = kv_cache[f"k_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
kv_cache[f"v_{global_idx}"] = kv_cache[f"v_{global_idx}"][:, :, -sliding_window_frames:, :, :]
|
||||
514
lingbot_map/layers/block.py
Normal file
514
lingbot_map/layers/block.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
import warnings
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, CausalAttention, FlashInferAttention, SDPAAttention
|
||||
from functools import lru_cache, partial
|
||||
from torch.nn.attention.flex_attention import BlockMask, create_mask
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
qk_norm=qk_norm,
|
||||
fused_attn=fused_attn,
|
||||
rope=rope,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False) -> Tensor:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, pos=pos, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, pos=None
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
if pos is not None:
|
||||
# if necessary, apply rope to the subset
|
||||
pos = pos[brange]
|
||||
residual = residual_func(x_subset, pos=pos)
|
||||
else:
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
class FlashInferBlock(nn.Module):
|
||||
"""
|
||||
FlashInfer variant of causal block for GCT.
|
||||
Uses FlashInferAttention (FlashInfer paged KV cache + attention kernels).
|
||||
Supports optimized token layout and KV cache streaming inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = FlashInferAttention(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
qk_norm=qk_norm,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def attn_pre(self, x: Tensor, pos=None, enable_3d_rope: bool = False) -> tuple:
|
||||
"""Phase 2 streaming only: norm1 + prepare_qkv fused as one compilable unit.
|
||||
|
||||
Extracted as a named method so torch.compile can capture norm1 + qkv-linear +
|
||||
reshape + q_norm + k_norm + RoPE + format as a single CUDA graph.
|
||||
|
||||
Returns:
|
||||
(q_nhd, k_nhd, v_nhd) each [tokens_per_frame, num_heads, head_dim],
|
||||
ready for manager.append_frame + manager.compute_attention.
|
||||
"""
|
||||
return self.attn.prepare_qkv(self.norm1(x), pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos=None,
|
||||
enable_ulysses_cp=False,
|
||||
num_patches=None,
|
||||
num_special=None,
|
||||
num_frames=None,
|
||||
enable_3d_rope=False,
|
||||
kv_cache=None,
|
||||
global_idx=0,
|
||||
num_frame_per_block=1,
|
||||
num_frame_for_scale=-1,
|
||||
num_register_tokens=4,
|
||||
) -> Tensor:
|
||||
# Phase 2 (streaming): single-frame FlashInfer paged attention.
|
||||
# Handle inline so attn_pre (norm1+prepare_qkv) can be compiled as one CUDA graph.
|
||||
is_streaming = (kv_cache is not None and (num_frames is None or num_frames <= 1))
|
||||
if is_streaming:
|
||||
manager = kv_cache
|
||||
# Compiled: norm1 + qkv linear + reshape + q_norm + k_norm + RoPE + format
|
||||
q_nhd, k_nhd, v_nhd = self.attn_pre(x, pos=pos, enable_3d_rope=enable_3d_rope)
|
||||
# Eager: write frame K/V to paged cache
|
||||
manager.append_frame(global_idx, k_nhd, v_nhd)
|
||||
# CPU-only: update eviction state (deque ops, no GPU kernel)
|
||||
manager.evict_frames(
|
||||
block_idx=global_idx,
|
||||
scale_frames=self.attn.kv_cache_scale_frames,
|
||||
sliding_window=self.attn.kv_cache_sliding_window,
|
||||
cross_frame_special=self.attn.kv_cache_cross_frame_special,
|
||||
include_scale_frames=self.attn.kv_cache_include_scale_frames,
|
||||
camera_only=self.attn.kv_cache_camera_only,
|
||||
num_register_tokens=num_register_tokens,
|
||||
)
|
||||
# Eager: FlashInfer BatchPrefillWithPagedKVCacheWrapper
|
||||
attn_x = manager.compute_attention(global_idx, q_nhd)
|
||||
# [tpf, H, D] -> [B, tpf, C] (B=1 in streaming, contiguous from FlashInfer output)
|
||||
attn_x = attn_x.reshape(x.shape[0], q_nhd.shape[0],
|
||||
self.attn.num_heads * self.attn.head_dim)
|
||||
# Compiled: output projection
|
||||
attn_x = self.attn.proj(attn_x)
|
||||
x = x + self.ls1(attn_x)
|
||||
else:
|
||||
# Phase 1 (multi-frame scale pass) or non-streaming training path
|
||||
x = x + self.ls1(self.attn(
|
||||
self.norm1(x),
|
||||
pos=pos,
|
||||
enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches,
|
||||
num_special=num_special,
|
||||
num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope,
|
||||
kv_cache=kv_cache,
|
||||
global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
x = self.ffn_residual(x)
|
||||
return x
|
||||
|
||||
def ffn_residual(self, x: Tensor) -> Tensor:
|
||||
"""FFN residual branch: norm2 -> mlp -> ls2, WITH residual add fused in.
|
||||
|
||||
Includes the residual add (x + ...) so torch.compile captures the entire
|
||||
ffn branch as one CUDA graph.
|
||||
"""
|
||||
return x + self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
|
||||
class CameraBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
||||
rope=None,
|
||||
elementwise_attn_output_gate: bool = False,
|
||||
sliding_window_size: int = -1,
|
||||
attend_to_scale_frames: bool = False,
|
||||
num_random_frames: int = 0,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = CausalAttention(dim=dim, num_heads=num_heads,
|
||||
qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
rope=rope, elementwise_attn_output_gate=elementwise_attn_output_gate,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only)
|
||||
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.num_random_frames = num_random_frames
|
||||
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
self.masks = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_blockwise_causal_attn_mask(self,
|
||||
device: torch.device | str, num_frames: int = 21,
|
||||
frame_seqlen: int = 1560, num_frame_per_block=1
|
||||
) -> BlockMask:
|
||||
"""
|
||||
we will divide the token sequence into the following format
|
||||
[1 latent frame] [1 latent frame] ... [1 latent frame]
|
||||
We use flexattention to construct the attention mask
|
||||
"""
|
||||
total_length = num_frames * frame_seqlen
|
||||
|
||||
# we do right padding to get to a multiple of 128
|
||||
padded_length = math.ceil(total_length / 128) * 128 - total_length
|
||||
|
||||
ends = torch.zeros(total_length + padded_length,
|
||||
device=device, dtype=torch.long)
|
||||
|
||||
# Block-wise causal mask will attend to all elements that are before the end of the current chunk
|
||||
frame_indices = torch.arange(
|
||||
start=0,
|
||||
end=total_length,
|
||||
step=frame_seqlen * num_frame_per_block,
|
||||
device=device
|
||||
)
|
||||
|
||||
for tmp in frame_indices:
|
||||
ends[tmp:tmp + frame_seqlen * num_frame_per_block] = tmp + \
|
||||
frame_seqlen * num_frame_per_block
|
||||
|
||||
def attention_mask(b, h, q_idx, kv_idx):
|
||||
return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
|
||||
# return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
|
||||
|
||||
block_mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + padded_length,
|
||||
KV_LEN=total_length + padded_length, device=device)
|
||||
|
||||
return block_mask
|
||||
|
||||
def forward(self, x: Tensor, pos=None, video_mask=None, num_frames=0, frame_seqlen=0, kv_cache=None, current_start=0, current_end=0, global_idx=0, num_frame_per_block=8, num_frame_for_scale=-1, sliding_window_size=None, enable_ulysses_cp=False, full_attention=False, enable_3d_rope=False, is_scale_frames=False) -> Tensor:
|
||||
# Use passed sliding_window_size if provided, otherwise use self.sliding_window_size
|
||||
effective_sliding_window_size = sliding_window_size if sliding_window_size is not None else self.sliding_window_size
|
||||
|
||||
# Fast path for full attention (camera head) - skip mask computation
|
||||
if full_attention:
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, full_attention=True, enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
mask_block = self._prepare_blockwise_causal_attn_mask(
|
||||
device=x.device, num_frames=num_frames, frame_seqlen=frame_seqlen, num_frame_per_block=num_frame_per_block)
|
||||
|
||||
|
||||
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), pos=pos, block_mask=mask_block, frame_seqlen=frame_seqlen, video_mask=video_mask, current_start=current_start, current_end=current_end, kv_cache=kv_cache, global_idx=global_idx, num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale, sliding_window_size=effective_sliding_window_size, attend_to_scale_frames=self.attend_to_scale_frames, num_random_frames=self.num_random_frames,
|
||||
enable_ulysses_cp=enable_ulysses_cp, enable_3d_rope=enable_3d_rope, is_scale_frames=is_scale_frames))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
class SDPABlock(nn.Module):
|
||||
"""
|
||||
SDPA variant for streaming inference. Uses F.scaled_dot_product_attention
|
||||
with dict-based KV cache. No FlashInfer dependency required.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
qk_norm: bool = False,
|
||||
rope=None,
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = SDPAAttention(
|
||||
dim=dim, num_heads=num_heads, qk_norm=qk_norm, qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, rope=rope,
|
||||
kv_cache_sliding_window=kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=kv_cache_camera_only,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = ffn_layer(in_features=dim, hidden_features=int(dim * mlp_ratio),
|
||||
act_layer=act_layer, drop=drop, bias=ffn_bias)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, pos=None, enable_ulysses_cp=False,
|
||||
num_patches=None, num_special=None, num_frames=None, enable_3d_rope=False,
|
||||
kv_cache=None, global_idx=0, num_frame_per_block=1,
|
||||
num_frame_for_scale=-1, num_register_tokens=4) -> Tensor:
|
||||
def attn_residual_func(x, pos=None):
|
||||
return self.ls1(self.attn(
|
||||
self.norm1(x), pos=pos, enable_ulysses_cp=enable_ulysses_cp,
|
||||
num_patches=num_patches, num_special=num_special, num_frames=num_frames,
|
||||
enable_3d_rope=enable_3d_rope, kv_cache=kv_cache, global_idx=global_idx,
|
||||
num_frame_per_block=num_frame_per_block, num_frame_for_scale=num_frame_for_scale,
|
||||
num_register_tokens=num_register_tokens,
|
||||
))
|
||||
|
||||
def ffn_residual_func(x):
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
||||
x = x + self.drop_path1(ffn_residual_func(x))
|
||||
else:
|
||||
x = x + attn_residual_func(x, pos=pos)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
34
lingbot_map/layers/drop_path.py
Normal file
34
lingbot_map/layers/drop_path.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
582
lingbot_map/layers/flashinfer_cache.py
Normal file
582
lingbot_map/layers/flashinfer_cache.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
FlashInfer KV Cache Manager — Two-Stream Paged Design.
|
||||
|
||||
Two logical streams sharing one physical page pool per layer:
|
||||
|
||||
Patch stream (recyclable):
|
||||
- page_size = patches_per_frame (256 for 224×224; 972 for 504×378)
|
||||
- Exactly 1 patch page per frame
|
||||
- Scale frames → scale_patch_pages (never evicted, maxlen=scale_frames)
|
||||
- Recent frames → live_window_patch_pages (evicted when > sliding_window)
|
||||
|
||||
Special stream (append-only, never recycled):
|
||||
- num_special_tokens (6) special tokens per frame
|
||||
- Packed continuously: one special page holds floor(page_size/6) frames
|
||||
e.g. page_size=256 → 42 frames per special page, 4 slots wasted
|
||||
- Specials written for EVERY frame (including scale + window), not just evicted ones.
|
||||
|
||||
Physical layout per block:
|
||||
kv_caches[block_idx]: [max_num_pages, 2, page_size, H, D]
|
||||
Pages 0 .. max_patch_pages-1 : patch page pool (recyclable)
|
||||
Pages max_patch_pages .. max_pages-1: special page pool (append-only)
|
||||
dim 1: 0=K 1=V
|
||||
|
||||
Attention computation:
|
||||
visible = scale_patch_pages + live_window_patch_pages + all_special_pages
|
||||
Special pages placed LAST → paged_kv_last_page_len naturally describes
|
||||
the partial special-tail without a custom mask.
|
||||
|
||||
plan() is called ONCE per frame step (when block_idx == 0).
|
||||
run() is called per layer, reusing the same plan. All layers at the
|
||||
same frame step have identical page structures (same page IDs in same
|
||||
positions), so reusing the plan across layers is correct.
|
||||
|
||||
Public API is drop-in compatible with the previous FlashInferKVCacheManager:
|
||||
append_frame(block_idx, k, v)
|
||||
evict_frames(block_idx, scale_frames, sliding_window, ...)
|
||||
compute_attention(block_idx, q) -> out
|
||||
reset()
|
||||
"""
|
||||
|
||||
import collections
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
FLASHINFER_AVAILABLE = True
|
||||
except ImportError:
|
||||
FLASHINFER_AVAILABLE = False
|
||||
|
||||
|
||||
class FlashInferKVCacheManager:
|
||||
"""
|
||||
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
|
||||
|
||||
Args:
|
||||
num_blocks: Number of Transformer blocks (one cache per block).
|
||||
max_num_frames: Maximum frames held in the KV window at once
|
||||
(scale_frames + sliding_window + headroom).
|
||||
tokens_per_frame: Total tokens per frame = patches + specials (e.g. 262).
|
||||
num_heads: Number of KV heads (= QO heads; MHA assumed).
|
||||
head_dim: Head dimension (64 for ViT-L).
|
||||
dtype: Storage dtype (bfloat16 / float16).
|
||||
device: CUDA device.
|
||||
num_special_tokens: Special tokens per frame: camera + register×N + scale (6).
|
||||
scale_frames: Number of always-resident scale frames (8).
|
||||
sliding_window: Sliding window size (64).
|
||||
max_total_frames: Upper bound on total frames ever processed; used to
|
||||
pre-allocate the special page pool (default 2048).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
max_num_frames: int,
|
||||
tokens_per_frame: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_special_tokens: int = 6,
|
||||
scale_frames: int = 8,
|
||||
sliding_window: int = 64,
|
||||
max_total_frames: int = 2048,
|
||||
force_fp32: bool = False,
|
||||
fa3: bool = False,
|
||||
):
|
||||
if not FLASHINFER_AVAILABLE:
|
||||
raise RuntimeError("FlashInfer is not available. Please install flashinfer.")
|
||||
|
||||
self.num_blocks = num_blocks
|
||||
self.num_special_tokens = num_special_tokens # 6
|
||||
self.patches_per_frame = tokens_per_frame - num_special_tokens # 256 / 999 / ...
|
||||
# Use exact page_size = patches_per_frame to eliminate zero-padded slots.
|
||||
# FA2 (backend="fa2") supports non-power-of-2 page sizes.
|
||||
# FA3 (sm90) requires power-of-2 page sizes; use next_power_of_2 when fa3=True.
|
||||
p = self.patches_per_frame
|
||||
if fa3:
|
||||
# Round up to next power-of-2 for FA3 SM90 kernel requirement.
|
||||
# e.g. 999 → 1024 (25 zero-padded slots per patch page)
|
||||
self.page_size = 1 << (p - 1).bit_length()
|
||||
else:
|
||||
self.page_size = p # exact: no zero padding in patch pages
|
||||
self.scale_frames = scale_frames # 8
|
||||
self.sliding_window = sliding_window # 64
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.tokens_per_frame = tokens_per_frame
|
||||
|
||||
assert self.patches_per_frame > 0, (
|
||||
f"tokens_per_frame={tokens_per_frame} <= num_special_tokens={num_special_tokens}"
|
||||
)
|
||||
assert self.page_size > 0
|
||||
|
||||
# force_fp32: bypass FlashInfer FA2 kernel (which only supports fp16/bf16) and
|
||||
# instead gather paged K/V into a dense tensor and use F.scaled_dot_product_attention
|
||||
# in fp32 for accuracy comparison. Storage dtype is also kept as fp32 in this mode.
|
||||
self.force_fp32 = force_fp32
|
||||
if force_fp32:
|
||||
self.dtype = torch.float32
|
||||
else:
|
||||
if dtype == torch.float32:
|
||||
dtype = torch.bfloat16
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
# ── Page pool sizing ─────────────────────────────────────────────────
|
||||
# Patch: scale + window + 16 headroom (pages recycled → fixed count)
|
||||
max_patch_pages = scale_frames + sliding_window + 16 # e.g. 88
|
||||
# Special: enough for max_total_frames × 6 tokens, plus 16 headroom
|
||||
max_special_pages = (
|
||||
math.ceil(max_total_frames * num_special_tokens / self.page_size) + 16
|
||||
)
|
||||
self.max_patch_pages = max_patch_pages
|
||||
self.max_num_pages = max_patch_pages + max_special_pages
|
||||
|
||||
# ── Physical paged KV caches ─────────────────────────────────────────
|
||||
# Shape per block: [max_num_pages, 2, page_size, H, D] (NHD, K=dim0, V=dim1)
|
||||
self.kv_caches: List[Tensor] = [
|
||||
torch.zeros(
|
||||
self.max_num_pages, 2, self.page_size, num_heads, head_dim,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
# ── Per-block state ──────────────────────────────────────────────────
|
||||
# Patch pages (IDs 0 .. max_patch_pages-1)
|
||||
self.scale_patch_pages: List[collections.deque] = [
|
||||
collections.deque() for _ in range(num_blocks)
|
||||
]
|
||||
self.live_window_patch_pages: List[collections.deque] = [
|
||||
collections.deque() for _ in range(num_blocks)
|
||||
]
|
||||
self.free_patch_pages: List[List[int]] = [
|
||||
list(range(max_patch_pages)) for _ in range(num_blocks)
|
||||
]
|
||||
|
||||
# Special pages (IDs max_patch_pages .. max_num_pages-1)
|
||||
self.all_special_pages: List[List[int]] = [[] for _ in range(num_blocks)]
|
||||
self.free_special_pages: List[List[int]] = [
|
||||
list(range(max_patch_pages, self.max_num_pages)) for _ in range(num_blocks)
|
||||
]
|
||||
self.special_token_count: List[int] = [0] * num_blocks
|
||||
|
||||
# Frame counter per block (determines scale vs window routing)
|
||||
self.frame_count: List[int] = [0] * num_blocks
|
||||
|
||||
# ── FlashInfer wrapper ───────────────────────────────────────────────
|
||||
# plan() is called once per frame step (block_idx == 0).
|
||||
# run() is called per layer, reusing the same aux structures.
|
||||
# backend: "fa2" (default) or "fa3" (SM90/H100, requires power-of-2 page_size).
|
||||
# FA2 supports non-power-of-2 page sizes and avoids a FA3 NaN bug seen in
|
||||
# FlashInfer 0.2.5 at 518×378 resolution.
|
||||
_fi_backend = "fa3" if fa3 else "fa2"
|
||||
self.workspace_buffer = torch.zeros(
|
||||
128 * 1024 * 1024, dtype=torch.uint8, device=device
|
||||
)
|
||||
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
backend=_fi_backend,
|
||||
)
|
||||
|
||||
# plan() inputs (indices/indptr built fresh each step; qo_indptr is fixed)
|
||||
self._qo_indptr = torch.tensor(
|
||||
[0, tokens_per_frame], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Public API (drop-in compatible with previous FlashInferKVCacheManager)
|
||||
# =========================================================================
|
||||
|
||||
def append_frame(self, block_idx: int, k: Tensor, v: Tensor) -> None:
|
||||
"""
|
||||
Append one frame's K/V tensors to the two-stream cache.
|
||||
|
||||
Token layout must be: [camera, reg0, ..., regN, scale, patch0, ..., patchP-1]
|
||||
i.e. specials come first (matching stream.py's patch_start_idx convention).
|
||||
|
||||
Args:
|
||||
block_idx: Block/layer index (0 … num_blocks-1).
|
||||
k: [tokens_per_frame, H, D] NHD layout.
|
||||
v: [tokens_per_frame, H, D] NHD layout.
|
||||
"""
|
||||
n = self.num_special_tokens # 6
|
||||
sp_k = k[:n].to(self.dtype) # [6, H, D]
|
||||
patch_k = k[n:].to(self.dtype) # [256, H, D]
|
||||
sp_v = v[:n].to(self.dtype)
|
||||
patch_v = v[n:].to(self.dtype)
|
||||
|
||||
assert patch_k.shape[0] == self.patches_per_frame, (
|
||||
f"block {block_idx}: expected {self.patches_per_frame} patch tokens, "
|
||||
f"got {patch_k.shape[0]} (tokens_per_frame={k.shape[0]})"
|
||||
)
|
||||
|
||||
self._write_patch_page(block_idx, patch_k, patch_v)
|
||||
self._write_special_tokens(block_idx, sp_k, sp_v)
|
||||
self.frame_count[block_idx] += 1
|
||||
|
||||
def evict_frames(
|
||||
self,
|
||||
block_idx: int,
|
||||
scale_frames: int,
|
||||
sliding_window: int,
|
||||
cross_frame_special: bool = True,
|
||||
include_scale_frames: bool = True,
|
||||
camera_only: bool = False,
|
||||
num_register_tokens: int = 4,
|
||||
) -> None:
|
||||
"""
|
||||
Evict old window patch pages (recycle to free list).
|
||||
|
||||
Special pages are NEVER evicted.
|
||||
Scale pages are NEVER evicted.
|
||||
Only live_window_patch_pages beyond `sliding_window` are recycled.
|
||||
"""
|
||||
while len(self.live_window_patch_pages[block_idx]) > sliding_window:
|
||||
old_page = self.live_window_patch_pages[block_idx].popleft()
|
||||
self.free_patch_pages[block_idx].append(old_page)
|
||||
|
||||
def _gather_kv(self, block_idx: int):
|
||||
"""
|
||||
Gather all visible K and V tokens from the paged cache into dense tensors.
|
||||
|
||||
Used by force_fp32 mode to bypass the FlashInfer FA2 kernel (which only
|
||||
supports fp16/bf16) and instead run F.scaled_dot_product_attention in fp32.
|
||||
|
||||
Returns:
|
||||
k_flat: [kv_len, H, D] — all visible K tokens concatenated
|
||||
v_flat: [kv_len, H, D] — all visible V tokens concatenated
|
||||
"""
|
||||
visible = self.build_visible_page_table(block_idx)
|
||||
last_len = self.compute_last_page_len(block_idx)
|
||||
P = self.page_size
|
||||
|
||||
parts_k, parts_v = [], []
|
||||
for i, pid in enumerate(visible):
|
||||
n = last_len if (i == len(visible) - 1) else P
|
||||
parts_k.append(self.kv_caches[block_idx][pid, 0, :n]) # [n, H, D]
|
||||
parts_v.append(self.kv_caches[block_idx][pid, 1, :n])
|
||||
|
||||
k_flat = torch.cat(parts_k, dim=0) # [kv_len, H, D]
|
||||
v_flat = torch.cat(parts_v, dim=0)
|
||||
return k_flat, v_flat
|
||||
|
||||
def compute_attention(self, block_idx: int, q: Tensor) -> Tensor:
|
||||
"""
|
||||
Compute cross-frame attention using FlashInfer BatchPrefillWithPagedKVCacheWrapper.
|
||||
|
||||
When self.force_fp32 is True, gathers all visible K/V into dense tensors
|
||||
and uses F.scaled_dot_product_attention in fp32 instead of the FA2 kernel.
|
||||
This is used for accuracy comparison since FlashInfer FA2 only supports fp16/bf16.
|
||||
|
||||
plan() is called once per frame step (when block_idx == 0).
|
||||
All layers at the same step share the same visible page structure,
|
||||
so the plan is reused by calling run() with each layer's kv_cache.
|
||||
|
||||
Args:
|
||||
block_idx: Block/layer index.
|
||||
q: [q_len, H, D] NHD layout (q_len = tokens_per_frame = 262).
|
||||
|
||||
Returns:
|
||||
out: [q_len, H, D]
|
||||
"""
|
||||
if self.frame_count[block_idx] == 0:
|
||||
# No KV present yet (should not occur in normal usage after append_frame)
|
||||
return torch.zeros_like(q)
|
||||
|
||||
if self.force_fp32:
|
||||
# ── fp32 gather+SDPA path ─────────────────────────────────────────
|
||||
# Gather visible K/V from paged cache and run SDPA in fp32.
|
||||
# This bypasses the FlashInfer FA2 kernel (fp16/bf16 only) for accuracy.
|
||||
# q_len, H, D → 1, H, q_len, D (SDPA expects BHsD layout)
|
||||
import torch.nn.functional as F_nn
|
||||
k_flat, v_flat = self._gather_kv(block_idx)
|
||||
q_b = q.float().permute(1, 0, 2).unsqueeze(0) # [1, H, q_len, D]
|
||||
k_b = k_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
||||
v_b = v_flat.float().permute(1, 0, 2).unsqueeze(0) # [1, H, kv_len, D]
|
||||
out = F_nn.scaled_dot_product_attention(q_b, k_b, v_b)
|
||||
return out.squeeze(0).permute(1, 0, 2).to(q.dtype) # [q_len, H, D]
|
||||
|
||||
if block_idx == 0:
|
||||
# ── Plan once per frame step ──────────────────────────────────────
|
||||
# Build visible page table from block 0's state.
|
||||
# All blocks have identical page structures, so this plan is valid
|
||||
# for all subsequent run() calls (block_idx = 1, 2, ...).
|
||||
visible = self.build_visible_page_table(0)
|
||||
last_len = self.compute_last_page_len(0)
|
||||
|
||||
assert visible, "visible page table is empty after append_frame"
|
||||
assert 1 <= last_len <= self.page_size, (
|
||||
f"block 0: last_page_len={last_len} out of [1, {self.page_size}]"
|
||||
)
|
||||
|
||||
paged_kv_indices = torch.tensor(visible, dtype=torch.int32, device=self.device)
|
||||
paged_kv_indptr = torch.tensor([0, len(visible)], dtype=torch.int32, device=self.device)
|
||||
paged_kv_last_page_len = torch.tensor([last_len], dtype=torch.int32, device=self.device)
|
||||
|
||||
self.prefill_wrapper.plan(
|
||||
self._qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len,
|
||||
num_qo_heads = self.num_heads,
|
||||
num_kv_heads = self.num_heads,
|
||||
head_dim_qk = self.head_dim,
|
||||
page_size = self.page_size,
|
||||
causal = False, # custom page ordering; no causal mask
|
||||
pos_encoding_mode = "NONE", # RoPE applied externally before append
|
||||
q_data_type = self.dtype,
|
||||
)
|
||||
|
||||
# ── Run attention for this layer ──────────────────────────────────────
|
||||
# Cast q to storage dtype (LayerNorm may upcast to float32 under autocast).
|
||||
return self.prefill_wrapper.run(
|
||||
q = q.to(self.dtype).contiguous(),
|
||||
paged_kv_cache = self.kv_caches[block_idx],
|
||||
) # → [q_len, H, D]
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all per-block state for a new sequence."""
|
||||
for i in range(self.num_blocks):
|
||||
self.scale_patch_pages[i].clear()
|
||||
self.live_window_patch_pages[i].clear()
|
||||
self.all_special_pages[i].clear()
|
||||
self.free_patch_pages[i] = list(range(self.max_patch_pages))
|
||||
self.free_special_pages[i] = list(range(self.max_patch_pages, self.max_num_pages))
|
||||
self.special_token_count[i] = 0
|
||||
self.frame_count[i] = 0
|
||||
|
||||
# =========================================================================
|
||||
# Helper methods
|
||||
# =========================================================================
|
||||
|
||||
def build_visible_page_table(self, block_idx: int) -> List[int]:
|
||||
"""
|
||||
Return page IDs in strict order: scale → window → special.
|
||||
|
||||
Placing special pages last means only the final page may be partially
|
||||
full, so paged_kv_last_page_len = compute_last_page_len() is sufficient
|
||||
without a custom attention mask.
|
||||
"""
|
||||
return (
|
||||
list(self.scale_patch_pages[block_idx]) +
|
||||
list(self.live_window_patch_pages[block_idx]) +
|
||||
list(self.all_special_pages[block_idx])
|
||||
)
|
||||
|
||||
def compute_last_page_len(self, block_idx: int) -> int:
|
||||
"""
|
||||
Valid token count in the last page of the visible sequence.
|
||||
|
||||
- No special pages → last page is a patch page.
|
||||
Returns patches_per_frame (real tokens written),
|
||||
which may be < page_size when page_size was rounded
|
||||
up to a power of 2.
|
||||
- Special tail partial → special_token_count % page_size.
|
||||
- Special tail exactly full → page_size.
|
||||
"""
|
||||
if not self.all_special_pages[block_idx]:
|
||||
# Last page is a patch page. We wrote patches_per_frame tokens (0..P-1);
|
||||
# positions P..page_size-1 are zero padding. Tell FlashInfer the true
|
||||
# valid count so it doesn't read beyond the real tokens.
|
||||
return self.patches_per_frame
|
||||
|
||||
tail = self.special_token_count[block_idx] % self.page_size
|
||||
return self.page_size if tail == 0 else tail
|
||||
|
||||
# ── Internal write helpers ────────────────────────────────────────────────
|
||||
|
||||
def _write_patch_page(self, block_idx: int, patch_k: Tensor, patch_v: Tensor) -> int:
|
||||
"""
|
||||
Allocate one free patch page and write patches_per_frame patch tokens.
|
||||
|
||||
Direct tensor assignment to kv_caches[block_idx][page_id, 0/1] avoids
|
||||
the Python→C++/CUDA dispatch overhead of flashinfer.page.append_paged_kv_cache.
|
||||
kv_caches layout: [max_num_pages, 2, page_size, H, D] (NHD, K=0, V=1).
|
||||
patch_k/v fill exactly one full page (patches_per_frame == page_size).
|
||||
|
||||
Routes to scale_patch_pages if still filling scale quota,
|
||||
otherwise to live_window_patch_pages.
|
||||
|
||||
Returns:
|
||||
page_id: Physical page index used.
|
||||
"""
|
||||
assert self.free_patch_pages[block_idx], (
|
||||
f"block {block_idx}: patch page pool exhausted — "
|
||||
f"scale={len(self.scale_patch_pages[block_idx])}, "
|
||||
f"window={len(self.live_window_patch_pages[block_idx])}, "
|
||||
f"free={len(self.free_patch_pages[block_idx])}"
|
||||
)
|
||||
|
||||
page_id = self.free_patch_pages[block_idx].pop()
|
||||
|
||||
# Direct slice write: positions 0..patches_per_frame-1.
|
||||
# When page_size == patches_per_frame (power-of-2 aligned, e.g. 256 for 224×224),
|
||||
# this is equivalent to a full-page write. When page_size > patches_per_frame
|
||||
# (rounded up for FA3 alignment, e.g. page_size=1024 for patches_per_frame=999),
|
||||
# positions patches_per_frame..page_size-1 remain zero (kv_caches is zero-init).
|
||||
P = self.patches_per_frame
|
||||
self.kv_caches[block_idx][page_id, 0, :P] = patch_k # K
|
||||
self.kv_caches[block_idx][page_id, 1, :P] = patch_v # V
|
||||
|
||||
if len(self.scale_patch_pages[block_idx]) < self.scale_frames:
|
||||
self.scale_patch_pages[block_idx].append(page_id)
|
||||
else:
|
||||
self.live_window_patch_pages[block_idx].append(page_id)
|
||||
|
||||
return page_id
|
||||
|
||||
def _write_special_tokens(self, block_idx: int, sp_k: Tensor, sp_v: Tensor) -> None:
|
||||
"""
|
||||
Append num_special_tokens (6) special tokens to the special stream.
|
||||
|
||||
Direct tensor slice assignment to kv_caches[block_idx][tail_page, 0/1,
|
||||
tail_offset : tail_offset+write_n] avoids the Python→C++/CUDA dispatch
|
||||
overhead of flashinfer.page.append_paged_kv_cache.
|
||||
|
||||
Handles page-boundary crossing: if 6 tokens straddle two pages, performs
|
||||
two slice writes (rare — page_size=256 >> 6).
|
||||
"""
|
||||
remaining = self.num_special_tokens # 6
|
||||
written = 0
|
||||
|
||||
while remaining > 0:
|
||||
tail_offset = self.special_token_count[block_idx] % self.page_size
|
||||
|
||||
if tail_offset == 0:
|
||||
# Current tail page is full (or no page exists) — allocate a new one
|
||||
assert self.free_special_pages[block_idx], (
|
||||
f"block {block_idx}: special page pool exhausted at "
|
||||
f"special_token_count={self.special_token_count[block_idx]}. "
|
||||
f"Increase max_total_frames."
|
||||
)
|
||||
new_page = self.free_special_pages[block_idx].pop()
|
||||
self.all_special_pages[block_idx].append(new_page)
|
||||
|
||||
tail_page = self.all_special_pages[block_idx][-1]
|
||||
space = self.page_size - tail_offset # free slots in tail page
|
||||
write_n = min(remaining, space)
|
||||
|
||||
# Direct slice write: kv_caches[block_idx][tail_page, 0/1, offset:offset+n]
|
||||
# shape: [page_size, H, D]; slice [tail_offset:tail_offset+write_n, :, :]
|
||||
end = tail_offset + write_n
|
||||
self.kv_caches[block_idx][tail_page, 0, tail_offset:end] = sp_k[written:written + write_n]
|
||||
self.kv_caches[block_idx][tail_page, 1, tail_offset:end] = sp_v[written:written + write_n]
|
||||
|
||||
self.special_token_count[block_idx] += write_n
|
||||
written += write_n
|
||||
remaining -= write_n
|
||||
|
||||
# ── Legacy property (used by stream.py) ──────────────────────────────────
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of frames appended to block 0 (representative)."""
|
||||
return self.frame_count[0] if self.frame_count else 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sanity check
|
||||
# =============================================================================
|
||||
|
||||
def _sanity_check():
|
||||
"""
|
||||
Minimal smoke test.
|
||||
Run with: python -c "from lingbot_map.layers.flashinfer_cache import _sanity_check; _sanity_check()"
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
if not torch.cuda.is_available():
|
||||
print("[sanity_check] CUDA not available — skipping.")
|
||||
return
|
||||
|
||||
tokens_per_frame = 262 # 256 patch + 6 special (224×224)
|
||||
num_special = 6
|
||||
patches_per_frame = tokens_per_frame - num_special # 256
|
||||
page_size = patches_per_frame # 256
|
||||
|
||||
mgr = FlashInferKVCacheManager(
|
||||
num_blocks = 2,
|
||||
max_num_frames = 88,
|
||||
tokens_per_frame = tokens_per_frame,
|
||||
num_heads = 16,
|
||||
head_dim = 64,
|
||||
dtype = torch.bfloat16,
|
||||
device = device,
|
||||
num_special_tokens = num_special,
|
||||
scale_frames = 8,
|
||||
sliding_window = 64,
|
||||
max_total_frames = 200,
|
||||
)
|
||||
|
||||
def make_kv():
|
||||
k = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
v = torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
return k, v
|
||||
|
||||
def make_q():
|
||||
return torch.randn(tokens_per_frame, 16, 64, dtype=torch.bfloat16, device=device)
|
||||
|
||||
for block in range(2):
|
||||
for t in range(100):
|
||||
k, v = make_kv()
|
||||
mgr.append_frame(block, k, v)
|
||||
mgr.evict_frames(block, scale_frames=8, sliding_window=64)
|
||||
|
||||
# ── Page count checks ───────────────────────────────────────────────
|
||||
n_scale = len(mgr.scale_patch_pages[block])
|
||||
n_window = len(mgr.live_window_patch_pages[block])
|
||||
n_spec = len(mgr.all_special_pages[block])
|
||||
sp_count = mgr.special_token_count[block]
|
||||
|
||||
assert n_scale == 8, f"block {block}: scale pages = {n_scale}, expected 8"
|
||||
assert n_window == 64, f"block {block}: window pages = {n_window}, expected 64"
|
||||
# 100 frames × 6 specials = 600 tokens; ceil(600/256) = 3 pages
|
||||
expected_spec_pages = math.ceil(100 * num_special / page_size)
|
||||
assert n_spec == expected_spec_pages, (
|
||||
f"block {block}: special pages = {n_spec}, expected {expected_spec_pages}"
|
||||
)
|
||||
assert sp_count == 100 * num_special, (
|
||||
f"block {block}: special_token_count = {sp_count}, expected {100*num_special}"
|
||||
)
|
||||
|
||||
# ── last_page_len ────────────────────────────────────────────────────
|
||||
last_len = mgr.compute_last_page_len(block)
|
||||
tail = sp_count % page_size
|
||||
expected_len = page_size if tail == 0 else tail
|
||||
assert last_len == expected_len, f"block {block}: last_len={last_len}, expected={expected_len}"
|
||||
|
||||
# ── visible page table order ─────────────────────────────────────────
|
||||
visible = mgr.build_visible_page_table(block)
|
||||
assert len(visible) == n_scale + n_window + n_spec, "visible page count mismatch"
|
||||
for pid in visible[:n_scale + n_window]:
|
||||
assert pid < mgr.max_patch_pages, f"patch page {pid} out of patch range"
|
||||
for pid in visible[n_scale + n_window:]:
|
||||
assert pid >= mgr.max_patch_pages, f"special page {pid} not in special range"
|
||||
|
||||
# ── forward pass: plan() once for block 0, run() for both blocks ─────
|
||||
if block == 1:
|
||||
# Simulate the actual calling pattern: plan on block 0, run on both
|
||||
q0 = make_q()
|
||||
out0 = mgr.compute_attention(0, q0) # triggers plan()
|
||||
q1 = make_q()
|
||||
out1 = mgr.compute_attention(1, q1) # reuses plan, different kv_cache
|
||||
assert out0.shape == (tokens_per_frame, 16, 64)
|
||||
assert out1.shape == (tokens_per_frame, 16, 64)
|
||||
|
||||
print(f"[block {block}] PASS: scale={n_scale}, window={n_window}, "
|
||||
f"special_pages={n_spec}, special_tokens={sp_count}, "
|
||||
f"last_page_len={last_len}")
|
||||
|
||||
mgr.reset()
|
||||
assert mgr.frame_count[0] == 0
|
||||
print("\n[sanity_check] All assertions passed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_sanity_check()
|
||||
22
lingbot_map/layers/layer_scale.py
Normal file
22
lingbot_map/layers/layer_scale.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
40
lingbot_map/layers/mlp.py
Normal file
40
lingbot_map/layers/mlp.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
85
lingbot_map/layers/patch_embed.py
Normal file
85
lingbot_map/layers/patch_embed.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
474
lingbot_map/layers/rope.py
Normal file
474
lingbot_map/layers/rope.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
||||
|
||||
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
||||
# which extends the original RoPE concept to handle 2D spatial positions.
|
||||
|
||||
# Inspired by:
|
||||
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
||||
# https://github.com/naver-ai/rope-vit
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
||||
class PositionGetter:
|
||||
"""Generates and caches 2D spatial positions for patches in a grid.
|
||||
|
||||
This class efficiently manages the generation of spatial coordinates for patches
|
||||
in a 2D grid, caching results to avoid redundant computations.
|
||||
|
||||
Attributes:
|
||||
position_cache: Dictionary storing precomputed position tensors for different
|
||||
grid dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the position generator with an empty cache."""
|
||||
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generates spatial positions for a batch of patches.
|
||||
|
||||
Args:
|
||||
batch_size: Number of samples in the batch.
|
||||
height: Height of the grid in patches.
|
||||
width: Width of the grid in patches.
|
||||
device: Target device for the position tensor.
|
||||
|
||||
Returns:
|
||||
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
||||
for each position in the grid, repeated for each batch item.
|
||||
"""
|
||||
if (height, width) not in self.position_cache:
|
||||
y_coords = torch.arange(height, device=device)
|
||||
x_coords = torch.arange(width, device=device)
|
||||
positions = torch.cartesian_prod(y_coords, x_coords)
|
||||
self.position_cache[height, width] = positions
|
||||
|
||||
cached_positions = self.position_cache[height, width]
|
||||
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(nn.Module):
|
||||
"""2D Rotary Position Embedding implementation.
|
||||
|
||||
This module applies rotary position embeddings to input tokens based on their
|
||||
2D spatial positions. It handles the position-dependent rotation of features
|
||||
separately for vertical and horizontal dimensions.
|
||||
|
||||
Args:
|
||||
frequency: Base frequency for the position embeddings. Default: 100.0
|
||||
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
||||
|
||||
Attributes:
|
||||
base_frequency: Base frequency for computing position embeddings.
|
||||
scaling_factor: Factor to scale the computed frequencies.
|
||||
frequency_cache: Cache for storing precomputed frequency components.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
||||
"""Initializes the 2D RoPE module."""
|
||||
super().__init__()
|
||||
self.base_frequency = frequency
|
||||
self.scaling_factor = scaling_factor
|
||||
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
def _compute_frequency_components(
|
||||
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Computes frequency components for rotary embeddings.
|
||||
|
||||
Args:
|
||||
dim: Feature dimension (must be even).
|
||||
seq_len: Maximum sequence length.
|
||||
device: Target device for computations.
|
||||
dtype: Data type for the computed tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (cosine, sine) tensors for frequency components.
|
||||
"""
|
||||
cache_key = (dim, seq_len, device, dtype)
|
||||
if cache_key not in self.frequency_cache:
|
||||
# Compute frequency bands
|
||||
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency**exponents)
|
||||
|
||||
# Generate position-dependent frequencies
|
||||
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
||||
|
||||
# Compute and cache frequency components
|
||||
angles = angles.to(dtype)
|
||||
angles = torch.cat((angles, angles), dim=-1)
|
||||
cos_components = angles.cos().to(dtype)
|
||||
sin_components = angles.sin().to(dtype)
|
||||
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
||||
|
||||
return self.frequency_cache[cache_key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Performs feature rotation by splitting and recombining feature dimensions.
|
||||
|
||||
Args:
|
||||
x: Input tensor to rotate.
|
||||
|
||||
Returns:
|
||||
Rotated feature tensor.
|
||||
"""
|
||||
feature_dim = x.shape[-1]
|
||||
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d_rope(
|
||||
self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Applies 1D rotary position embeddings along one dimension.
|
||||
|
||||
Args:
|
||||
tokens: Input token features.
|
||||
positions: Position indices.
|
||||
cos_comp: Cosine components for rotation.
|
||||
sin_comp: Sine components for rotation.
|
||||
|
||||
Returns:
|
||||
Tokens with applied rotary position embeddings.
|
||||
"""
|
||||
# Embed positions with frequency components
|
||||
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
||||
|
||||
# Apply rotation
|
||||
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies 2D rotary position embeddings to input tokens.
|
||||
|
||||
Args:
|
||||
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
||||
The feature dimension (dim) must be divisible by 4.
|
||||
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
||||
the y and x coordinates for each token.
|
||||
|
||||
Returns:
|
||||
Tensor of same shape as input with applied 2D rotary position embeddings.
|
||||
|
||||
Raises:
|
||||
AssertionError: If input dimensions are invalid or positions are malformed.
|
||||
"""
|
||||
# Validate inputs
|
||||
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
||||
assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
|
||||
|
||||
# Compute feature dimension for each spatial direction
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
|
||||
# Get frequency components
|
||||
max_position = int(positions.max()) + 1
|
||||
cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
|
||||
|
||||
# Split features for vertical and horizontal processing
|
||||
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
||||
|
||||
# Apply RoPE separately for each dimension
|
||||
vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
|
||||
horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
|
||||
|
||||
# Combine processed features
|
||||
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
||||
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[np.ndarray, int],
|
||||
theta: float = 10000.0,
|
||||
use_real=False,
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
计算1D旋转位置编码(RoPE)的频率张量。
|
||||
|
||||
RoPE的核心思想:使用旋转矩阵来编码位置信息,使得相对位置关系保持不变。
|
||||
公式:对于位置m和维度i,频率为 θ_i = θ^(-2i/d),其中θ是基础频率(默认10000)
|
||||
|
||||
Args:
|
||||
dim: 特征维度,必须是偶数(因为要成对处理)
|
||||
pos: 位置索引,可以是整数(自动生成0到pos-1的序列)或位置数组 [S]
|
||||
theta: 基础频率,控制位置编码的周期性(默认10000)
|
||||
use_real: 是否返回实数形式(cos和sin分开)还是复数形式
|
||||
linear_factor: 线性缩放因子,用于上下文扩展
|
||||
ntk_factor: NTK-Aware缩放因子,用于处理更长的序列
|
||||
repeat_interleave_real: 当use_real=True时,是否交错重复(用于某些模型架构)
|
||||
freqs_dtype: 频率张量的数据类型
|
||||
|
||||
Returns:
|
||||
复数形式:[S, D/2] 的复数张量,表示 e^(i*m*θ_j)
|
||||
实数形式:两个 [S, D] 的张量(cos和sin)
|
||||
"""
|
||||
# 确保维度是偶数(RoPE需要成对处理维度)
|
||||
assert dim % 2 == 0
|
||||
|
||||
# 将位置转换为torch张量
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos) # 生成 [0, 1, 2, ..., pos-1]
|
||||
if isinstance(pos, np.ndarray):
|
||||
pos = torch.from_numpy(pos) # [S]
|
||||
|
||||
# 应用NTK缩放(Neural Tangent Kernel,用于处理训练时未见过的长序列)
|
||||
theta = theta * ntk_factor
|
||||
|
||||
# 步骤1:计算频率 θ_i = 1 / (θ^(2i/d))
|
||||
# 其中 i ∈ {0, 2, 4, ..., dim-2}(只取偶数索引,因为成对处理)
|
||||
# 公式:freq_i = 1 / (theta^(2i/d) * linear_factor)
|
||||
freqs = (
|
||||
1.0
|
||||
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
||||
/ linear_factor
|
||||
) # [D/2],每个频率对应一个维度对
|
||||
|
||||
# 步骤2:计算位置-频率矩阵
|
||||
# 使用外积:pos[m] * freqs[i] = m * θ_i
|
||||
# 结果:每个位置m和每个频率i的组合
|
||||
freqs = torch.outer(pos, freqs) # [S, D/2]
|
||||
|
||||
# 步骤3:根据返回格式转换
|
||||
if use_real and repeat_interleave_real:
|
||||
# 方式1:交错重复(用于flux, hunyuan-dit, cogvideox等模型)
|
||||
# 将每个频率的cos和sin交错排列:[cos_0, cos_0, cos_1, cos_1, ...]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
# 方式2:拼接重复(用于stable audio, allegro等模型)
|
||||
# 将所有cos拼接,然后是所有sin:[cos_0, cos_1, ..., cos_n, cos_0, cos_1, ..., cos_n]
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
# 方式3:复数形式(用于lumina等模型)
|
||||
# 使用欧拉公式:e^(iθ) = cos(θ) + i*sin(θ)
|
||||
# torch.polar(r, θ) 返回 r * e^(iθ),这里r=1,所以就是 e^(i*freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64: [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
"""
|
||||
3D旋转位置编码(3D RoPE)模块
|
||||
|
||||
核心思想:将RoPE扩展到3D空间(时间、高度、宽度),为视频或3D数据提供位置编码。
|
||||
每个维度(t, h, w)独立使用RoPE,然后拼接起来。
|
||||
|
||||
公式:
|
||||
对于3D位置 (f, h, w)(帧、高度、宽度):
|
||||
- 帧维度使用 dim_f 个特征维度
|
||||
- 高度维度使用 dim_h 个特征维度
|
||||
- 宽度维度使用 dim_w 个特征维度
|
||||
其中 dim_f + dim_h + dim_w = attention_head_dim
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int = 1024,
|
||||
theta: float = 10000.0,
|
||||
fhw_dim: Optional[Tuple[int, int, int]] = [20, 22, 22],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim # 注意力头的总维度
|
||||
self.patch_size = patch_size # patch大小 (patch_f, patch_h, patch_w)
|
||||
self.max_seq_len = max_seq_len # 最大序列长度(用于预计算频率)
|
||||
|
||||
# 步骤1:分配维度给三个空间维度
|
||||
if fhw_dim is not None:
|
||||
# 如果指定了维度分配,使用指定的
|
||||
assert attention_head_dim == sum(
|
||||
fhw_dim
|
||||
), f"attention_head_dim {attention_head_dim} must match sum(fhw_dim) {sum(fhw_dim)}"
|
||||
t_dim, h_dim, w_dim = fhw_dim
|
||||
else:
|
||||
# 否则自动分配:h和w各占1/3,t占剩余
|
||||
# 例如:如果attention_head_dim=64,则 h_dim=w_dim=21,t_dim=22
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 保存维度分配以便在forward中使用
|
||||
self.fhw_dim = (t_dim, h_dim, w_dim)
|
||||
|
||||
# 步骤2:为每个维度预计算频率
|
||||
# 分别计算时间、高度、宽度三个维度的RoPE频率
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
# 每个维度独立调用1D RoPE
|
||||
# 返回复数形式的频率: [max_seq_len, dim//2]
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||||
)
|
||||
freqs.append(freq)
|
||||
# 将三个维度的频率在最后一维拼接: [max_seq_len, (t_dim + h_dim + w_dim)//2]
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
|
||||
def forward(self, ppf, pph, ppw, patch_start_idx, device: torch.device, f_start: int = 0, f_end: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
前向传播:为3D输入(视频帧+patch)生成旋转位置编码
|
||||
|
||||
参数:
|
||||
- ppf (int): 帧数(patches per frame),当f_end为None时使用
|
||||
- pph (int): 每帧的patch高度数量
|
||||
- ppw (int): 每帧的patch宽度数量
|
||||
- patch_start_idx (int): 每帧的特殊token数量(在patches之前)
|
||||
- device: 计算设备(CPU/GPU)
|
||||
- f_start (int): 起始帧索引(用于causal模式),默认为0
|
||||
- f_end (Optional[int]): 结束帧索引(用于causal模式),如果为None则使用ppf作为帧数
|
||||
|
||||
返回:
|
||||
- freqs: [1, 1, ppf * (patch_start_idx + pph * ppw), head_dim//2] 复数频率tensor
|
||||
|
||||
Token排列顺序:
|
||||
[frame0_special_token_0, ..., frame0_special_token_N,
|
||||
frame0_patch_0, ..., frame0_patch_M,
|
||||
frame1_special_token_0, ..., frame1_special_token_N,
|
||||
frame1_patch_0, ..., frame1_patch_M,
|
||||
...]
|
||||
|
||||
模式:
|
||||
- 非causal模式:f_end=None,使用ppf作为帧数,从位置0开始
|
||||
- Causal模式:f_end不为None,使用[f_start, f_end)范围的帧,ppf会被重新计算
|
||||
"""
|
||||
|
||||
# 步骤1:将预计算的频率移到目标设备,并分割成三个维度
|
||||
self.freqs = self.freqs.to(device)
|
||||
# 获取实际的维度分配
|
||||
if hasattr(self, 'fhw_dim') and self.fhw_dim is not None:
|
||||
t_dim, h_dim, w_dim = self.fhw_dim
|
||||
else:
|
||||
# 自动分配的情况
|
||||
h_dim = w_dim = 2 * (self.attention_head_dim // 6)
|
||||
t_dim = self.attention_head_dim - h_dim - w_dim
|
||||
|
||||
# 使用正确的split sizes(每个维度的一半)
|
||||
freqs = self.freqs.split_with_sizes(
|
||||
[
|
||||
t_dim // 2, # 时间维度
|
||||
h_dim // 2, # 高度维度
|
||||
w_dim // 2, # 宽度维度
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# 处理causal模式:如果指定了f_end,重新计算ppf和帧范围
|
||||
if f_end is not None:
|
||||
ppf = f_end - f_start
|
||||
frame_slice = slice(f_start, f_end)
|
||||
else:
|
||||
# 非causal模式:使用从0开始的ppf个帧
|
||||
frame_slice = slice(0, ppf)
|
||||
|
||||
# 步骤2:处理特殊token(如果存在)
|
||||
## For other tokens
|
||||
if patch_start_idx > 0:
|
||||
# 2.1 为特殊token生成位置编码
|
||||
# 特殊token位于对角线位置 (f, i, i),每个特殊token有唯一位置
|
||||
# camera: (f, 0, 0), register_0: (f, 1, 1), ..., scale: (f, 5, 5)
|
||||
# Shape: (ppf, patch_start_idx, dim)
|
||||
freqs_special_f = freqs[0][frame_slice].reshape(ppf, 1, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_f) 帧维度变化
|
||||
freqs_special_h = freqs[1][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_h) 高度=0,1,2,...
|
||||
freqs_special_w = freqs[2][:patch_start_idx].reshape(1, patch_start_idx, -1).expand(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim_w) 宽度=0,1,2,...
|
||||
freqs_special = torch.cat([freqs_special_f, freqs_special_h, freqs_special_w], dim=-1) # (ppf, patch_start_idx, dim) 拼接三维
|
||||
freqs_special = freqs_special.reshape(ppf, patch_start_idx, -1) # (ppf, patch_start_idx, dim)
|
||||
|
||||
# 2.2 为图像patch生成位置编码
|
||||
# Patch位于 (f, patch_start_idx+h, patch_start_idx+w),h,w 整体偏移 patch_start_idx
|
||||
# 这样 patches 与 special tokens 位置不冲突,且 h,w 对称处理
|
||||
# Shape: (ppf, pph, ppw, dim)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][patch_start_idx : patch_start_idx + pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从patch_start_idx开始
|
||||
freqs_w = freqs[2][patch_start_idx : patch_start_idx + ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从patch_start_idx开始
|
||||
freqs_patches = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1) # (ppf, pph, ppw, dim) 拼接三维
|
||||
freqs_patches = freqs_patches.reshape(ppf, pph * ppw, -1) # (ppf, pph * ppw, dim) 展平空间维度
|
||||
|
||||
# 步骤3:按照正确的顺序组合特殊token和patches
|
||||
# 每帧内部顺序:[特殊tokens, patches]
|
||||
# Concatenate special tokens and patches for each frame along the second dimension
|
||||
# Shape: (ppf, patch_start_idx + pph * ppw, dim)
|
||||
freqs = torch.cat([freqs_special, freqs_patches], dim=1) # (ppf, patch_start_idx + pph * ppw, dim)
|
||||
|
||||
# 步骤4:展平为最终形状并添加batch和head维度
|
||||
# Flatten to get final shape: (ppf * (patch_start_idx + pph * ppw), dim)
|
||||
freqs = freqs.reshape(ppf * (patch_start_idx + pph * ppw), -1)
|
||||
freqs = freqs.unsqueeze(0).unsqueeze(0) # (1, 1, ppf * (patch_start_idx + pph * ppw), dim) 添加batch和head维度
|
||||
return freqs
|
||||
|
||||
# 如果没有特殊token(patch_start_idx == 0),只处理图像patches
|
||||
# 所有patches位于 (f, 0:pph, 0:ppw)
|
||||
freqs_f = freqs[0][frame_slice].reshape(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_f) 帧维度
|
||||
freqs_h = freqs[1][:pph].reshape(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_h) 高度从0开始
|
||||
freqs_w = freqs[2][:ppw].reshape(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # (ppf, pph, ppw, dim_w) 宽度从0开始
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) # (1, 1, ppf * pph * ppw, dim)
|
||||
return freqs
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
"""
|
||||
应用旋转位置编码到输入特征
|
||||
|
||||
核心思想:使用复数乘法实现特征旋转,保持相对位置信息
|
||||
|
||||
数学原理:
|
||||
对于2D向量 [x1, x2],旋转θ角度可以表示为复数乘法:
|
||||
(x1 + ix2) * e^(iθ) = (x1 + ix2) * (cos(θ) + i*sin(θ))
|
||||
= (x1*cos(θ) - x2*sin(θ)) + i*(x1*sin(θ) + x2*cos(θ))
|
||||
|
||||
这等价于旋转矩阵:
|
||||
[cos(θ) -sin(θ)] [x1]
|
||||
[sin(θ) cos(θ)] [x2]
|
||||
|
||||
参数:
|
||||
- x: 输入特征 [batch, heads, seq_len, head_dim]
|
||||
- freqs: 旋转频率(复数) [1, 1, seq_len, head_dim//2]
|
||||
|
||||
返回:
|
||||
- x_out: 旋转后的特征 [batch, heads, seq_len, head_dim]
|
||||
|
||||
实现步骤:
|
||||
1. 将x的每两个连续特征看作一个复数 (real, imag)
|
||||
2. 与预计算的复数频率 e^(iθ) 相乘
|
||||
3. 转回实数表示
|
||||
"""
|
||||
# 步骤1:reshape成 [..., head_dim//2, 2] 形式,最后一维表示(real, imag)
|
||||
# 例如:[b, h, seq, 64] -> [b, h, seq, 32, 2]
|
||||
x_reshaped = x.to(torch.float64).reshape(x.shape[0], x.shape[1], x.shape[2], -1, 2)
|
||||
|
||||
# 步骤2:转换为复数表示 [b, h, seq, 32]
|
||||
# 每个元素是 real + imag*i
|
||||
x_complex = torch.view_as_complex(x_reshaped)
|
||||
|
||||
# 步骤3:复数乘法实现旋转
|
||||
# x_complex * freqs 相当于将每对特征旋转θ角度
|
||||
# freqs已经是 e^(iθ) = cos(θ) + i*sin(θ) 的形式
|
||||
x_rotated = x_complex * freqs
|
||||
|
||||
# 步骤4:转回实数表示 [b, h, seq, 32, 2]
|
||||
x_real = torch.view_as_real(x_rotated)
|
||||
|
||||
# 步骤5:展平最后两维 [b, h, seq, 64]
|
||||
x_out = x_real.flatten(3)
|
||||
|
||||
# 步骤6:转回原始数据类型
|
||||
return x_out.to(x.dtype)
|
||||
67
lingbot_map/layers/swiglu_ffn.py
Normal file
67
lingbot_map/layers/swiglu_ffn.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||
# try:
|
||||
# if XFORMERS_ENABLED:
|
||||
# from xformers.ops import SwiGLU
|
||||
|
||||
# XFORMERS_AVAILABLE = True
|
||||
# warnings.warn("xFormers is available (SwiGLU)")
|
||||
# else:
|
||||
# warnings.warn("xFormers is disabled (SwiGLU)")
|
||||
# raise ImportError
|
||||
# except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
# warnings.warn("xFormers is not available (SwiGLU)")
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias)
|
||||
411
lingbot_map/layers/vision_transformer.py
Normal file
411
lingbot_map/layers/vision_transformer.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention#, NestedTensorBlock as Block
|
||||
|
||||
# TODO: Check this
|
||||
# We replace NestedTensorBlock with Block
|
||||
from .block import Block
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = ".".join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=Block,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
drop_cls_token=False,
|
||||
qk_norm=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
||||
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
||||
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1 if not drop_cls_token else 0
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
self.use_reentrant = False # hardcoded to False
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.drop_cls_token = drop_cls_token
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if not drop_cls_token else None
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
assert num_register_tokens >= 0
|
||||
self.register_tokens = (
|
||||
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
||||
)
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||
logger.info("using SwiGLU layer as FFN")
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == "identity":
|
||||
logger.info("using Identity layer as FFN")
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
qk_norm=qk_norm,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None else None
|
||||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1 if not self.drop_cls_token else self.pos_embed.shape[1]
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
if not self.drop_cls_token:
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
else:
|
||||
patch_pos_embed = pos_embed
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
||||
assert N == M * M
|
||||
kwargs = {}
|
||||
if self.interpolate_offset:
|
||||
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
||||
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
||||
sx = float(w0 + self.interpolate_offset) / M
|
||||
sy = float(h0 + self.interpolate_offset) / M
|
||||
kwargs["scale_factor"] = (sx, sy)
|
||||
else:
|
||||
# Simply specify an output size instead of a scale factor
|
||||
kwargs["size"] = (w0, h0)
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
**kwargs,
|
||||
)
|
||||
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
if not self.drop_cls_token:
|
||||
x = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
else:
|
||||
x = patch_pos_embed
|
||||
return x.to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None else x
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat((x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), dim=1)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.training:
|
||||
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.training:
|
||||
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=True, **kwargs):
|
||||
ret = self.forward_features(*args, **kwargs)
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret["x_norm_clstoken"])
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1536,
|
||||
depth=40,
|
||||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
0
lingbot_map/models/__init__.py
Normal file
0
lingbot_map/models/__init__.py
Normal file
359
lingbot_map/models/gct_base.py
Normal file
359
lingbot_map/models/gct_base.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
GCTBase - Base class for GCT model implementations.
|
||||
|
||||
Provides shared functionality:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Model hub mixin (PyTorchModelHubMixin)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from lingbot_map.heads.dpt_head import DPTHead
|
||||
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
|
||||
"""
|
||||
Base class for GCT model implementations.
|
||||
|
||||
Handles shared components:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Input normalization
|
||||
|
||||
Subclasses must implement:
|
||||
- _build_aggregator(): Create mode-specific aggregator
|
||||
- _build_camera_head(): Create mode-specific camera head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
patch_embed: str = 'dinov2_vitl14_reg',
|
||||
disable_global_rope: bool = False,
|
||||
# Head configuration
|
||||
enable_camera: bool = True,
|
||||
enable_point: bool = True,
|
||||
enable_local_point: bool = False,
|
||||
enable_depth: bool = True,
|
||||
enable_track: bool = False,
|
||||
# Camera head sliding window
|
||||
enable_camera_sliding_window: bool = False,
|
||||
# 3D RoPE
|
||||
enable_3d_rope: bool = False,
|
||||
# Context Parallelism (kept for checkpoint compatibility but not used)
|
||||
enable_ulysses_cp: bool = False,
|
||||
# Normalization
|
||||
enable_normalize: bool = False,
|
||||
# Prediction normalization
|
||||
pred_normalization: bool = False,
|
||||
pred_normalization_detach_scale: bool = False,
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store configuration
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_embed = patch_embed
|
||||
self.disable_global_rope = disable_global_rope
|
||||
|
||||
self.enable_ulysses_cp = False # CP disabled in standalone package
|
||||
self.enable_normalize = enable_normalize
|
||||
self.pred_normalization = pred_normalization
|
||||
self.pred_normalization_detach_scale = pred_normalization_detach_scale
|
||||
self.use_gradient_checkpoint = use_gradient_checkpoint
|
||||
|
||||
# Head flags
|
||||
self.enable_camera = enable_camera
|
||||
self.enable_point = enable_point
|
||||
self.enable_local_point = enable_local_point
|
||||
self.enable_depth = enable_depth
|
||||
self.enable_track = enable_track
|
||||
self.enable_camera_sliding_window = enable_camera_sliding_window
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
|
||||
# Build aggregator (subclass-specific)
|
||||
self.aggregator = self._build_aggregator()
|
||||
|
||||
# Build prediction heads (subclass-specific)
|
||||
self.camera_head = self._build_camera_head() if enable_camera else None
|
||||
self.point_head = self._build_point_head() if enable_point else None
|
||||
self.local_point_head = self._build_local_point_head() if enable_local_point else None
|
||||
self.depth_head = self._build_depth_head() if enable_depth else None
|
||||
|
||||
@abstractmethod
|
||||
def _build_aggregator(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_camera_head(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
def _build_depth_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=2,
|
||||
activation="exp",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_local_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _normalize_input(self, images: torch.Tensor, query_points=None):
|
||||
if len(images.shape) == 4:
|
||||
images = images.unsqueeze(0)
|
||||
if query_points is not None and len(query_points.shape) == 2:
|
||||
query_points = query_points.unsqueeze(0)
|
||||
return images, query_points
|
||||
|
||||
@abstractmethod
|
||||
def _aggregate_features(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
view_graphs: Optional[torch.Tensor] = None,
|
||||
causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
is_cp_sliced: bool = False,
|
||||
) -> tuple:
|
||||
pass
|
||||
|
||||
def _predict_camera(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.camera_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
|
||||
camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pose_enc_list = self.camera_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
mask=mask,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
|
||||
sliding_window_size=camera_sliding_window,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
return {
|
||||
"pose_enc": pose_enc_list[-1],
|
||||
"pose_enc_list": pose_enc_list,
|
||||
}
|
||||
|
||||
def _predict_depth(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.depth_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
depth, depth_conf = self.depth_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"depth": depth, "depth_conf": depth_conf}
|
||||
|
||||
def _predict_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"world_points": pts3d, "world_points_conf": pts3d_conf}
|
||||
|
||||
def _predict_local_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.local_point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.local_point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
|
||||
|
||||
def _unproject_depth_to_world(
|
||||
self,
|
||||
depth: torch.Tensor,
|
||||
pose_enc: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, S, H, W, _ = depth.shape
|
||||
device = depth.device
|
||||
dtype = depth.dtype
|
||||
|
||||
image_size_hw = (H, W)
|
||||
extrinsics, intrinsics = pose_encoding_to_extri_intri(
|
||||
pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
|
||||
)
|
||||
|
||||
extrinsics_flat = extrinsics.view(B * S, 3, 4)
|
||||
extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
|
||||
extrinsics_4x4[:, :3, :] = extrinsics_flat
|
||||
extrinsics_4x4[:, 3, 3] = 1.0
|
||||
c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(
|
||||
torch.arange(H, device=device, dtype=dtype),
|
||||
torch.arange(W, device=device, dtype=dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
|
||||
|
||||
intrinsics_inv = torch.inverse(intrinsics)
|
||||
camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
|
||||
camera_points = camera_coords * depth
|
||||
|
||||
ones = torch.ones_like(camera_points[..., :1])
|
||||
camera_points_h = torch.cat([camera_points, ones], dim=-1)
|
||||
world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
|
||||
|
||||
return world_points_h[..., :3]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
query_points: Optional[torch.Tensor] = None,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
gather_outputs: bool = True,
|
||||
point_masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the GCT model.
|
||||
|
||||
Args:
|
||||
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
||||
query_points: Optional query points [N, 2] or [B, N, 2]
|
||||
|
||||
Returns:
|
||||
Dictionary containing predictions:
|
||||
- pose_enc: Camera pose encoding [B, S, 9]
|
||||
- depth: Depth maps [B, S, H, W, 1]
|
||||
- depth_conf: Depth confidence [B, S, H, W]
|
||||
- world_points: 3D world coordinates [B, S, H, W, 3]
|
||||
- world_points_conf: Point confidence [B, S, H, W]
|
||||
"""
|
||||
images, query_points = self._normalize_input(images, query_points)
|
||||
|
||||
aggregated_tokens_list, patch_start_idx = self._aggregate_features(
|
||||
images,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
predictions = {}
|
||||
|
||||
predictions.update(self._predict_camera(
|
||||
aggregated_tokens_list,
|
||||
mask=ordered_video,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_depth(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_local_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
if not self.training:
|
||||
predictions["images"] = images
|
||||
|
||||
return predictions
|
||||
444
lingbot_map/models/gct_stream.py
Normal file
444
lingbot_map/models/gct_stream.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
GCTStream - Streaming GCT with KV cache for online inference.
|
||||
|
||||
Provides streaming inference functionality:
|
||||
- Temporal causal attention with KV cache
|
||||
- Sliding window support
|
||||
- Efficient frame-by-frame processing
|
||||
- 3D RoPE support for temporal consistency
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Dict, Any, List
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lingbot_map.heads.camera_head import CameraCausalHead
|
||||
from lingbot_map.models.gct_base import GCTBase
|
||||
from lingbot_map.aggregator.stream import AggregatorStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCTStream(GCTBase):
|
||||
"""
|
||||
Streaming GCT model with KV cache for efficient online inference.
|
||||
|
||||
Features:
|
||||
- AggregatorStream with KV cache support (FlashInfer backend)
|
||||
- CameraCausalHead for pose refinement
|
||||
- Sliding window attention for memory efficiency
|
||||
- Frame-by-frame streaming inference
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
patch_embed: str = 'dinov2_vitl14_reg',
|
||||
pretrained_path: str = '',
|
||||
disable_global_rope: bool = False,
|
||||
# Head configuration
|
||||
enable_camera: bool = True,
|
||||
enable_point: bool = True,
|
||||
enable_local_point: bool = False,
|
||||
enable_depth: bool = True,
|
||||
enable_track: bool = False,
|
||||
# Normalization
|
||||
enable_normalize: bool = False,
|
||||
# Prediction normalization
|
||||
pred_normalization: bool = False,
|
||||
# Stream-specific parameters
|
||||
sliding_window_size: int = -1,
|
||||
num_frame_for_scale: int = 1,
|
||||
num_random_frames: int = 0,
|
||||
attend_to_special_tokens: bool = False,
|
||||
attend_to_scale_frames: bool = False,
|
||||
enable_stream_inference: bool = True, # Default to True for streaming
|
||||
enable_3d_rope: bool = False,
|
||||
max_frame_num: int = 1024,
|
||||
# Camera head 3D RoPE (separate from aggregator 3D RoPE)
|
||||
enable_camera_3d_rope: bool = False,
|
||||
camera_rope_theta: float = 10000.0,
|
||||
# Scale token configuration (kept for checkpoint compat, ignored)
|
||||
use_scale_token: bool = True,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window: int = 64,
|
||||
kv_cache_scale_frames: int = 8,
|
||||
kv_cache_cross_frame_special: bool = True,
|
||||
kv_cache_include_scale_frames: bool = True,
|
||||
kv_cache_camera_only: bool = False,
|
||||
# Backend selection
|
||||
use_sdpa: bool = False, # If True, use SDPA (no flashinfer needed); default: FlashInfer
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize GCTStream.
|
||||
|
||||
Args:
|
||||
img_size: Input image size
|
||||
patch_size: Patch size for embedding
|
||||
embed_dim: Embedding dimension
|
||||
patch_embed: Patch embedding type ("dinov2_vitl14_reg", "conv", etc.)
|
||||
pretrained_path: Path to pretrained DINOv2 weights
|
||||
disable_global_rope: Disable RoPE in global attention
|
||||
enable_camera/point/depth/track: Enable prediction heads
|
||||
enable_normalize: Enable normalization
|
||||
sliding_window_size: Sliding window size in blocks (-1 for full causal)
|
||||
num_frame_for_scale: Number of scale estimation frames
|
||||
num_random_frames: Number of random frames for long-range dependencies
|
||||
attend_to_special_tokens: Enable cross-frame special token attention
|
||||
attend_to_scale_frames: Whether to attend to scale frames
|
||||
enable_stream_inference: Enable streaming inference with KV cache
|
||||
enable_3d_rope: Enable 3D RoPE for temporal consistency
|
||||
max_frame_num: Maximum number of frames for 3D RoPE
|
||||
use_scale_token: Kept for checkpoint compatibility, ignored
|
||||
kv_cache_sliding_window: Sliding window size for KV cache eviction
|
||||
kv_cache_scale_frames: Number of scale frames to keep in KV cache
|
||||
kv_cache_cross_frame_special: Keep special tokens from evicted frames
|
||||
kv_cache_include_scale_frames: Include scale frames in KV cache
|
||||
kv_cache_camera_only: Only keep camera tokens from evicted frames
|
||||
"""
|
||||
# Store stream-specific parameters before calling super().__init__()
|
||||
self.pretrained_path = pretrained_path
|
||||
self.sliding_window_size = sliding_window_size
|
||||
self.num_frame_for_scale = num_frame_for_scale
|
||||
self.num_random_frames = num_random_frames
|
||||
self.attend_to_special_tokens = attend_to_special_tokens
|
||||
self.attend_to_scale_frames = attend_to_scale_frames
|
||||
self.enable_stream_inference = enable_stream_inference
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
self.max_frame_num = max_frame_num
|
||||
# Camera head 3D RoPE settings
|
||||
self.enable_camera_3d_rope = enable_camera_3d_rope
|
||||
self.camera_rope_theta = camera_rope_theta
|
||||
# KV cache parameters
|
||||
self.kv_cache_sliding_window = kv_cache_sliding_window
|
||||
self.kv_cache_scale_frames = kv_cache_scale_frames
|
||||
self.kv_cache_cross_frame_special = kv_cache_cross_frame_special
|
||||
self.kv_cache_include_scale_frames = kv_cache_include_scale_frames
|
||||
self.kv_cache_camera_only = kv_cache_camera_only
|
||||
self.use_sdpa = use_sdpa
|
||||
|
||||
# Call base class __init__ (will call _build_aggregator)
|
||||
super().__init__(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
embed_dim=embed_dim,
|
||||
patch_embed=patch_embed,
|
||||
disable_global_rope=disable_global_rope,
|
||||
enable_camera=enable_camera,
|
||||
enable_point=enable_point,
|
||||
enable_local_point=enable_local_point,
|
||||
enable_depth=enable_depth,
|
||||
enable_track=enable_track,
|
||||
enable_normalize=enable_normalize,
|
||||
pred_normalization=pred_normalization,
|
||||
enable_3d_rope=enable_3d_rope,
|
||||
use_gradient_checkpoint=use_gradient_checkpoint,
|
||||
)
|
||||
|
||||
def _build_aggregator(self) -> nn.Module:
|
||||
"""
|
||||
Build streaming aggregator with KV cache support (FlashInfer backend).
|
||||
|
||||
Returns:
|
||||
AggregatorStream module
|
||||
"""
|
||||
return AggregatorStream(
|
||||
img_size=self.img_size,
|
||||
patch_size=self.patch_size,
|
||||
embed_dim=self.embed_dim,
|
||||
patch_embed=self.patch_embed,
|
||||
pretrained_path=self.pretrained_path,
|
||||
disable_global_rope=self.disable_global_rope,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
num_frame_for_scale=self.num_frame_for_scale,
|
||||
num_random_frames=self.num_random_frames,
|
||||
attend_to_special_tokens=self.attend_to_special_tokens,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
enable_stream_inference=self.enable_stream_inference,
|
||||
enable_3d_rope=self.enable_3d_rope,
|
||||
max_frame_num=self.max_frame_num,
|
||||
# Backend: FlashInfer (default) or SDPA (fallback)
|
||||
use_flashinfer=not self.use_sdpa,
|
||||
use_sdpa=self.use_sdpa,
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
use_gradient_checkpoint=self.use_gradient_checkpoint,
|
||||
)
|
||||
|
||||
def _build_camera_head(self) -> nn.Module:
|
||||
"""
|
||||
Build causal camera head for streaming inference.
|
||||
|
||||
Returns:
|
||||
CameraCausalHead module or None
|
||||
"""
|
||||
return CameraCausalHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
sliding_window_size=self.sliding_window_size,
|
||||
attend_to_scale_frames=self.attend_to_scale_frames,
|
||||
# KV cache parameters
|
||||
kv_cache_sliding_window=self.kv_cache_sliding_window,
|
||||
kv_cache_scale_frames=self.kv_cache_scale_frames,
|
||||
kv_cache_cross_frame_special=self.kv_cache_cross_frame_special,
|
||||
kv_cache_include_scale_frames=self.kv_cache_include_scale_frames,
|
||||
kv_cache_camera_only=self.kv_cache_camera_only,
|
||||
# Camera head 3D RoPE parameters
|
||||
enable_3d_rope=self.enable_camera_3d_rope,
|
||||
max_frame_num=self.max_frame_num,
|
||||
rope_theta=self.camera_rope_theta,
|
||||
)
|
||||
|
||||
def _aggregate_features(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""
|
||||
Run aggregator to get multi-scale features.
|
||||
|
||||
Args:
|
||||
images: Input images [B, S, 3, H, W]
|
||||
num_frame_for_scale: Number of frames for scale estimation
|
||||
sliding_window_size: Override sliding window size
|
||||
num_frame_per_block: Number of frames per block
|
||||
|
||||
Returns:
|
||||
(aggregated_tokens_list, patch_start_idx)
|
||||
"""
|
||||
aggregated_tokens_list, patch_start_idx = self.aggregator(
|
||||
images,
|
||||
selected_idx=[4, 11, 17, 23],
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
return aggregated_tokens_list, patch_start_idx
|
||||
|
||||
def clean_kv_cache(self):
|
||||
"""
|
||||
Clean KV cache in aggregator.
|
||||
|
||||
Call this method when starting a new video sequence to clear
|
||||
cached key-value pairs from previous sequences.
|
||||
"""
|
||||
if hasattr(self.aggregator, 'clean_kv_cache'):
|
||||
self.aggregator.clean_kv_cache()
|
||||
else:
|
||||
logger.warning("Aggregator does not support KV cache cleaning")
|
||||
if hasattr(self.camera_head, 'kv_cache'):
|
||||
self.camera_head.clean_kv_cache()
|
||||
else:
|
||||
logger.warning("Camera head does not support KV cache cleaning")
|
||||
|
||||
def _set_skip_append(self, skip: bool):
|
||||
"""Set _skip_append flag on all KV caches (aggregator + camera head).
|
||||
|
||||
When skip=True, attention layers will attend to [cached_kv + current_kv]
|
||||
but will NOT store the current frame's KV in cache. This is used for
|
||||
non-keyframe processing in keyframe-based streaming inference.
|
||||
|
||||
Args:
|
||||
skip: If True, subsequent forward passes will not append KV to cache.
|
||||
"""
|
||||
if hasattr(self.aggregator, 'kv_cache') and self.aggregator.kv_cache is not None:
|
||||
self.aggregator.kv_cache["_skip_append"] = skip
|
||||
if self.camera_head is not None and hasattr(self.camera_head, 'kv_cache') and self.camera_head.kv_cache is not None:
|
||||
for cache_dict in self.camera_head.kv_cache:
|
||||
cache_dict["_skip_append"] = skip
|
||||
|
||||
def get_kv_cache_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about current KV cache state.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics:
|
||||
- num_cached_blocks: Number of blocks with cached KV
|
||||
- cache_memory_mb: Approximate memory usage in MB
|
||||
"""
|
||||
if not hasattr(self.aggregator, 'kv_cache') or self.aggregator.kv_cache is None:
|
||||
return {"num_cached_blocks": 0, "cache_memory_mb": 0.0}
|
||||
|
||||
kv_cache = self.aggregator.kv_cache
|
||||
num_cached = sum(1 for k in kv_cache.keys() if k.startswith('k_') and not k.endswith('_special'))
|
||||
|
||||
# Estimate memory usage
|
||||
total_elements = 0
|
||||
for _, v in kv_cache.items():
|
||||
if v is not None and torch.is_tensor(v):
|
||||
total_elements += v.numel()
|
||||
|
||||
# Assume bfloat16 (2 bytes per element)
|
||||
cache_memory_mb = (total_elements * 2) / (1024 * 1024)
|
||||
|
||||
return {
|
||||
"num_cached_blocks": num_cached,
|
||||
"cache_memory_mb": round(cache_memory_mb, 2)
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_streaming(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_scale_frames: Optional[int] = None,
|
||||
keyframe_interval: int = 1,
|
||||
output_device: Optional[torch.device] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Streaming inference: process scale frames first, then frame-by-frame.
|
||||
|
||||
This method enables efficient online inference by:
|
||||
1. Processing initial scale frames together (bidirectional attention via scale token)
|
||||
2. Processing remaining frames one-by-one with KV cache (causal streaming)
|
||||
|
||||
Keyframe mode (keyframe_interval > 1):
|
||||
- Every keyframe_interval-th frame (after scale frames) is a keyframe
|
||||
- Keyframes: KV is stored in cache (normal behavior)
|
||||
- Non-keyframes: KV is NOT stored in cache (attend to cached + own KV, then discard)
|
||||
- All frames produce full predictions regardless of keyframe status
|
||||
- Reduces KV cache memory growth by ~1/keyframe_interval
|
||||
|
||||
Args:
|
||||
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
||||
num_scale_frames: Number of initial frames for scale estimation.
|
||||
If None, uses self.num_frame_for_scale.
|
||||
keyframe_interval: Every N-th frame (after scale frames) is a keyframe
|
||||
whose KV persists in cache. 1 = every frame is a
|
||||
keyframe (default, same as original behavior).
|
||||
output_device: Device to store output predictions on. If None, keeps on
|
||||
the same device as the model. Set to torch.device('cpu')
|
||||
to offload predictions per-frame and avoid GPU OOM on
|
||||
long sequences.
|
||||
|
||||
Returns:
|
||||
Dictionary containing predictions for all frames:
|
||||
- pose_enc: [B, S, 9]
|
||||
- depth: [B, S, H, W, 1]
|
||||
- depth_conf: [B, S, H, W]
|
||||
- world_points: [B, S, H, W, 3]
|
||||
- world_points_conf: [B, S, H, W]
|
||||
"""
|
||||
# Normalize input shape
|
||||
if len(images.shape) == 4:
|
||||
images = images.unsqueeze(0)
|
||||
B, S, C, H, W = images.shape
|
||||
|
||||
# Determine number of scale frames
|
||||
scale_frames = num_scale_frames if num_scale_frames is not None else self.num_frame_for_scale
|
||||
scale_frames = min(scale_frames, S) # Cap to available frames
|
||||
|
||||
# Helper to move tensor to output device
|
||||
def _to_out(t: torch.Tensor) -> torch.Tensor:
|
||||
if output_device is not None:
|
||||
return t.to(output_device)
|
||||
return t
|
||||
|
||||
# Clean KV caches before starting new sequence
|
||||
self.clean_kv_cache()
|
||||
|
||||
# Phase 1: Process scale frames together
|
||||
# These frames get bidirectional attention among themselves via scale token
|
||||
logger.info(f'Processing {scale_frames} scale frames...')
|
||||
scale_images = images[:, :scale_frames]
|
||||
scale_output = self.forward(
|
||||
scale_images,
|
||||
num_frame_for_scale=scale_frames,
|
||||
num_frame_per_block=scale_frames, # Process all scale frames as one block
|
||||
causal_inference=True,
|
||||
)
|
||||
|
||||
# Initialize output lists with scale frame predictions (offload if needed)
|
||||
all_pose_enc = [_to_out(scale_output["pose_enc"])]
|
||||
all_depth = [_to_out(scale_output["depth"])] if "depth" in scale_output else []
|
||||
all_depth_conf = [_to_out(scale_output["depth_conf"])] if "depth_conf" in scale_output else []
|
||||
all_world_points = [_to_out(scale_output["world_points"])] if "world_points" in scale_output else []
|
||||
all_world_points_conf = [_to_out(scale_output["world_points_conf"])] if "world_points_conf" in scale_output else []
|
||||
del scale_output
|
||||
|
||||
# Phase 2: Process remaining frames one-by-one
|
||||
pbar = tqdm(
|
||||
range(scale_frames, S),
|
||||
desc='Streaming inference',
|
||||
initial=scale_frames,
|
||||
total=S,
|
||||
)
|
||||
for i in pbar:
|
||||
frame_image = images[:, i:i+1]
|
||||
|
||||
# Determine if this frame is a keyframe
|
||||
is_keyframe = (keyframe_interval <= 1) or ((i - scale_frames) % keyframe_interval == 0)
|
||||
|
||||
if not is_keyframe:
|
||||
self._set_skip_append(True)
|
||||
|
||||
frame_output = self.forward(
|
||||
frame_image,
|
||||
num_frame_for_scale=scale_frames, # Keep same for scale token logic
|
||||
num_frame_per_block=1, # Single frame per block
|
||||
causal_inference=True,
|
||||
)
|
||||
|
||||
if not is_keyframe:
|
||||
self._set_skip_append(False)
|
||||
|
||||
all_pose_enc.append(_to_out(frame_output["pose_enc"]))
|
||||
if "depth" in frame_output:
|
||||
all_depth.append(_to_out(frame_output["depth"]))
|
||||
if "depth_conf" in frame_output:
|
||||
all_depth_conf.append(_to_out(frame_output["depth_conf"]))
|
||||
if "world_points" in frame_output:
|
||||
all_world_points.append(_to_out(frame_output["world_points"]))
|
||||
if "world_points_conf" in frame_output:
|
||||
all_world_points_conf.append(_to_out(frame_output["world_points_conf"]))
|
||||
del frame_output
|
||||
|
||||
# Free GPU memory before concatenation
|
||||
if output_device is not None:
|
||||
# Move images to output device, then free GPU copy
|
||||
images_out = _to_out(images)
|
||||
del images
|
||||
# Clean KV cache (no longer needed after inference)
|
||||
self.clean_kv_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
images_out = images
|
||||
|
||||
# Concatenate all predictions along sequence dimension
|
||||
predictions = {
|
||||
"pose_enc": torch.cat(all_pose_enc, dim=1),
|
||||
}
|
||||
del all_pose_enc
|
||||
if all_depth:
|
||||
predictions["depth"] = torch.cat(all_depth, dim=1)
|
||||
del all_depth
|
||||
if all_depth_conf:
|
||||
predictions["depth_conf"] = torch.cat(all_depth_conf, dim=1)
|
||||
del all_depth_conf
|
||||
if all_world_points:
|
||||
predictions["world_points"] = torch.cat(all_world_points, dim=1)
|
||||
del all_world_points
|
||||
if all_world_points_conf:
|
||||
predictions["world_points_conf"] = torch.cat(all_world_points_conf, dim=1)
|
||||
del all_world_points_conf
|
||||
|
||||
# Store images for visualization
|
||||
predictions["images"] = images_out
|
||||
|
||||
# Apply prediction normalization if enabled
|
||||
if self.pred_normalization:
|
||||
predictions = self._normalize_predictions(predictions)
|
||||
|
||||
return predictions
|
||||
0
lingbot_map/utils/__init__.py
Normal file
0
lingbot_map/utils/__init__.py
Normal file
774
lingbot_map/utils/geometry.py
Normal file
774
lingbot_map/utils/geometry.py
Normal file
@@ -0,0 +1,774 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from scipy.spatial.transform import Rotation
|
||||
try:
|
||||
from lietorch import SE3, Sim3
|
||||
except ImportError:
|
||||
SE3 = Sim3 = None
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from lingbot_map.dependency.distortion import apply_distortion, iterative_undistortion, single_undistortion
|
||||
except ImportError:
|
||||
apply_distortion = iterative_undistortion = single_undistortion = None
|
||||
|
||||
|
||||
def unproject_depth_map_to_point_map(
|
||||
depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Unproject a batch of depth maps to 3D world coordinates.
|
||||
|
||||
Args:
|
||||
depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
|
||||
extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
|
||||
intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
|
||||
"""
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
depth_map = depth_map.cpu().numpy()
|
||||
if isinstance(extrinsics_cam, torch.Tensor):
|
||||
extrinsics_cam = extrinsics_cam.cpu().numpy()
|
||||
if isinstance(intrinsics_cam, torch.Tensor):
|
||||
intrinsics_cam = intrinsics_cam.cpu().numpy()
|
||||
|
||||
world_points_list = []
|
||||
for frame_idx in range(depth_map.shape[0]):
|
||||
cur_world_points, _, _ = depth_to_world_coords_points(
|
||||
depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
|
||||
)
|
||||
world_points_list.append(cur_world_points)
|
||||
world_points_array = np.stack(world_points_list, axis=0)
|
||||
|
||||
return world_points_array
|
||||
|
||||
|
||||
def depth_to_world_coords_points(
|
||||
depth_map: np.ndarray,
|
||||
extrinsic: np.ndarray,
|
||||
intrinsic: np.ndarray,
|
||||
eps=1e-8,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Convert a depth map to world coordinates.
|
||||
|
||||
Args:
|
||||
depth_map (np.ndarray): Depth map of shape (H, W).
|
||||
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
||||
extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
|
||||
"""
|
||||
if depth_map is None:
|
||||
return None, None, None
|
||||
|
||||
# Valid depth mask
|
||||
point_mask = depth_map > eps
|
||||
|
||||
# Convert depth map to camera coordinates
|
||||
cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
|
||||
|
||||
# Multiply with the inverse of extrinsic matrix to transform to world coordinates
|
||||
# extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
|
||||
cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
|
||||
|
||||
R_cam_to_world = cam_to_world_extrinsic[:3, :3]
|
||||
t_cam_to_world = cam_to_world_extrinsic[:3, 3]
|
||||
|
||||
# Apply the rotation and translation to the camera coordinates
|
||||
world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
|
||||
# world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
|
||||
|
||||
return world_coords_points, cam_coords_points, point_mask
|
||||
|
||||
|
||||
def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Convert a depth map to camera coordinates.
|
||||
|
||||
Args:
|
||||
depth_map (np.ndarray): Depth map of shape (H, W).
|
||||
intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
|
||||
"""
|
||||
H, W = depth_map.shape
|
||||
assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
|
||||
assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
|
||||
|
||||
# Intrinsic parameters
|
||||
fu, fv = intrinsic[0, 0], intrinsic[1, 1]
|
||||
cu, cv = intrinsic[0, 2], intrinsic[1, 2]
|
||||
|
||||
# Generate grid of pixel coordinates
|
||||
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
||||
|
||||
# Unproject to camera coordinates
|
||||
x_cam = (u - cu) * depth_map / fu
|
||||
y_cam = (v - cv) * depth_map / fv
|
||||
z_cam = depth_map
|
||||
|
||||
# Stack to form camera coordinates
|
||||
cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
||||
|
||||
return cam_coords
|
||||
|
||||
|
||||
def closed_form_inverse_se3(se3, R=None, T=None):
|
||||
"""
|
||||
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
||||
|
||||
If `R` and `T` are provided, they must correspond to the rotation and translation
|
||||
components of `se3`. Otherwise, they will be extracted from `se3`.
|
||||
|
||||
Args:
|
||||
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
||||
R (optional): Nx3x3 array or tensor of rotation matrices.
|
||||
T (optional): Nx3x1 array or tensor of translation vectors.
|
||||
|
||||
Returns:
|
||||
Inverted SE3 matrices with the same type and device as `se3`.
|
||||
|
||||
Shapes:
|
||||
se3: (N, 4, 4)
|
||||
R: (N, 3, 3)
|
||||
T: (N, 3, 1)
|
||||
"""
|
||||
# Check if se3 is a numpy array or a torch tensor
|
||||
is_numpy = isinstance(se3, np.ndarray)
|
||||
|
||||
# Validate shapes
|
||||
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
||||
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
||||
|
||||
# Extract R and T if not provided
|
||||
if R is None:
|
||||
R = se3[:, :3, :3] # (N,3,3)
|
||||
if T is None:
|
||||
T = se3[:, :3, 3:] # (N,3,1)
|
||||
|
||||
# Transpose R
|
||||
if is_numpy:
|
||||
# Compute the transpose of the rotation for NumPy
|
||||
R_transposed = np.transpose(R, (0, 2, 1))
|
||||
# -R^T t for NumPy
|
||||
top_right = -np.matmul(R_transposed, T)
|
||||
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
||||
else:
|
||||
R_transposed = R.transpose(1, 2) # (N,3,3)
|
||||
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
||||
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
||||
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
||||
|
||||
inverted_matrix[:, :3, :3] = R_transposed
|
||||
inverted_matrix[:, :3, 3:] = top_right
|
||||
|
||||
return inverted_matrix
|
||||
|
||||
def closed_form_inverse_se3_general(se3, R=None, T=None):
|
||||
"""
|
||||
支持任意 batch 维度的 SE3 逆运算
|
||||
se3: (..., 4, 4) 或 (..., 3, 4)
|
||||
"""
|
||||
batch_shape = se3.shape[:-2]
|
||||
if R is None:
|
||||
R = se3[..., :3, :3]
|
||||
if T is None:
|
||||
T = se3[..., :3, 3:]
|
||||
R_transposed = R.transpose(-2, -1)
|
||||
top_right = -R_transposed @ T
|
||||
# 构造单位阵
|
||||
eye = torch.eye(4, 4, dtype=R.dtype, device=R.device)
|
||||
inverted_matrix = eye.expand(*batch_shape, 4, 4).clone()
|
||||
inverted_matrix[..., :3, :3] = R_transposed
|
||||
inverted_matrix[..., :3, 3:] = top_right
|
||||
return inverted_matrix
|
||||
|
||||
|
||||
# TODO: this code can be further cleaned up
|
||||
|
||||
|
||||
def project_world_points_to_camera_points_batch(world_points, cam_extrinsics):
|
||||
"""
|
||||
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
||||
Args:
|
||||
world_points (torch.Tensor): 3D points of shape BxSxHxWx3.
|
||||
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape BxSx3x4.
|
||||
Returns:
|
||||
"""
|
||||
# TODO: merge this into project_world_points_to_cam
|
||||
|
||||
# device = world_points.device
|
||||
# with torch.autocast(device_type=device.type, enabled=False):
|
||||
ones = torch.ones_like(world_points[..., :1]) # shape: (B, S, H, W, 1)
|
||||
world_points_h = torch.cat([world_points, ones], dim=-1) # shape: (B, S, H, W, 4)
|
||||
|
||||
# extrinsics: (B, S, 3, 4) -> (B, S, 1, 1, 3, 4)
|
||||
extrinsics_exp = cam_extrinsics.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
# world_points_h: (B, S, H, W, 4) -> (B, S, H, W, 4, 1)
|
||||
world_points_h_exp = world_points_h.unsqueeze(-1)
|
||||
|
||||
# Now perform the matrix multiplication
|
||||
# (B, S, 1, 1, 3, 4) @ (B, S, H, W, 4, 1) broadcasts to (B, S, H, W, 3, 1)
|
||||
camera_points = torch.matmul(extrinsics_exp, world_points_h_exp).squeeze(-1)
|
||||
|
||||
return camera_points
|
||||
|
||||
|
||||
|
||||
def project_world_points_to_cam(
|
||||
world_points,
|
||||
cam_extrinsics,
|
||||
cam_intrinsics=None,
|
||||
distortion_params=None,
|
||||
default=0,
|
||||
only_points_cam=False,
|
||||
):
|
||||
"""
|
||||
Transforms 3D points to 2D using extrinsic and intrinsic parameters.
|
||||
Args:
|
||||
world_points (torch.Tensor): 3D points of shape Px3.
|
||||
cam_extrinsics (torch.Tensor): Extrinsic parameters of shape Bx3x4.
|
||||
cam_intrinsics (torch.Tensor): Intrinsic parameters of shape Bx3x3.
|
||||
distortion_params (torch.Tensor): Extra parameters of shape BxN, which is used for radial distortion.
|
||||
Returns:
|
||||
torch.Tensor: Transformed 2D points of shape BxNx2.
|
||||
"""
|
||||
device = world_points.device
|
||||
# with torch.autocast(device_type=device.type, dtype=torch.double):
|
||||
with torch.autocast(device_type=device.type, enabled=False):
|
||||
N = world_points.shape[0] # Number of points
|
||||
B = cam_extrinsics.shape[0] # Batch size, i.e., number of cameras
|
||||
world_points_homogeneous = torch.cat(
|
||||
[world_points, torch.ones_like(world_points[..., 0:1])], dim=1
|
||||
) # Nx4
|
||||
# Reshape for batch processing
|
||||
world_points_homogeneous = world_points_homogeneous.unsqueeze(0).expand(
|
||||
B, -1, -1
|
||||
) # BxNx4
|
||||
|
||||
# Step 1: Apply extrinsic parameters
|
||||
# Transform 3D points to camera coordinate system for all cameras
|
||||
cam_points = torch.bmm(
|
||||
cam_extrinsics, world_points_homogeneous.transpose(-1, -2)
|
||||
)
|
||||
|
||||
if only_points_cam:
|
||||
return None, cam_points
|
||||
|
||||
# Step 2: Apply intrinsic parameters and (optional) distortion
|
||||
image_points = img_from_cam(cam_intrinsics, cam_points, distortion_params, default=default)
|
||||
|
||||
return image_points, cam_points
|
||||
|
||||
|
||||
|
||||
def img_from_cam(cam_intrinsics, cam_points, distortion_params=None, default=0.0):
|
||||
"""
|
||||
Applies intrinsic parameters and optional distortion to the given 3D points.
|
||||
|
||||
Args:
|
||||
cam_intrinsics (torch.Tensor): Intrinsic camera parameters of shape Bx3x3.
|
||||
cam_points (torch.Tensor): 3D points in camera coordinates of shape Bx3xN.
|
||||
distortion_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
||||
default (float, optional): Default value to replace NaNs in the output.
|
||||
|
||||
Returns:
|
||||
pixel_coords (torch.Tensor): 2D points in pixel coordinates of shape BxNx2.
|
||||
"""
|
||||
|
||||
# Normalized device coordinates (NDC)
|
||||
cam_points = cam_points / cam_points[:, 2:3, :]
|
||||
ndc_xy = cam_points[:, :2, :]
|
||||
|
||||
# Apply distortion if distortion_params are provided
|
||||
if distortion_params is not None:
|
||||
x_distorted, y_distorted = apply_distortion(distortion_params, ndc_xy[:, 0], ndc_xy[:, 1])
|
||||
distorted_xy = torch.stack([x_distorted, y_distorted], dim=1)
|
||||
else:
|
||||
distorted_xy = ndc_xy
|
||||
|
||||
# Prepare cam_points for batch matrix multiplication
|
||||
cam_coords_homo = torch.cat(
|
||||
(distorted_xy, torch.ones_like(distorted_xy[:, :1, :])), dim=1
|
||||
) # Bx3xN
|
||||
# Apply intrinsic parameters using batch matrix multiplication
|
||||
pixel_coords = torch.bmm(cam_intrinsics, cam_coords_homo) # Bx3xN
|
||||
|
||||
# Extract x and y coordinates
|
||||
pixel_coords = pixel_coords[:, :2, :] # Bx2xN
|
||||
|
||||
# Replace NaNs with default value
|
||||
pixel_coords = torch.nan_to_num(pixel_coords, nan=default)
|
||||
|
||||
return pixel_coords.transpose(1, 2) # BxNx2
|
||||
|
||||
|
||||
|
||||
|
||||
def cam_from_img(pred_tracks, intrinsics, extra_params=None):
|
||||
"""
|
||||
Normalize predicted tracks based on camera intrinsics.
|
||||
Args:
|
||||
intrinsics (torch.Tensor): The camera intrinsics tensor of shape [batch_size, 3, 3].
|
||||
pred_tracks (torch.Tensor): The predicted tracks tensor of shape [batch_size, num_tracks, 2].
|
||||
extra_params (torch.Tensor, optional): Distortion parameters of shape BxN, where N can be 1, 2, or 4.
|
||||
Returns:
|
||||
torch.Tensor: Normalized tracks tensor.
|
||||
"""
|
||||
|
||||
# We don't want to do intrinsics_inv = torch.inverse(intrinsics) here
|
||||
# otherwise we can use something like
|
||||
# tracks_normalized_homo = torch.bmm(pred_tracks_homo, intrinsics_inv.transpose(1, 2))
|
||||
|
||||
principal_point = intrinsics[:, [0, 1], [2, 2]].unsqueeze(-2)
|
||||
focal_length = intrinsics[:, [0, 1], [0, 1]].unsqueeze(-2)
|
||||
tracks_normalized = (pred_tracks - principal_point) / focal_length
|
||||
|
||||
if extra_params is not None:
|
||||
# Apply iterative undistortion
|
||||
try:
|
||||
tracks_normalized = iterative_undistortion(
|
||||
extra_params, tracks_normalized
|
||||
)
|
||||
except:
|
||||
tracks_normalized = single_undistortion(
|
||||
extra_params, tracks_normalized
|
||||
)
|
||||
|
||||
return tracks_normalized
|
||||
|
||||
## Droid SLAM Part
|
||||
|
||||
MIN_DEPTH = 0.2
|
||||
|
||||
def extract_intrinsics(intrinsics):
|
||||
return intrinsics[...,None,None,:].unbind(dim=-1)
|
||||
|
||||
def projective_transform(
|
||||
poses, depths, intrinsics, ii, jj, jacobian=False, return_depth=False
|
||||
):
|
||||
"""map points from ii->jj"""
|
||||
|
||||
# inverse project (pinhole)
|
||||
X0, Jz = iproj(depths[:, ii], intrinsics[:, ii], jacobian=jacobian)
|
||||
|
||||
# transform
|
||||
Gij = poses[:, jj] * poses[:, ii].inv()
|
||||
|
||||
# Gij.data[:, ii == jj] = torch.as_tensor(
|
||||
# [-0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device="cuda"
|
||||
# )
|
||||
X1, Ja = actp(Gij, X0, jacobian=jacobian)
|
||||
|
||||
# project (pinhole)
|
||||
x1, Jp = proj(X1, intrinsics[:, jj], jacobian=jacobian, return_depth=return_depth)
|
||||
|
||||
# exclude points too close to camera
|
||||
valid = ((X1[..., 2] > MIN_DEPTH) & (X0[..., 2] > MIN_DEPTH)).float()
|
||||
valid = valid.unsqueeze(-1)
|
||||
|
||||
if jacobian:
|
||||
# Ji transforms according to dual adjoint
|
||||
Jj = torch.matmul(Jp, Ja)
|
||||
Ji = -Gij[:, :, None, None, None].adjT(Jj)
|
||||
|
||||
Jz = Gij[:, :, None, None] * Jz
|
||||
Jz = torch.matmul(Jp, Jz.unsqueeze(-1))
|
||||
|
||||
return x1, valid, (Ji, Jj, Jz)
|
||||
|
||||
return x1, valid
|
||||
|
||||
|
||||
def induced_flow(poses, disps, intrinsics, ii, jj):
|
||||
"""optical flow induced by camera motion"""
|
||||
|
||||
ht, wd = disps.shape[2:]
|
||||
y, x = torch.meshgrid(
|
||||
torch.arange(ht, device=disps.device, dtype=torch.float),
|
||||
torch.arange(wd, device=disps.device, dtype=torch.float),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
coords0 = torch.stack([x, y], dim=-1)
|
||||
coords1, valid = projective_transform(poses, disps, intrinsics, ii, jj, False)
|
||||
|
||||
return coords1[..., :2] - coords0, valid
|
||||
|
||||
def all_pairs_distance_matrix(poses, beta=2.5):
|
||||
""" compute distance matrix between all pairs of poses """
|
||||
poses = np.array(poses, dtype=np.float32)
|
||||
poses[:,:3] *= beta # scale to balence rot + trans
|
||||
poses = SE3(torch.from_numpy(poses))
|
||||
|
||||
r = (poses[:,None].inv() * poses[None,:]).log()
|
||||
return r.norm(dim=-1).cpu().numpy()
|
||||
|
||||
def pose_matrix_to_quaternion(pose):
|
||||
""" convert 4x4 pose matrix to (t, q) """
|
||||
q = Rotation.from_matrix(pose[..., :3, :3]).as_quat()
|
||||
return np.concatenate([pose[..., :3, 3], q], axis=-1)
|
||||
|
||||
def compute_distance_matrix_flow(poses, disps, intrinsics):
|
||||
""" compute flow magnitude between all pairs of frames """
|
||||
if not isinstance(poses, SE3):
|
||||
poses = torch.from_numpy(poses).float().cuda()[None]
|
||||
poses = SE3(poses).inv()
|
||||
|
||||
disps = torch.from_numpy(disps).float().cuda()[None]
|
||||
intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
||||
|
||||
N = poses.shape[1]
|
||||
|
||||
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
||||
ii = ii.reshape(-1).cuda()
|
||||
jj = jj.reshape(-1).cuda()
|
||||
|
||||
MAX_FLOW = 100.0
|
||||
matrix = np.zeros((N, N), dtype=np.float32)
|
||||
|
||||
s = 2048
|
||||
for i in range(0, ii.shape[0], s):
|
||||
flow1, val1 = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
||||
flow2, val2 = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
|
||||
|
||||
flow = torch.stack([flow1, flow2], dim=2)
|
||||
val = torch.stack([val1, val2], dim=2)
|
||||
|
||||
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
||||
mag = mag.view(mag.shape[1], -1)
|
||||
val = val.view(val.shape[1], -1)
|
||||
|
||||
mag = (mag * val).mean(-1) / val.mean(-1)
|
||||
mag[val.mean(-1) < 0.7] = np.inf
|
||||
|
||||
i1 = ii[i:i+s].cpu().numpy()
|
||||
j1 = jj[i:i+s].cpu().numpy()
|
||||
matrix[i1, j1] = mag.cpu().numpy()
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
|
||||
""" compute flow magnitude between all pairs of frames """
|
||||
# if not isinstance(poses, SE3):
|
||||
# poses = torch.from_numpy(poses).float().cuda()[None]
|
||||
# poses = SE3(poses).inv()
|
||||
|
||||
# disps = torch.from_numpy(disps).float().cuda()[None]
|
||||
# intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
||||
|
||||
N = poses.shape[1]
|
||||
|
||||
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
||||
ii = ii.reshape(-1)
|
||||
jj = jj.reshape(-1)
|
||||
|
||||
MAX_FLOW = 128.0
|
||||
matrix = np.zeros((N, N), dtype=np.float32)
|
||||
|
||||
s = 2048
|
||||
for i in range(0, ii.shape[0], s):
|
||||
flow1a, val1a = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
|
||||
flow1b, val1b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
||||
flow2a, val2a = induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
|
||||
flow2b, val2b = induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
||||
|
||||
flow1 = flow1a + beta * flow1b
|
||||
val1 = val1a * val2b
|
||||
|
||||
flow2 = flow2a + beta * flow2b
|
||||
val2 = val2a * val2b
|
||||
|
||||
flow = torch.stack([flow1, flow2], dim=2)
|
||||
val = torch.stack([val1, val2], dim=2)
|
||||
|
||||
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
||||
mag = mag.view(mag.shape[1], -1)
|
||||
val = val.view(val.shape[1], -1)
|
||||
|
||||
mag = (mag * val).mean(-1) / val.mean(-1)
|
||||
mag[val.mean(-1) < 0.8] = np.inf
|
||||
|
||||
i1 = ii[i:i+s].cpu().numpy()
|
||||
j1 = jj[i:i+s].cpu().numpy()
|
||||
matrix[i1, j1] = mag.cpu().numpy()
|
||||
|
||||
return matrix
|
||||
|
||||
def coords_grid(ht, wd, **kwargs):
|
||||
y, x = torch.meshgrid(
|
||||
torch.arange(ht, dtype=torch.float, **kwargs),
|
||||
torch.arange(wd, dtype=torch.float, **kwargs),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
return torch.stack([x, y], dim=-1)
|
||||
|
||||
|
||||
def iproj(disps, intrinsics, jacobian=False):
|
||||
"""pinhole camera inverse projection"""
|
||||
ht, wd = disps.shape[2:]
|
||||
fx, fy, cx, cy = extract_intrinsics(intrinsics)
|
||||
|
||||
y, x = torch.meshgrid(
|
||||
torch.arange(ht, device=disps.device, dtype=torch.float),
|
||||
torch.arange(wd, device=disps.device, dtype=torch.float),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
i = torch.ones_like(disps)
|
||||
X = (x - cx) / fx
|
||||
Y = (y - cy) / fy
|
||||
pts = torch.stack([X, Y, i, disps], dim=-1)
|
||||
|
||||
if jacobian:
|
||||
J = torch.zeros_like(pts)
|
||||
J[..., -1] = 1.0
|
||||
return pts, J
|
||||
|
||||
return pts, None
|
||||
|
||||
|
||||
def proj(Xs, intrinsics, jacobian=False, return_depth=False):
|
||||
"""pinhole camera projection"""
|
||||
fx, fy, cx, cy = extract_intrinsics(intrinsics)
|
||||
X, Y, Z, D = Xs.unbind(dim=-1)
|
||||
|
||||
Z = torch.where(Z < 0.5 * MIN_DEPTH, torch.ones_like(Z), Z)
|
||||
d = 1.0 / Z
|
||||
|
||||
x = fx * (X * d) + cx
|
||||
y = fy * (Y * d) + cy
|
||||
if return_depth:
|
||||
coords = torch.stack([x, y, D * d], dim=-1)
|
||||
else:
|
||||
coords = torch.stack([x, y], dim=-1)
|
||||
|
||||
if jacobian:
|
||||
B, N, H, W = d.shape
|
||||
o = torch.zeros_like(d)
|
||||
proj_jac = torch.stack(
|
||||
[
|
||||
fx * d,
|
||||
o,
|
||||
-fx * X * d * d,
|
||||
o,
|
||||
o,
|
||||
fy * d,
|
||||
-fy * Y * d * d,
|
||||
o,
|
||||
# o, o, -D*d*d, d,
|
||||
],
|
||||
dim=-1,
|
||||
).view(B, N, H, W, 2, 4)
|
||||
|
||||
return coords, proj_jac
|
||||
|
||||
return coords, None
|
||||
|
||||
|
||||
def actp(Gij, X0, jacobian=False):
|
||||
"""action on point cloud"""
|
||||
X1 = Gij[:, :, None, None] * X0
|
||||
|
||||
if jacobian:
|
||||
X, Y, Z, d = X1.unbind(dim=-1)
|
||||
o = torch.zeros_like(d)
|
||||
B, N, H, W = d.shape
|
||||
|
||||
if isinstance(Gij, SE3):
|
||||
Ja = torch.stack(
|
||||
[
|
||||
d,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
Z,
|
||||
-Y,
|
||||
o,
|
||||
d,
|
||||
o,
|
||||
-Z,
|
||||
o,
|
||||
X,
|
||||
o,
|
||||
o,
|
||||
d,
|
||||
Y,
|
||||
-X,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
],
|
||||
dim=-1,
|
||||
).view(B, N, H, W, 4, 6)
|
||||
|
||||
elif isinstance(Gij, Sim3):
|
||||
Ja = torch.stack(
|
||||
[
|
||||
d,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
Z,
|
||||
-Y,
|
||||
X,
|
||||
o,
|
||||
d,
|
||||
o,
|
||||
-Z,
|
||||
o,
|
||||
X,
|
||||
Y,
|
||||
o,
|
||||
o,
|
||||
d,
|
||||
Y,
|
||||
-X,
|
||||
o,
|
||||
Z,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
o,
|
||||
],
|
||||
dim=-1,
|
||||
).view(B, N, H, W, 4, 7)
|
||||
|
||||
return X1, Ja
|
||||
|
||||
return X1, None
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
out = quat_candidates[
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
||||
].reshape(batch_dim + (4,))
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a unit quaternion to a standard form: one in which the real
|
||||
part is non negative.
|
||||
|
||||
Args:
|
||||
quaternions: Quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Standardized quaternions as tensor of shape (..., 4).
|
||||
"""
|
||||
quaternions = F.normalize(quaternions, p=2, dim=-1)
|
||||
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
||||
|
||||
def umeyama(X, Y):
|
||||
"""
|
||||
Estimates the Sim(3) transformation between `X` and `Y` point sets.
|
||||
|
||||
Estimates c, R and t such as c * R @ X + t ~ Y.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : numpy.array
|
||||
(m, n) shaped numpy array. m is the dimension of the points,
|
||||
n is the number of points in the point set.
|
||||
Y : numpy.array
|
||||
(m, n) shaped numpy array. Indexes should be consistent with `X`.
|
||||
That is, Y[:, i] must be the point corresponding to X[:, i].
|
||||
|
||||
Returns
|
||||
-------
|
||||
c : float
|
||||
Scale factor.
|
||||
R : numpy.array
|
||||
(3, 3) shaped rotation matrix.
|
||||
t : numpy.array
|
||||
(3, 1) shaped translation vector.
|
||||
"""
|
||||
mu_x = X.mean(axis=1).reshape(-1, 1)
|
||||
mu_y = Y.mean(axis=1).reshape(-1, 1)
|
||||
var_x = np.square(X - mu_x).sum(axis=0).mean()
|
||||
cov_xy = ((Y - mu_y) @ (X - mu_x).T) / X.shape[1]
|
||||
U, D, VH = np.linalg.svd(cov_xy)
|
||||
S = np.eye(X.shape[0])
|
||||
if np.linalg.det(U) * np.linalg.det(VH) < 0:
|
||||
S[-1, -1] = -1
|
||||
c = np.trace(np.diag(D) @ S) / var_x
|
||||
R = U @ S @ VH
|
||||
t = mu_y - c * R @ mu_x
|
||||
return c, R, t
|
||||
246
lingbot_map/utils/load_fn.py
Normal file
246
lingbot_map/utils/load_fn.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms as TF
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_and_preprocess_images_square(image_path_list, target_size=1024):
|
||||
"""
|
||||
Load and preprocess images by center padding to square and resizing to target size.
|
||||
Also returns the position information of original pixels after transformation.
|
||||
|
||||
Args:
|
||||
image_path_list (list): List of paths to image files
|
||||
target_size (int, optional): Target size for both width and height. Defaults to 518.
|
||||
|
||||
Returns:
|
||||
tuple: (
|
||||
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, target_size, target_size),
|
||||
torch.Tensor: Array of shape (N, 5) containing [x1, y1, x2, y2, width, height] for each image
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input list is empty
|
||||
"""
|
||||
# Check for empty list
|
||||
if len(image_path_list) == 0:
|
||||
raise ValueError("At least 1 image is required")
|
||||
|
||||
images = []
|
||||
original_coords = [] # Renamed from position_info to be more descriptive
|
||||
to_tensor = TF.ToTensor()
|
||||
|
||||
for image_path in image_path_list:
|
||||
# Open image
|
||||
img = Image.open(image_path)
|
||||
|
||||
# If there's an alpha channel, blend onto white background
|
||||
if img.mode == "RGBA":
|
||||
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
||||
img = Image.alpha_composite(background, img)
|
||||
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Get original dimensions
|
||||
width, height = img.size
|
||||
|
||||
# Make the image square by padding the shorter dimension
|
||||
max_dim = max(width, height)
|
||||
|
||||
# Calculate padding
|
||||
left = (max_dim - width) // 2
|
||||
top = (max_dim - height) // 2
|
||||
|
||||
# Calculate scale factor for resizing
|
||||
scale = target_size / max_dim
|
||||
|
||||
# Calculate final coordinates of original image in target space
|
||||
x1 = left * scale
|
||||
y1 = top * scale
|
||||
x2 = (left + width) * scale
|
||||
y2 = (top + height) * scale
|
||||
|
||||
# Store original image coordinates and scale
|
||||
original_coords.append(np.array([x1, y1, x2, y2, width, height]))
|
||||
|
||||
# Create a new black square image and paste original
|
||||
square_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
||||
square_img.paste(img, (left, top))
|
||||
|
||||
# Resize to target size
|
||||
square_img = square_img.resize((target_size, target_size), Image.Resampling.BICUBIC)
|
||||
|
||||
# Convert to tensor
|
||||
img_tensor = to_tensor(square_img)
|
||||
images.append(img_tensor)
|
||||
|
||||
# Stack all images
|
||||
images = torch.stack(images)
|
||||
original_coords = torch.from_numpy(np.array(original_coords)).float()
|
||||
|
||||
# Add additional dimension if single image to ensure correct shape
|
||||
if len(image_path_list) == 1:
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0)
|
||||
original_coords = original_coords.unsqueeze(0)
|
||||
|
||||
return images, original_coords
|
||||
|
||||
|
||||
def load_and_preprocess_images(image_path_list, fx=None, fy=None, cx=None, cy=None, mode="crop", image_size=512, patch_size=16):
|
||||
"""
|
||||
A quick start function to load and preprocess images for model input.
|
||||
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
|
||||
|
||||
Args:
|
||||
image_path_list (list): List of paths to image files
|
||||
mode (str, optional): Preprocessing mode, either "crop" or "pad".
|
||||
- "crop" (default): Sets width to 518px and center crops height if needed.
|
||||
- "pad": Preserves all pixels by making the largest dimension 518px
|
||||
and padding the smaller dimension to reach a square shape.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input list is empty or if mode is invalid
|
||||
|
||||
Notes:
|
||||
- Images with different dimensions will be padded with white (value=1.0)
|
||||
- A warning is printed when images have different shapes
|
||||
- When mode="crop": The function ensures width=518px while maintaining aspect ratio
|
||||
and height is center-cropped if larger than 518px
|
||||
- When mode="pad": The function ensures the largest dimension is 518px while maintaining aspect ratio
|
||||
and the smaller dimension is padded to reach a square shape (518x518)
|
||||
- Dimensions are adjusted to be divisible by 14 for compatibility with model requirements
|
||||
"""
|
||||
# Check for empty list
|
||||
if len(image_path_list) == 0:
|
||||
raise ValueError("At least 1 image is required")
|
||||
|
||||
|
||||
|
||||
# Validate mode
|
||||
if mode not in ["crop", "pad"]:
|
||||
raise ValueError("Mode must be either 'crop' or 'pad'")
|
||||
|
||||
images = []
|
||||
shapes = set()
|
||||
to_tensor = TF.ToTensor()
|
||||
target_size = image_size
|
||||
|
||||
# First process all images and collect their shapes
|
||||
for i, image_path in enumerate(image_path_list):
|
||||
# Open image
|
||||
img = Image.open(image_path)
|
||||
|
||||
# If there's an alpha channel, blend onto white background:
|
||||
if img.mode == "RGBA":
|
||||
# Create white background
|
||||
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
|
||||
# Alpha composite onto the white background
|
||||
img = Image.alpha_composite(background, img)
|
||||
|
||||
# Now convert to "RGB" (this step assigns white for transparent areas)
|
||||
img = img.convert("RGB")
|
||||
|
||||
width, height = img.size
|
||||
|
||||
if fx is not None:
|
||||
fx[i] = fx[i] * width
|
||||
fy[i] = fy[i] * height
|
||||
cx[i] = cx[i] * width
|
||||
cy[i] = cy[i] * height
|
||||
|
||||
if mode == "pad":
|
||||
# Make the largest dimension 518px while maintaining aspect ratio
|
||||
if width >= height:
|
||||
new_width = target_size
|
||||
new_height = round(height * (new_width / width) / patch_size) * patch_size # Make divisible by 14
|
||||
else:
|
||||
new_height = target_size
|
||||
new_width = round(width * (new_height / height) / patch_size) * patch_size # Make divisible by 14
|
||||
|
||||
else: # mode == "crop"
|
||||
# Original behavior: set width to 518px
|
||||
new_width = target_size
|
||||
# Calculate height maintaining aspect ratio, divisible by 14
|
||||
new_height = round(height * (new_width / width) / patch_size) * patch_size
|
||||
|
||||
# Resize with new dimensions (width, height)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
||||
img = to_tensor(img) # Convert to tensor (0, 1)
|
||||
|
||||
# Center crop height if it's larger than 518 (only in crop mode)
|
||||
if mode == "crop" and new_height > target_size:
|
||||
start_y = (new_height - target_size) // 2
|
||||
img = img[:, start_y : start_y + target_size, :]
|
||||
if fx is not None:
|
||||
fx[i] = fx[i] * new_width / width
|
||||
fy[i] = fy[i] * new_height / height
|
||||
|
||||
cx[i] = img.shape[2] / 2
|
||||
cy[i] = img.shape[1] / 2
|
||||
|
||||
# For pad mode, pad to make a square of target_size x target_size
|
||||
if mode == "pad":
|
||||
h_padding = target_size - img.shape[1]
|
||||
w_padding = target_size - img.shape[2]
|
||||
|
||||
if h_padding > 0 or w_padding > 0:
|
||||
pad_top = h_padding // 2
|
||||
pad_bottom = h_padding - pad_top
|
||||
pad_left = w_padding // 2
|
||||
pad_right = w_padding - pad_left
|
||||
|
||||
# Pad with white (value=1.0)
|
||||
img = torch.nn.functional.pad(
|
||||
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
||||
)
|
||||
|
||||
shapes.add((img.shape[1], img.shape[2]))
|
||||
images.append(img)
|
||||
|
||||
# Check if we have different shapes
|
||||
# In theory our model can also work well with different shapes
|
||||
if len(shapes) > 1:
|
||||
print(f"Warning: Found images with different shapes: {shapes}")
|
||||
# Find maximum dimensions
|
||||
max_height = max(shape[0] for shape in shapes)
|
||||
max_width = max(shape[1] for shape in shapes)
|
||||
|
||||
# Pad images if necessary
|
||||
padded_images = []
|
||||
for img in images:
|
||||
h_padding = max_height - img.shape[1]
|
||||
w_padding = max_width - img.shape[2]
|
||||
|
||||
if h_padding > 0 or w_padding > 0:
|
||||
pad_top = h_padding // 2
|
||||
pad_bottom = h_padding - pad_top
|
||||
pad_left = w_padding // 2
|
||||
pad_right = w_padding - pad_left
|
||||
|
||||
img = torch.nn.functional.pad(
|
||||
img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
|
||||
)
|
||||
padded_images.append(img)
|
||||
images = padded_images
|
||||
|
||||
images = torch.stack(images) # concatenate images
|
||||
|
||||
# Ensure correct shape when single image
|
||||
if len(image_path_list) == 1:
|
||||
# Verify shape is (1, C, H, W)
|
||||
if images.dim() == 3:
|
||||
images = images.unsqueeze(0)
|
||||
if fx is not None:
|
||||
return images, fx, fy, cx, cy
|
||||
return images
|
||||
331
lingbot_map/utils/pose_enc.py
Normal file
331
lingbot_map/utils/pose_enc.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from .rotation import quat_to_mat, mat_to_quat
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import gzip
|
||||
import json
|
||||
import random
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3, closed_form_inverse_se3_general
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(
|
||||
extrinsics, intrinsics, image_size_hw=None, pose_encoding_type="absT_quaR_FoV" # e.g., (256, 512)
|
||||
):
|
||||
"""Convert camera extrinsics and intrinsics to a compact pose encoding.
|
||||
|
||||
This function transforms camera parameters into a unified pose encoding format,
|
||||
which can be used for various downstream tasks like pose prediction or representation.
|
||||
|
||||
Args:
|
||||
extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
|
||||
where B is batch size and S is sequence length.
|
||||
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
|
||||
The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
|
||||
intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
|
||||
Defined in pixels, with format:
|
||||
[[fx, 0, cx],
|
||||
[0, fy, cy],
|
||||
[0, 0, 1]]
|
||||
where fx, fy are focal lengths and (cx, cy) is the principal point
|
||||
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
||||
Required for computing field of view values. For example: (256, 512).
|
||||
pose_encoding_type (str): Type of pose encoding to use. Currently only
|
||||
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded camera pose parameters with shape BxSx9.
|
||||
For "absT_quaR_FoV" type, the 9 dimensions are:
|
||||
- [:3] = absolute translation vector T (3D)
|
||||
- [3:7] = rotation as quaternion quat (4D)
|
||||
- [7:] = field of view (2D)
|
||||
"""
|
||||
|
||||
# extrinsics: BxSx3x4
|
||||
# intrinsics: BxSx3x3
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
||||
T = extrinsics[:, :, :3, 3] # BxSx3
|
||||
|
||||
quat = mat_to_quat(R)
|
||||
# Note the order of h and w here
|
||||
H, W = image_size_hw
|
||||
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
||||
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
||||
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return pose_encoding
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(
|
||||
pose_encoding, image_size_hw=None, pose_encoding_type="absT_quaR_FoV", build_intrinsics=True # e.g., (256, 512)
|
||||
):
|
||||
"""Convert a pose encoding back to camera extrinsics and intrinsics.
|
||||
|
||||
This function performs the inverse operation of extri_intri_to_pose_encoding,
|
||||
reconstructing the full camera parameters from the compact encoding.
|
||||
|
||||
Args:
|
||||
pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
|
||||
where B is batch size and S is sequence length.
|
||||
For "absT_quaR_FoV" type, the 9 dimensions are:
|
||||
- [:3] = absolute translation vector T (3D)
|
||||
- [3:7] = rotation as quaternion quat (4D)
|
||||
- [7:] = field of view (2D)
|
||||
image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
|
||||
Required for reconstructing intrinsics from field of view values.
|
||||
For example: (256, 512).
|
||||
pose_encoding_type (str): Type of pose encoding used. Currently only
|
||||
supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
|
||||
build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
|
||||
If False, only extrinsics are returned and intrinsics will be None.
|
||||
|
||||
Returns:
|
||||
tuple: (extrinsics, intrinsics)
|
||||
- extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
|
||||
In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
|
||||
transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
|
||||
a 3x1 translation vector.
|
||||
- intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
|
||||
or None if build_intrinsics is False. Defined in pixels, with format:
|
||||
[[fx, 0, cx],
|
||||
[0, fy, cy],
|
||||
[0, 0, 1]]
|
||||
where fx, fy are focal lengths and (cx, cy) is the principal point,
|
||||
assumed to be at the center of the image (W/2, H/2).
|
||||
"""
|
||||
|
||||
intrinsics = None
|
||||
|
||||
if pose_encoding_type == "absT_quaR_FoV":
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
fov_w = pose_encoding[..., 8]
|
||||
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
|
||||
if build_intrinsics:
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
||||
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
intrinsics[..., 1, 2] = H / 2
|
||||
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
||||
elif pose_encoding_type == "absT_quaR":
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
|
||||
intrinsics = None
|
||||
|
||||
return extrinsics, intrinsics
|
||||
|
||||
def convert_pt3d_RT_to_opencv(Rot, Trans):
|
||||
"""
|
||||
Convert Point3D extrinsic matrices to OpenCV convention.
|
||||
|
||||
Args:
|
||||
Rot: 3D rotation matrix in Point3D format
|
||||
Trans: 3D translation vector in Point3D format
|
||||
|
||||
Returns:
|
||||
extri_opencv: 3x4 extrinsic matrix in OpenCV format
|
||||
"""
|
||||
rot_pt3d = np.array(Rot)
|
||||
trans_pt3d = np.array(Trans)
|
||||
|
||||
trans_pt3d[:2] *= -1
|
||||
rot_pt3d[:, :2] *= -1
|
||||
rot_pt3d = rot_pt3d.transpose(1, 0)
|
||||
extri_opencv = np.hstack((rot_pt3d, trans_pt3d[:, None]))
|
||||
return extri_opencv
|
||||
|
||||
|
||||
def build_pair_index(N, B=1):
|
||||
"""
|
||||
Build indices for all possible pairs of frames.
|
||||
|
||||
Args:
|
||||
N: Number of frames
|
||||
B: Batch size
|
||||
|
||||
Returns:
|
||||
i1, i2: Indices for all possible pairs
|
||||
"""
|
||||
i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
|
||||
i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
|
||||
return i1, i2
|
||||
|
||||
|
||||
def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
|
||||
"""
|
||||
Calculate rotation angle error between ground truth and predicted rotations.
|
||||
|
||||
Args:
|
||||
rot_gt: Ground truth rotation matrices
|
||||
rot_pred: Predicted rotation matrices
|
||||
batch_size: Batch size for reshaping the result
|
||||
eps: Small value to avoid numerical issues
|
||||
|
||||
Returns:
|
||||
Rotation angle error in degrees
|
||||
"""
|
||||
q_pred = mat_to_quat(rot_pred)
|
||||
q_gt = mat_to_quat(rot_gt)
|
||||
|
||||
loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
|
||||
err_q = torch.arccos(1 - 2 * loss_q)
|
||||
|
||||
rel_rangle_deg = err_q * 180 / np.pi
|
||||
|
||||
if batch_size is not None:
|
||||
rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
|
||||
|
||||
return rel_rangle_deg
|
||||
|
||||
|
||||
def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
|
||||
"""
|
||||
Calculate translation angle error between ground truth and predicted translations.
|
||||
|
||||
Args:
|
||||
tvec_gt: Ground truth translation vectors
|
||||
tvec_pred: Predicted translation vectors
|
||||
batch_size: Batch size for reshaping the result
|
||||
ambiguity: Whether to handle direction ambiguity
|
||||
|
||||
Returns:
|
||||
Translation angle error in degrees
|
||||
"""
|
||||
rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
|
||||
rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
|
||||
|
||||
if ambiguity:
|
||||
rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
|
||||
|
||||
if batch_size is not None:
|
||||
rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
|
||||
|
||||
return rel_tangle_deg
|
||||
|
||||
|
||||
def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
|
||||
"""
|
||||
Normalize the translation vectors and compute the angle between them.
|
||||
|
||||
Args:
|
||||
t_gt: Ground truth translation vectors
|
||||
t: Predicted translation vectors
|
||||
eps: Small value to avoid division by zero
|
||||
default_err: Default error value for invalid cases
|
||||
|
||||
Returns:
|
||||
Angular error between translation vectors in radians
|
||||
"""
|
||||
t_norm = torch.norm(t, dim=1, keepdim=True)
|
||||
t = t / (t_norm + eps)
|
||||
|
||||
t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
|
||||
t_gt = t_gt / (t_gt_norm + eps)
|
||||
|
||||
loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
|
||||
err_t = torch.acos(torch.sqrt(1 - loss_t))
|
||||
|
||||
err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
|
||||
return err_t
|
||||
|
||||
|
||||
def calculate_auc_np(r_error, t_error, max_threshold=30):
|
||||
"""
|
||||
Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
|
||||
|
||||
Args:
|
||||
r_error: numpy array representing R error values (Degree)
|
||||
t_error: numpy array representing T error values (Degree)
|
||||
max_threshold: Maximum threshold value for binning the histogram
|
||||
|
||||
Returns:
|
||||
AUC value and the normalized histogram
|
||||
"""
|
||||
error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
|
||||
max_errors = np.max(error_matrix, axis=1)
|
||||
bins = np.arange(max_threshold + 1)
|
||||
histogram, _ = np.histogram(max_errors, bins=bins)
|
||||
num_pairs = float(len(max_errors))
|
||||
normalized_histogram = histogram.astype(float) / num_pairs
|
||||
return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
|
||||
|
||||
|
||||
def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
|
||||
"""
|
||||
Compute rotation and translation errors between predicted and ground truth poses.
|
||||
This function assumes the input poses are world-to-camera (w2c) transformations.
|
||||
|
||||
Args:
|
||||
pred_se3: Predicted SE(3) transformations (w2c), shape (N, 4, 4)
|
||||
gt_se3: Ground truth SE(3) transformations (w2c), shape (N, 4, 4)
|
||||
num_frames: Number of frames (N)
|
||||
|
||||
Returns:
|
||||
Rotation and translation angle errors in degrees
|
||||
"""
|
||||
pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
|
||||
|
||||
relative_pose_gt = gt_se3[pair_idx_i1].bmm(
|
||||
closed_form_inverse_se3(gt_se3[pair_idx_i2])
|
||||
)
|
||||
relative_pose_pred = pred_se3[pair_idx_i1].bmm(
|
||||
closed_form_inverse_se3(pred_se3[pair_idx_i2])
|
||||
)
|
||||
|
||||
rel_rangle_deg = rotation_angle(
|
||||
relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
|
||||
)
|
||||
rel_tangle_deg = translation_angle(
|
||||
relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
|
||||
)
|
||||
|
||||
return rel_rangle_deg, rel_tangle_deg
|
||||
|
||||
|
||||
def colmap_to_opencv_intrinsics(K):
|
||||
"""
|
||||
Modify camera intrinsics to follow a different convention.
|
||||
Coordinates of the center of the top-left pixels are by default:
|
||||
- (0.5, 0.5) in Colmap
|
||||
- (0,0) in OpenCV
|
||||
"""
|
||||
K = K.copy()
|
||||
K[..., 0, 2] -= 0.5
|
||||
K[..., 1, 2] -= 0.5
|
||||
return K
|
||||
|
||||
def read_camera_parameters(filename):
|
||||
with open(filename) as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
# extrinsics: line [1,5), 4x4 matrix
|
||||
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
|
||||
# intrinsics: line [7-10), 3x3 matrix
|
||||
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
|
||||
|
||||
return intrinsics, extrinsics
|
||||
132
lingbot_map/utils/rotation.py
Normal file
132
lingbot_map/utils/rotation.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from PyTorch3D, https://github.com/facebookresearch/pytorch3d
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Quaternion Order: XYZW or say ijkr, scalar-last
|
||||
|
||||
Convert rotations given as quaternions to rotation matrices.
|
||||
Args:
|
||||
quaternions: quaternions with real part last,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotation matrices as tensor of shape (..., 3, 3).
|
||||
"""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part last, as tensor of shape (..., 4).
|
||||
Quaternion Order: XYZW or say ijkr, scalar-last
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape(batch_dim + (9,)), dim=-1)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[1.0 + m00 + m11 + m22, 1.0 + m00 - m11 - m22, 1.0 - m00 + m11 - m22, 1.0 - m00 - m11 + m22], dim=-1
|
||||
)
|
||||
)
|
||||
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
||||
# the candidate won't be picked.
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4,))
|
||||
|
||||
# Convert from rijk to ijkr
|
||||
out = out[..., [1, 2, 3, 0]]
|
||||
|
||||
out = standardize_quaternion(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a unit quaternion to a standard form: one in which the real
|
||||
part is non negative.
|
||||
|
||||
Args:
|
||||
quaternions: Quaternions with real part last,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Standardized quaternions as tensor of shape (..., 4).
|
||||
"""
|
||||
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
||||
59
lingbot_map/vis/__init__.py
Normal file
59
lingbot_map/vis/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
GCT Visualization Module
|
||||
|
||||
This module provides visualization utilities for 3D reconstruction results:
|
||||
- PointCloudViewer: Interactive point cloud viewer with camera visualization
|
||||
- viser_wrapper: Quick visualization wrapper for predictions
|
||||
- predictions_to_glb: Export predictions to GLB 3D format
|
||||
- Colorization and utility functions
|
||||
|
||||
Usage:
|
||||
from lingbot_map.vis import PointCloudViewer, viser_wrapper, predictions_to_glb
|
||||
|
||||
# Interactive visualization
|
||||
viewer = PointCloudViewer(pred_dict=predictions, port=8080)
|
||||
viewer.run()
|
||||
|
||||
# Quick visualization
|
||||
viser_wrapper(predictions, port=8080)
|
||||
|
||||
# Export to GLB
|
||||
scene = predictions_to_glb(predictions)
|
||||
scene.export("output.glb")
|
||||
"""
|
||||
|
||||
from lingbot_map.vis.point_cloud_viewer import PointCloudViewer
|
||||
from lingbot_map.vis.viser_wrapper import viser_wrapper
|
||||
from lingbot_map.vis.utils import CameraState, colorize, colorize_np, get_vertical_colorbar
|
||||
from lingbot_map.vis.sky_segmentation import (
|
||||
apply_sky_segmentation,
|
||||
download_skyseg_model,
|
||||
load_or_create_sky_masks,
|
||||
segment_sky,
|
||||
)
|
||||
from lingbot_map.vis.glb_export import predictions_to_glb
|
||||
|
||||
__all__ = [
|
||||
# Main viewer
|
||||
"PointCloudViewer",
|
||||
# Quick visualization
|
||||
"viser_wrapper",
|
||||
# GLB export
|
||||
"predictions_to_glb",
|
||||
# Utilities
|
||||
"CameraState",
|
||||
"colorize",
|
||||
"colorize_np",
|
||||
"get_vertical_colorbar",
|
||||
# Sky segmentation
|
||||
"apply_sky_segmentation",
|
||||
"segment_sky",
|
||||
"download_skyseg_model",
|
||||
"load_or_create_sky_masks",
|
||||
]
|
||||
509
lingbot_map/vis/glb_export.py
Normal file
509
lingbot_map/vis/glb_export.py
Normal file
@@ -0,0 +1,509 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
GLB 3D export utilities for GCT predictions.
|
||||
"""
|
||||
|
||||
import os
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import matplotlib
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from lingbot_map.vis.sky_segmentation import (
|
||||
_SKYSEG_INPUT_SIZE,
|
||||
_SKYSEG_SOFT_THRESHOLD,
|
||||
_mask_to_float,
|
||||
_mask_to_uint8,
|
||||
_result_map_to_non_sky_conf,
|
||||
)
|
||||
|
||||
try:
|
||||
import trimesh
|
||||
except ImportError:
|
||||
trimesh = None
|
||||
print("trimesh not found. GLB export will not work.")
|
||||
|
||||
|
||||
def predictions_to_glb(
|
||||
predictions: dict,
|
||||
conf_thres: float = 50.0,
|
||||
filter_by_frames: str = "all",
|
||||
mask_black_bg: bool = False,
|
||||
mask_white_bg: bool = False,
|
||||
show_cam: bool = True,
|
||||
mask_sky: bool = False,
|
||||
target_dir: Optional[str] = None,
|
||||
prediction_mode: str = "Predicted Pointmap",
|
||||
) -> "trimesh.Scene":
|
||||
"""
|
||||
Converts GCT predictions to a 3D scene represented as a GLB file.
|
||||
|
||||
Args:
|
||||
predictions: Dictionary containing model predictions with keys:
|
||||
- world_points: 3D point coordinates (S, H, W, 3)
|
||||
- world_points_conf: Confidence scores (S, H, W)
|
||||
- images: Input images (S, H, W, 3) or (S, 3, H, W)
|
||||
- extrinsic: Camera extrinsic matrices (S, 3, 4)
|
||||
conf_thres: Percentage of low-confidence points to filter out
|
||||
filter_by_frames: Frame filter specification ("all" or frame index)
|
||||
mask_black_bg: Mask out black background pixels
|
||||
mask_white_bg: Mask out white background pixels
|
||||
show_cam: Include camera visualization
|
||||
mask_sky: Apply sky segmentation mask
|
||||
target_dir: Output directory for intermediate files
|
||||
prediction_mode: "Predicted Pointmap" or "Predicted Depthmap"
|
||||
|
||||
Returns:
|
||||
trimesh.Scene: Processed 3D scene containing point cloud and cameras
|
||||
|
||||
Raises:
|
||||
ValueError: If input predictions structure is invalid
|
||||
ImportError: If trimesh is not available
|
||||
"""
|
||||
if trimesh is None:
|
||||
raise ImportError("trimesh is required for GLB export. Install with: pip install trimesh")
|
||||
|
||||
if not isinstance(predictions, dict):
|
||||
raise ValueError("predictions must be a dictionary")
|
||||
|
||||
if conf_thres is None:
|
||||
conf_thres = 10.0
|
||||
|
||||
print("Building GLB scene")
|
||||
|
||||
# Parse frame filter
|
||||
selected_frame_idx = None
|
||||
if filter_by_frames != "all" and filter_by_frames != "All":
|
||||
try:
|
||||
selected_frame_idx = int(filter_by_frames.split(":")[0])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Select prediction source
|
||||
if "Pointmap" in prediction_mode:
|
||||
print("Using Pointmap Branch")
|
||||
if "world_points" in predictions:
|
||||
pred_world_points = predictions["world_points"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"world_points_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
else:
|
||||
print("Warning: world_points not found, falling back to depth-based points")
|
||||
pred_world_points = predictions["world_points_from_depth"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
else:
|
||||
print("Using Depthmap and Camera Branch")
|
||||
pred_world_points = predictions["world_points_from_depth"]
|
||||
pred_world_points_conf = predictions.get(
|
||||
"depth_conf", np.ones_like(pred_world_points[..., 0])
|
||||
)
|
||||
|
||||
images = predictions["images"]
|
||||
camera_matrices = predictions["extrinsic"]
|
||||
|
||||
# Apply sky segmentation if enabled
|
||||
if mask_sky and target_dir is not None:
|
||||
pred_world_points_conf = _apply_sky_mask(
|
||||
pred_world_points_conf, target_dir, images
|
||||
)
|
||||
|
||||
# Apply frame filter
|
||||
if selected_frame_idx is not None:
|
||||
pred_world_points = pred_world_points[selected_frame_idx][None]
|
||||
pred_world_points_conf = pred_world_points_conf[selected_frame_idx][None]
|
||||
images = images[selected_frame_idx][None]
|
||||
camera_matrices = camera_matrices[selected_frame_idx][None]
|
||||
|
||||
# Prepare vertices and colors
|
||||
vertices_3d = pred_world_points.reshape(-1, 3)
|
||||
|
||||
# Handle different image formats
|
||||
if images.ndim == 4 and images.shape[1] == 3: # NCHW format
|
||||
colors_rgb = np.transpose(images, (0, 2, 3, 1))
|
||||
else:
|
||||
colors_rgb = images
|
||||
colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
|
||||
|
||||
# Apply confidence filtering
|
||||
conf = pred_world_points_conf.reshape(-1)
|
||||
conf_threshold = np.percentile(conf, conf_thres) if conf_thres > 0 else 0.0
|
||||
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
||||
|
||||
# Apply background masking
|
||||
if mask_black_bg:
|
||||
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
||||
conf_mask = conf_mask & black_bg_mask
|
||||
|
||||
if mask_white_bg:
|
||||
white_bg_mask = ~(
|
||||
(colors_rgb[:, 0] > 240) &
|
||||
(colors_rgb[:, 1] > 240) &
|
||||
(colors_rgb[:, 2] > 240)
|
||||
)
|
||||
conf_mask = conf_mask & white_bg_mask
|
||||
|
||||
vertices_3d = vertices_3d[conf_mask]
|
||||
colors_rgb = colors_rgb[conf_mask]
|
||||
|
||||
# Handle empty point cloud
|
||||
if vertices_3d is None or np.asarray(vertices_3d).size == 0:
|
||||
vertices_3d = np.array([[1, 0, 0]])
|
||||
colors_rgb = np.array([[255, 255, 255]])
|
||||
scene_scale = 1
|
||||
else:
|
||||
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
||||
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
||||
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
||||
|
||||
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
||||
|
||||
# Build scene
|
||||
scene_3d = trimesh.Scene()
|
||||
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
|
||||
scene_3d.add_geometry(point_cloud_data)
|
||||
|
||||
# Prepare camera matrices
|
||||
num_cameras = len(camera_matrices)
|
||||
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
||||
extrinsics_matrices[:, :3, :4] = camera_matrices
|
||||
extrinsics_matrices[:, 3, 3] = 1
|
||||
|
||||
# Add cameras
|
||||
if show_cam:
|
||||
for i in range(num_cameras):
|
||||
world_to_camera = extrinsics_matrices[i]
|
||||
camera_to_world = np.linalg.inv(world_to_camera)
|
||||
rgba_color = colormap(i / num_cameras)
|
||||
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
||||
integrate_camera_into_scene(scene_3d, camera_to_world, current_color, scene_scale)
|
||||
|
||||
# Align scene
|
||||
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
||||
|
||||
print("GLB Scene built")
|
||||
return scene_3d
|
||||
|
||||
|
||||
def _apply_sky_mask(
|
||||
conf: np.ndarray,
|
||||
target_dir: str,
|
||||
images: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""Apply sky segmentation mask to confidence scores."""
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
print("Warning: onnxruntime not available, skipping sky masking")
|
||||
return conf
|
||||
|
||||
target_dir_images = os.path.join(target_dir, "images")
|
||||
if not os.path.exists(target_dir_images):
|
||||
print(f"Warning: Images directory not found at {target_dir_images}")
|
||||
return conf
|
||||
|
||||
image_list = sorted(os.listdir(target_dir_images))
|
||||
S, H, W = conf.shape if hasattr(conf, "shape") else (len(images), images.shape[1], images.shape[2])
|
||||
|
||||
skyseg_model_path = "skyseg.onnx"
|
||||
if not os.path.exists(skyseg_model_path):
|
||||
print("Downloading skyseg.onnx...")
|
||||
download_file_from_url(
|
||||
"https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx",
|
||||
skyseg_model_path
|
||||
)
|
||||
|
||||
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
||||
sky_mask_list = []
|
||||
|
||||
for i, image_name in enumerate(image_list[:S]):
|
||||
image_filepath = os.path.join(target_dir_images, image_name)
|
||||
mask_filepath = os.path.join(target_dir, "sky_masks", image_name)
|
||||
|
||||
if os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
else:
|
||||
sky_mask = segment_sky(image_filepath, skyseg_session, mask_filepath)
|
||||
|
||||
if sky_mask.shape[0] != H or sky_mask.shape[1] != W:
|
||||
sky_mask = cv2.resize(sky_mask, (W, H), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
sky_mask_list.append(_mask_to_float(sky_mask))
|
||||
|
||||
sky_mask_array = np.array(sky_mask_list)
|
||||
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
||||
return conf * sky_mask_binary
|
||||
|
||||
|
||||
def integrate_camera_into_scene(
|
||||
scene: "trimesh.Scene",
|
||||
transform: np.ndarray,
|
||||
face_colors: Tuple[int, int, int],
|
||||
scene_scale: float,
|
||||
frustum_thickness: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Integrates a camera mesh into the 3D scene.
|
||||
|
||||
Args:
|
||||
scene: The 3D scene to add the camera model
|
||||
transform: Transformation matrix for camera positioning
|
||||
face_colors: RGB color tuple for the camera
|
||||
scene_scale: Scale of the scene
|
||||
frustum_thickness: Multiplier for frustum edge thickness (>1 = thicker)
|
||||
"""
|
||||
cam_width = scene_scale * 0.05
|
||||
cam_height = scene_scale * 0.1
|
||||
|
||||
rot_45_degree = np.eye(4)
|
||||
rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
|
||||
rot_45_degree[2, 3] = -cam_height
|
||||
|
||||
opengl_transform = get_opengl_conversion_matrix()
|
||||
complete_transform = transform @ opengl_transform @ rot_45_degree
|
||||
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
||||
|
||||
# Build thicker frustum by stacking rotated copies
|
||||
slight_rotation = np.eye(4)
|
||||
slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
|
||||
|
||||
shell_scales = [1.0, 0.95]
|
||||
shell_transforms = [np.eye(4), slight_rotation]
|
||||
# Add extra shells for thickness
|
||||
if frustum_thickness > 1.0:
|
||||
n_extra = max(1, int(frustum_thickness - 1))
|
||||
for k in range(1, n_extra + 1):
|
||||
# Progressively rotated and scaled copies
|
||||
angle = 2.0 + k * 2.0
|
||||
scale = 1.0 + k * 0.02
|
||||
rot = np.eye(4)
|
||||
rot[:3, :3] = Rotation.from_euler("z", angle, degrees=True).as_matrix()
|
||||
shell_scales.append(scale)
|
||||
shell_transforms.append(rot)
|
||||
rot_neg = np.eye(4)
|
||||
rot_neg[:3, :3] = Rotation.from_euler("z", -angle, degrees=True).as_matrix()
|
||||
shell_scales.append(scale)
|
||||
shell_transforms.append(rot_neg)
|
||||
|
||||
vertices_parts = []
|
||||
for s, t_mat in zip(shell_scales, shell_transforms):
|
||||
vertices_parts.append(
|
||||
transform_points(t_mat, s * camera_cone_shape.vertices)
|
||||
)
|
||||
vertices_combined = np.concatenate(vertices_parts)
|
||||
vertices_transformed = transform_points(complete_transform, vertices_combined)
|
||||
|
||||
mesh_faces = compute_camera_faces_multi(camera_cone_shape, len(shell_scales))
|
||||
camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
|
||||
camera_mesh.visual.face_colors[:, :3] = face_colors
|
||||
scene.add_geometry(camera_mesh)
|
||||
|
||||
|
||||
def apply_scene_alignment(
|
||||
scene_3d: "trimesh.Scene",
|
||||
extrinsics_matrices: np.ndarray
|
||||
) -> "trimesh.Scene":
|
||||
"""
|
||||
Aligns the 3D scene based on the extrinsics of the first camera.
|
||||
|
||||
Args:
|
||||
scene_3d: The 3D scene to be aligned
|
||||
extrinsics_matrices: Camera extrinsic matrices
|
||||
|
||||
Returns:
|
||||
Aligned 3D scene
|
||||
"""
|
||||
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
||||
|
||||
align_rotation = np.eye(4)
|
||||
align_rotation[:3, :3] = Rotation.from_euler("y", 180, degrees=True).as_matrix()
|
||||
|
||||
initial_transformation = (
|
||||
np.linalg.inv(extrinsics_matrices[0]) @ opengl_conversion_matrix @ align_rotation
|
||||
)
|
||||
scene_3d.apply_transform(initial_transformation)
|
||||
return scene_3d
|
||||
|
||||
|
||||
def get_opengl_conversion_matrix() -> np.ndarray:
|
||||
"""Returns the OpenGL conversion matrix (flips Y and Z axes)."""
|
||||
matrix = np.identity(4)
|
||||
matrix[1, 1] = -1
|
||||
matrix[2, 2] = -1
|
||||
return matrix
|
||||
|
||||
|
||||
def transform_points(
|
||||
transformation: np.ndarray,
|
||||
points: np.ndarray,
|
||||
dim: Optional[int] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Applies a 4x4 transformation to a set of points.
|
||||
|
||||
Args:
|
||||
transformation: Transformation matrix
|
||||
points: Points to be transformed
|
||||
dim: Dimension for reshaping the result
|
||||
|
||||
Returns:
|
||||
Transformed points
|
||||
"""
|
||||
points = np.asarray(points)
|
||||
initial_shape = points.shape[:-1]
|
||||
dim = dim or points.shape[-1]
|
||||
|
||||
transformation = transformation.swapaxes(-1, -2)
|
||||
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
||||
|
||||
return points[..., :dim].reshape(*initial_shape, dim)
|
||||
|
||||
|
||||
def compute_camera_faces(cone_shape: "trimesh.Trimesh") -> np.ndarray:
|
||||
"""Computes the faces for the camera mesh."""
|
||||
faces_list = []
|
||||
num_vertices_cone = len(cone_shape.vertices)
|
||||
|
||||
for face in cone_shape.faces:
|
||||
if 0 in face:
|
||||
continue
|
||||
v1, v2, v3 = face
|
||||
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
||||
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
||||
|
||||
faces_list.extend([
|
||||
(v1, v2, v2_offset),
|
||||
(v1, v1_offset, v3),
|
||||
(v3_offset, v2, v3),
|
||||
(v1, v2, v2_offset_2),
|
||||
(v1, v1_offset_2, v3),
|
||||
(v3_offset_2, v2, v3),
|
||||
])
|
||||
|
||||
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
||||
return np.array(faces_list)
|
||||
|
||||
|
||||
def compute_camera_faces_multi(cone_shape: "trimesh.Trimesh", num_shells: int) -> np.ndarray:
|
||||
"""Computes faces for a camera mesh with multiple shells (for thicker frustums).
|
||||
|
||||
Connects each consecutive pair of vertex shells to form the frustum edges.
|
||||
"""
|
||||
faces_list = []
|
||||
nv = len(cone_shape.vertices)
|
||||
|
||||
for s in range(num_shells - 1):
|
||||
off_a = s * nv
|
||||
off_b = (s + 1) * nv
|
||||
for face in cone_shape.faces:
|
||||
if 0 in face:
|
||||
continue
|
||||
v1, v2, v3 = face
|
||||
faces_list.extend([
|
||||
(v1 + off_a, v2 + off_a, v2 + off_b),
|
||||
(v1 + off_a, v1 + off_b, v3 + off_a),
|
||||
(v3 + off_b, v2 + off_a, v3 + off_a),
|
||||
])
|
||||
|
||||
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
||||
return np.array(faces_list)
|
||||
|
||||
|
||||
def segment_sky(
|
||||
image_path: str,
|
||||
onnx_session,
|
||||
mask_filename: str
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segments sky from an image using an ONNX model.
|
||||
|
||||
Args:
|
||||
image_path: Path to input image
|
||||
onnx_session: ONNX runtime session with loaded model
|
||||
mask_filename: Path to save the output mask
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1]
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
result_map = run_skyseg(onnx_session, _SKYSEG_INPUT_SIZE, image)
|
||||
result_map_original = cv2.resize(
|
||||
result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
output_mask = _result_map_to_non_sky_conf(result_map_original)
|
||||
|
||||
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
||||
cv2.imwrite(mask_filename, _mask_to_uint8(output_mask))
|
||||
return output_mask
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
input_size: Tuple[int, int],
|
||||
image: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Runs sky segmentation inference using ONNX model.
|
||||
|
||||
Args:
|
||||
onnx_session: ONNX runtime session
|
||||
input_size: Target size for model input (width, height)
|
||||
image: Input image in BGR format
|
||||
|
||||
Returns:
|
||||
Segmentation mask
|
||||
"""
|
||||
temp_image = copy.deepcopy(image)
|
||||
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
||||
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
||||
x = np.array(x, dtype=np.float32)
|
||||
mean = [0.485, 0.456, 0.406]
|
||||
std = [0.229, 0.224, 0.225]
|
||||
x = (x / 255 - mean) / std
|
||||
x = x.transpose(2, 0, 1)
|
||||
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
||||
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
output_name = onnx_session.get_outputs()[0].name
|
||||
onnx_result = onnx_session.run([output_name], {input_name: x})
|
||||
|
||||
onnx_result = np.array(onnx_result).squeeze()
|
||||
min_value = np.min(onnx_result)
|
||||
max_value = np.max(onnx_result)
|
||||
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
||||
onnx_result *= 255
|
||||
return onnx_result.astype("uint8")
|
||||
|
||||
|
||||
def download_file_from_url(url: str, filename: str):
|
||||
"""Downloads a file from a URL, handling redirects."""
|
||||
import requests
|
||||
|
||||
try:
|
||||
response = requests.get(url, allow_redirects=False)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 302:
|
||||
redirect_url = response.headers["Location"]
|
||||
response = requests.get(redirect_url, stream=True)
|
||||
response.raise_for_status()
|
||||
else:
|
||||
print(f"Unexpected status code: {response.status_code}")
|
||||
return
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Downloaded {filename} successfully.")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error downloading file: {e}")
|
||||
1780
lingbot_map/vis/point_cloud_viewer.py
Normal file
1780
lingbot_map/vis/point_cloud_viewer.py
Normal file
File diff suppressed because it is too large
Load Diff
473
lingbot_map/vis/sky_segmentation.py
Normal file
473
lingbot_map/vis/sky_segmentation.py
Normal file
@@ -0,0 +1,473 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Sky segmentation utilities for filtering sky points from point clouds.
|
||||
"""
|
||||
|
||||
import glob
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
try:
|
||||
import onnxruntime
|
||||
except ImportError:
|
||||
onnxruntime = None
|
||||
print("onnxruntime not found. Sky segmentation may not work.")
|
||||
|
||||
|
||||
_SKYSEG_INPUT_SIZE = (320, 320)
|
||||
_SKYSEG_SOFT_THRESHOLD = 0.1
|
||||
_SKYSEG_CACHE_VERSION = "imagenet_norm_softmap_inverted_v3"
|
||||
|
||||
|
||||
def _get_cache_version_path(sky_mask_dir: str) -> str:
|
||||
return os.path.join(sky_mask_dir, ".skyseg_cache_version")
|
||||
|
||||
|
||||
def _prepare_sky_mask_cache(sky_mask_dir: Optional[str]) -> bool:
|
||||
if sky_mask_dir is None:
|
||||
return False
|
||||
|
||||
os.makedirs(sky_mask_dir, exist_ok=True)
|
||||
version_path = _get_cache_version_path(sky_mask_dir)
|
||||
refresh_cache = True
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
refresh_cache = f.read().strip() != _SKYSEG_CACHE_VERSION
|
||||
|
||||
if refresh_cache:
|
||||
print(
|
||||
f"Sky mask cache at {sky_mask_dir} uses an older format; "
|
||||
"regenerating masks with ImageNet-normalized skyseg input"
|
||||
)
|
||||
with open(version_path, "w", encoding="utf-8") as f:
|
||||
f.write(_SKYSEG_CACHE_VERSION)
|
||||
|
||||
return refresh_cache
|
||||
|
||||
|
||||
def run_skyseg(
|
||||
onnx_session,
|
||||
input_size: Tuple[int, int],
|
||||
image: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run ONNX sky segmentation on a BGR image and return an 8-bit score map.
|
||||
"""
|
||||
resize_image = cv2.resize(image, dsize=(input_size[0], input_size[1]))
|
||||
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB).astype(np.float32)
|
||||
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||||
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||||
x = (x / 255.0 - mean) / std
|
||||
x = x.transpose(2, 0, 1)
|
||||
x = x.reshape(-1, 3, input_size[1], input_size[0]).astype("float32")
|
||||
|
||||
input_name = onnx_session.get_inputs()[0].name
|
||||
output_name = onnx_session.get_outputs()[0].name
|
||||
onnx_result = onnx_session.run([output_name], {input_name: x})
|
||||
|
||||
onnx_result = np.array(onnx_result).squeeze()
|
||||
min_value = np.min(onnx_result)
|
||||
max_value = np.max(onnx_result)
|
||||
denom = max(max_value - min_value, 1e-8)
|
||||
onnx_result = (onnx_result - min_value) / denom
|
||||
onnx_result *= 255.0
|
||||
return onnx_result.astype(np.uint8)
|
||||
|
||||
|
||||
def _mask_to_float(mask: np.ndarray) -> np.ndarray:
|
||||
mask = mask.astype(np.float32)
|
||||
if mask.size == 0:
|
||||
return mask
|
||||
return np.clip(mask, 0.0, 1.0)
|
||||
|
||||
|
||||
def _mask_to_uint8(mask: np.ndarray) -> np.ndarray:
|
||||
mask = np.asarray(mask)
|
||||
if mask.dtype == np.uint8:
|
||||
return mask
|
||||
mask = mask.astype(np.float32)
|
||||
if mask.size > 0 and mask.max() <= 1.0:
|
||||
mask = mask * 255.0
|
||||
return np.clip(mask, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
|
||||
def _result_map_to_non_sky_conf(result_map: np.ndarray) -> np.ndarray:
|
||||
# The raw skyseg map is higher on sky and lower on non-sky.
|
||||
return 1.0 - _mask_to_float(result_map)
|
||||
|
||||
|
||||
def segment_sky_from_array(
|
||||
image: np.ndarray,
|
||||
skyseg_session,
|
||||
target_h: int,
|
||||
target_w: int
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segment sky from an image array using ONNX model.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (H, W, 3) or (3, H, W), values in [0, 1] or [0, 255]
|
||||
skyseg_session: ONNX runtime inference session
|
||||
target_h: Target output height
|
||||
target_w: Target output width
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1].
|
||||
"""
|
||||
image_rgb = _image_to_rgb_uint8(image)
|
||||
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
|
||||
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image_bgr)
|
||||
result_map = cv2.resize(result_map, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||||
return _result_map_to_non_sky_conf(result_map)
|
||||
|
||||
|
||||
def segment_sky(
|
||||
image_path: str,
|
||||
skyseg_session,
|
||||
output_path: Optional[str] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Segment sky from an image using ONNX model.
|
||||
|
||||
Args:
|
||||
image_path: Path to the input image
|
||||
skyseg_session: ONNX runtime inference session
|
||||
output_path: Optional path to save the mask
|
||||
|
||||
Returns:
|
||||
Continuous non-sky confidence map in [0, 1].
|
||||
"""
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to read image: {image_path}")
|
||||
|
||||
result_map = run_skyseg(skyseg_session, _SKYSEG_INPUT_SIZE, image)
|
||||
result_map = cv2.resize(result_map, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
|
||||
mask = _result_map_to_non_sky_conf(result_map)
|
||||
|
||||
if output_path is not None:
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
cv2.imwrite(output_path, _mask_to_uint8(mask))
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def _list_image_files(image_folder: str) -> list[str]:
|
||||
image_files = sorted(glob.glob(os.path.join(image_folder, "*")))
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif", ".webp"}
|
||||
return [f for f in image_files if os.path.splitext(f.lower())[1] in image_extensions]
|
||||
|
||||
|
||||
def _image_to_rgb_uint8(image: np.ndarray) -> np.ndarray:
|
||||
if image.ndim == 3 and image.shape[0] == 3 and image.shape[-1] != 3:
|
||||
image = image.transpose(1, 2, 0)
|
||||
|
||||
if image.ndim != 3 or image.shape[2] != 3:
|
||||
raise ValueError(f"Expected image with shape (H, W, 3) or (3, H, W), got {image.shape}")
|
||||
|
||||
if image.dtype != np.uint8:
|
||||
image = image.astype(np.float32)
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = np.clip(image, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def _get_mask_filename(image_paths: Optional[list[str]], index: int) -> str:
|
||||
if image_paths is not None and index < len(image_paths):
|
||||
return os.path.basename(image_paths[index])
|
||||
return f"frame_{index:06d}.png"
|
||||
|
||||
|
||||
def _save_sky_mask_visualization(
|
||||
image: np.ndarray,
|
||||
sky_mask: np.ndarray,
|
||||
output_path: str,
|
||||
) -> None:
|
||||
image_rgb = _image_to_rgb_uint8(image)
|
||||
if sky_mask.shape[:2] != image_rgb.shape[:2]:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(image_rgb.shape[1], image_rgb.shape[0]),
|
||||
interpolation=cv2.INTER_NEAREST,
|
||||
)
|
||||
|
||||
mask_uint8 = _mask_to_uint8(sky_mask)
|
||||
mask_rgb = np.repeat(mask_uint8[..., None], 3, axis=2)
|
||||
overlay = image_rgb.astype(np.float32).copy()
|
||||
sky_pixels = _mask_to_float(sky_mask) <= _SKYSEG_SOFT_THRESHOLD
|
||||
overlay[sky_pixels] = overlay[sky_pixels] * 0.35 + np.array([255, 64, 64], dtype=np.float32) * 0.65
|
||||
overlay = np.clip(overlay, 0.0, 255.0).astype(np.uint8)
|
||||
|
||||
panel = np.concatenate([image_rgb, mask_rgb, overlay], axis=1)
|
||||
output_dir = os.path.dirname(output_path)
|
||||
if output_dir:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
cv2.imwrite(output_path, cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))
|
||||
|
||||
|
||||
def load_or_create_sky_masks(
|
||||
image_folder: Optional[str] = None,
|
||||
image_paths: Optional[list[str]] = None,
|
||||
images: Optional[np.ndarray] = None,
|
||||
skyseg_model_path: str = "skyseg.onnx",
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
target_shape: Optional[Tuple[int, int]] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Load cached sky masks or generate them with the ONNX model.
|
||||
|
||||
Args:
|
||||
image_folder: Folder containing input images.
|
||||
image_paths: Optional explicit image file list, in the exact order to process.
|
||||
images: Optional image array with shape (S, 3, H, W) or (S, H, W, 3).
|
||||
skyseg_model_path: Path to the sky segmentation ONNX model.
|
||||
sky_mask_dir: Optional directory for cached raw masks.
|
||||
sky_mask_visualization_dir: Optional directory for side-by-side visualizations.
|
||||
target_shape: Optional output mask shape (H, W) after resizing.
|
||||
num_frames: Optional maximum number of frames to process.
|
||||
|
||||
Returns:
|
||||
Sky masks with shape (S, H, W), or None if sky segmentation could not run.
|
||||
"""
|
||||
if onnxruntime is None:
|
||||
print("Warning: onnxruntime not available, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if image_folder is None and image_paths is None and images is None:
|
||||
print("Warning: Neither image_folder/image_paths nor images provided, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if not os.path.exists(skyseg_model_path):
|
||||
print(f"Sky segmentation model not found at {skyseg_model_path}, downloading...")
|
||||
try:
|
||||
download_skyseg_model(skyseg_model_path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to download sky segmentation model: {e}")
|
||||
return None
|
||||
|
||||
skyseg_session = onnxruntime.InferenceSession(skyseg_model_path)
|
||||
sky_masks = []
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
os.makedirs(sky_mask_visualization_dir, exist_ok=True)
|
||||
print(f"Saving sky mask visualizations to {sky_mask_visualization_dir}")
|
||||
|
||||
if images is not None:
|
||||
if image_paths is None and image_folder is not None:
|
||||
image_paths = _list_image_files(image_folder)
|
||||
|
||||
num_images = images.shape[0]
|
||||
if num_frames is not None:
|
||||
num_images = min(num_images, num_frames)
|
||||
if image_paths is not None:
|
||||
image_paths = image_paths[:num_images]
|
||||
|
||||
if sky_mask_dir is None and image_folder is not None:
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image array...")
|
||||
for i in tqdm(range(num_images)):
|
||||
image_rgb = _image_to_rgb_uint8(images[i])
|
||||
image_h, image_w = image_rgb.shape[:2]
|
||||
image_name = _get_mask_filename(image_paths, i)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name) if sky_mask_dir is not None else None
|
||||
|
||||
if mask_filepath is not None and not refresh_cache and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
elif sky_mask.shape[:2] != (image_h, image_w):
|
||||
print(
|
||||
f"Cached sky mask shape {sky_mask.shape[:2]} does not match resized image "
|
||||
f"shape {(image_h, image_w)} for {image_name}; regenerating it"
|
||||
)
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
else:
|
||||
sky_mask = segment_sky_from_array(image_rgb, skyseg_session, image_h, image_w)
|
||||
if mask_filepath is not None:
|
||||
cv2.imwrite(mask_filepath, _mask_to_uint8(sky_mask))
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
_save_sky_mask_visualization(
|
||||
image_rgb,
|
||||
sky_mask,
|
||||
os.path.join(sky_mask_visualization_dir, image_name),
|
||||
)
|
||||
|
||||
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(target_shape[1], target_shape[0]),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
sky_masks.append(_mask_to_float(sky_mask))
|
||||
|
||||
else:
|
||||
if image_paths is None and image_folder is not None:
|
||||
image_paths = _list_image_files(image_folder)
|
||||
|
||||
if images is None and image_paths is not None:
|
||||
if len(image_paths) == 0:
|
||||
print("Warning: No image files provided, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
if num_frames is not None:
|
||||
image_paths = image_paths[:num_frames]
|
||||
|
||||
if sky_mask_dir is None:
|
||||
if image_folder is None:
|
||||
image_folder = os.path.dirname(image_paths[0])
|
||||
sky_mask_dir = image_folder.rstrip("/") + "_sky_masks"
|
||||
refresh_cache = _prepare_sky_mask_cache(sky_mask_dir)
|
||||
|
||||
print("Generating sky masks from image files...")
|
||||
for image_path in tqdm(image_paths):
|
||||
image_name = os.path.basename(image_path)
|
||||
mask_filepath = os.path.join(sky_mask_dir, image_name)
|
||||
|
||||
if not refresh_cache and os.path.exists(mask_filepath):
|
||||
sky_mask = cv2.imread(mask_filepath, cv2.IMREAD_GRAYSCALE)
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to read cached sky mask {mask_filepath}, regenerating it")
|
||||
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
||||
else:
|
||||
sky_mask = segment_sky(image_path, skyseg_session, mask_filepath)
|
||||
|
||||
if sky_mask is None:
|
||||
print(f"Warning: Failed to produce sky mask for {image_path}, skipping frame")
|
||||
continue
|
||||
|
||||
if sky_mask_visualization_dir is not None:
|
||||
image_bgr = cv2.imread(image_path)
|
||||
if image_bgr is not None:
|
||||
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
|
||||
_save_sky_mask_visualization(
|
||||
image_rgb,
|
||||
sky_mask,
|
||||
os.path.join(sky_mask_visualization_dir, image_name),
|
||||
)
|
||||
|
||||
if target_shape is not None and sky_mask.shape[:2] != target_shape:
|
||||
sky_mask = cv2.resize(
|
||||
sky_mask,
|
||||
(target_shape[1], target_shape[0]),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
sky_masks.append(_mask_to_float(sky_mask))
|
||||
|
||||
if len(sky_masks) == 0:
|
||||
print("Warning: No sky masks generated, skipping sky segmentation")
|
||||
return None
|
||||
|
||||
try:
|
||||
return np.stack(sky_masks, axis=0)
|
||||
except ValueError:
|
||||
return np.array(sky_masks, dtype=object)
|
||||
|
||||
|
||||
def apply_sky_segmentation(
|
||||
conf: np.ndarray,
|
||||
image_folder: Optional[str] = None,
|
||||
image_paths: Optional[list[str]] = None,
|
||||
images: Optional[np.ndarray] = None,
|
||||
skyseg_model_path: str = "skyseg.onnx",
|
||||
sky_mask_dir: Optional[str] = None,
|
||||
sky_mask_visualization_dir: Optional[str] = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply sky segmentation to confidence scores.
|
||||
|
||||
Args:
|
||||
conf: Confidence scores with shape (S, H, W)
|
||||
image_folder: Path to the folder containing input images (optional if images provided)
|
||||
image_paths: Optional explicit image file list in processing order
|
||||
images: Image array with shape (S, 3, H, W) or (S, H, W, 3) (optional if image_folder provided)
|
||||
skyseg_model_path: Path to the sky segmentation ONNX model
|
||||
sky_mask_dir: Optional directory for cached raw masks
|
||||
sky_mask_visualization_dir: Optional directory for side-by-side mask visualization images
|
||||
|
||||
Returns:
|
||||
Updated confidence scores with sky regions masked out
|
||||
"""
|
||||
S, H, W = conf.shape
|
||||
|
||||
sky_mask_array = load_or_create_sky_masks(
|
||||
image_folder=image_folder,
|
||||
image_paths=image_paths,
|
||||
images=images,
|
||||
skyseg_model_path=skyseg_model_path,
|
||||
sky_mask_dir=sky_mask_dir,
|
||||
sky_mask_visualization_dir=sky_mask_visualization_dir,
|
||||
target_shape=(H, W),
|
||||
num_frames=S,
|
||||
)
|
||||
if sky_mask_array is None:
|
||||
return conf
|
||||
|
||||
if sky_mask_array.shape[0] < S:
|
||||
print(
|
||||
f"Warning: Only {sky_mask_array.shape[0]} sky masks generated for {S} frames; "
|
||||
"leaving the remaining frames unmasked"
|
||||
)
|
||||
padded = np.zeros((S, H, W), dtype=sky_mask_array.dtype)
|
||||
padded[: sky_mask_array.shape[0]] = sky_mask_array
|
||||
sky_mask_array = padded
|
||||
elif sky_mask_array.shape[0] > S:
|
||||
sky_mask_array = sky_mask_array[:S]
|
||||
|
||||
sky_mask_binary = (sky_mask_array > _SKYSEG_SOFT_THRESHOLD).astype(np.float32)
|
||||
conf = conf * sky_mask_binary
|
||||
|
||||
print("Sky segmentation applied successfully")
|
||||
return conf
|
||||
|
||||
|
||||
def download_skyseg_model(output_path: str = "skyseg.onnx") -> str:
|
||||
"""
|
||||
Download sky segmentation model from HuggingFace.
|
||||
|
||||
Args:
|
||||
output_path: Path to save the model
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
import requests
|
||||
|
||||
url = "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx"
|
||||
|
||||
print(f"Downloading sky segmentation model from {url}...")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
print(f"Model saved to {output_path}")
|
||||
return output_path
|
||||
206
lingbot_map/vis/utils.py
Normal file
206
lingbot_map/vis/utils.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Visualization utility functions for colorization and color bars.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import matplotlib.cm as cm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CameraState:
|
||||
"""Camera state for rendering."""
|
||||
fov: float
|
||||
aspect: float
|
||||
c2w: np.ndarray
|
||||
|
||||
def get_K(self, img_wh: Tuple[int, int]) -> np.ndarray:
|
||||
"""Get camera intrinsic matrix from FOV and image size."""
|
||||
W, H = img_wh
|
||||
focal_length = H / 2.0 / np.tan(self.fov / 2.0)
|
||||
K = np.array([
|
||||
[focal_length, 0.0, W / 2.0],
|
||||
[0.0, focal_length, H / 2.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
])
|
||||
return K
|
||||
|
||||
|
||||
def get_vertical_colorbar(
|
||||
h: int,
|
||||
vmin: float,
|
||||
vmax: float,
|
||||
cmap_name: str = "jet",
|
||||
label: Optional[str] = None,
|
||||
cbar_precision: int = 2
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Create a vertical colorbar image.
|
||||
|
||||
Args:
|
||||
h: Height in pixels
|
||||
vmin: Minimum value
|
||||
vmax: Maximum value
|
||||
cmap_name: Colormap name
|
||||
label: Optional label for the colorbar
|
||||
cbar_precision: Decimal precision for tick labels
|
||||
|
||||
Returns:
|
||||
Colorbar image as numpy array (H, W, 3)
|
||||
"""
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
import matplotlib as mpl
|
||||
|
||||
fig = Figure(figsize=(2, 8), dpi=100)
|
||||
fig.subplots_adjust(right=1.5)
|
||||
canvas = FigureCanvasAgg(fig)
|
||||
|
||||
ax = fig.add_subplot(111)
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
|
||||
|
||||
tick_cnt = 6
|
||||
tick_loc = np.linspace(vmin, vmax, tick_cnt)
|
||||
cb1 = mpl.colorbar.ColorbarBase(
|
||||
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
|
||||
)
|
||||
|
||||
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
|
||||
if cbar_precision == 0:
|
||||
tick_label = [x[:-2] for x in tick_label]
|
||||
|
||||
cb1.set_ticklabels(tick_label)
|
||||
cb1.ax.tick_params(labelsize=18, rotation=0)
|
||||
if label is not None:
|
||||
cb1.set_label(label)
|
||||
|
||||
canvas.draw()
|
||||
s, (width, height) = canvas.print_to_buffer()
|
||||
|
||||
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
|
||||
im = im[:, :, :3].astype(np.float32) / 255.0
|
||||
|
||||
if h != im.shape[0]:
|
||||
w = int(im.shape[1] / im.shape[0] * h)
|
||||
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return im
|
||||
|
||||
|
||||
def colorize_np(
|
||||
x: np.ndarray,
|
||||
cmap_name: str = "jet",
|
||||
mask: Optional[np.ndarray] = None,
|
||||
range: Optional[Tuple[float, float]] = None,
|
||||
append_cbar: bool = False,
|
||||
cbar_in_image: bool = False,
|
||||
cbar_precision: int = 2,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Turn a grayscale image into a color image.
|
||||
|
||||
Args:
|
||||
x: Input grayscale image [H, W]
|
||||
cmap_name: Colormap name
|
||||
mask: Optional mask image [H, W]
|
||||
range: Value range for scaling [min, max], automatic if None
|
||||
append_cbar: Whether to append colorbar
|
||||
cbar_in_image: Put colorbar inside image
|
||||
cbar_precision: Colorbar tick precision
|
||||
|
||||
Returns:
|
||||
Colorized image [H, W, 3]
|
||||
"""
|
||||
if range is not None:
|
||||
vmin, vmax = range
|
||||
elif mask is not None:
|
||||
vmin = np.min(x[mask][np.nonzero(x[mask])])
|
||||
vmax = np.max(x[mask])
|
||||
x[np.logical_not(mask)] = vmin
|
||||
else:
|
||||
vmin, vmax = np.percentile(x, (1, 100))
|
||||
vmax += 1e-6
|
||||
|
||||
x = np.clip(x, vmin, vmax)
|
||||
x = (x - vmin) / (vmax - vmin)
|
||||
|
||||
cmap = cm.get_cmap(cmap_name)
|
||||
x_new = cmap(x)[:, :, :3]
|
||||
|
||||
if mask is not None:
|
||||
mask = np.float32(mask[:, :, np.newaxis])
|
||||
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
|
||||
|
||||
cbar = get_vertical_colorbar(
|
||||
h=x.shape[0],
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
cmap_name=cmap_name,
|
||||
cbar_precision=cbar_precision,
|
||||
)
|
||||
|
||||
if append_cbar:
|
||||
if cbar_in_image:
|
||||
x_new[:, -cbar.shape[1]:, :] = cbar
|
||||
else:
|
||||
x_new = np.concatenate(
|
||||
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
|
||||
)
|
||||
return x_new
|
||||
else:
|
||||
return x_new
|
||||
|
||||
|
||||
def colorize(
|
||||
x: torch.Tensor,
|
||||
cmap_name: str = "jet",
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
range: Optional[Tuple[float, float]] = None,
|
||||
append_cbar: bool = False,
|
||||
cbar_in_image: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Turn a grayscale image into a color image (PyTorch tensor version).
|
||||
|
||||
Args:
|
||||
x: Grayscale image tensor [H, W] or [B, H, W]
|
||||
cmap_name: Colormap name
|
||||
mask: Optional mask tensor [H, W] or [B, H, W]
|
||||
range: Value range for scaling
|
||||
append_cbar: Whether to append colorbar
|
||||
cbar_in_image: Put colorbar inside image
|
||||
|
||||
Returns:
|
||||
Colorized tensor
|
||||
"""
|
||||
device = x.device
|
||||
x = x.cpu().numpy()
|
||||
if mask is not None:
|
||||
mask = mask.cpu().numpy() > 0.99
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
|
||||
if x.ndim == 2:
|
||||
x = x[None]
|
||||
if mask is not None:
|
||||
mask = mask[None]
|
||||
|
||||
out = []
|
||||
for x_ in x:
|
||||
if mask is not None:
|
||||
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
|
||||
|
||||
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
|
||||
out.append(torch.from_numpy(x_).to(device).float())
|
||||
out = torch.stack(out).squeeze(0)
|
||||
return out
|
||||
248
lingbot_map/vis/viser_wrapper.py
Normal file
248
lingbot_map/vis/viser_wrapper.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Quick visualization wrapper for GCT predictions using Viser.
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import viser
|
||||
import viser.transforms as tf
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
||||
from lingbot_map.vis.sky_segmentation import apply_sky_segmentation
|
||||
|
||||
|
||||
def viser_wrapper(
|
||||
pred_dict: dict,
|
||||
port: int = 8080,
|
||||
init_conf_threshold: float = 50.0,
|
||||
use_point_map: bool = False,
|
||||
background_mode: bool = False,
|
||||
mask_sky: bool = False,
|
||||
image_folder: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Visualize predicted 3D points and camera poses with viser.
|
||||
|
||||
This is a simplified wrapper for quick visualization without the full
|
||||
PointCloudViewer controls.
|
||||
|
||||
Args:
|
||||
pred_dict: Dictionary containing predictions with keys:
|
||||
- images: (S, 3, H, W) - Input images
|
||||
- world_points: (S, H, W, 3)
|
||||
- world_points_conf: (S, H, W)
|
||||
- depth: (S, H, W, 1)
|
||||
- depth_conf: (S, H, W)
|
||||
- extrinsic: (S, 3, 4)
|
||||
- intrinsic: (S, 3, 3)
|
||||
port: Port number for the viser server
|
||||
init_conf_threshold: Initial percentage of low-confidence points to filter out
|
||||
use_point_map: Whether to visualize world_points or use depth-based points
|
||||
background_mode: Whether to run the server in background thread
|
||||
mask_sky: Whether to apply sky segmentation to filter out sky points
|
||||
image_folder: Path to the folder containing input images (for sky segmentation)
|
||||
|
||||
Returns:
|
||||
viser.ViserServer: The viser server instance
|
||||
"""
|
||||
print(f"Starting viser server on port {port}")
|
||||
|
||||
server = viser.ViserServer(host="0.0.0.0", port=port)
|
||||
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
|
||||
|
||||
# Unpack prediction dict
|
||||
images = pred_dict["images"] # (S, 3, H, W)
|
||||
world_points_map = pred_dict["world_points"] # (S, H, W, 3)
|
||||
conf_map = pred_dict["world_points_conf"] # (S, H, W)
|
||||
|
||||
depth_map = pred_dict["depth"] # (S, H, W, 1)
|
||||
depth_conf = pred_dict["depth_conf"] # (S, H, W)
|
||||
|
||||
extrinsics_cam = pred_dict["extrinsic"] # (S, 3, 4)
|
||||
intrinsics_cam = pred_dict["intrinsic"] # (S, 3, 3)
|
||||
|
||||
# Compute world points from depth if not using the precomputed point map
|
||||
if not use_point_map:
|
||||
world_points = unproject_depth_map_to_point_map(depth_map, extrinsics_cam, intrinsics_cam)
|
||||
conf = depth_conf
|
||||
else:
|
||||
world_points = world_points_map
|
||||
conf = conf_map
|
||||
|
||||
# Apply sky segmentation if enabled
|
||||
if mask_sky and image_folder is not None:
|
||||
conf = apply_sky_segmentation(conf, image_folder)
|
||||
|
||||
# Convert images from (S, 3, H, W) to (S, H, W, 3)
|
||||
colors = images.transpose(0, 2, 3, 1) # now (S, H, W, 3)
|
||||
shape = world_points.shape
|
||||
S: int = shape[0]
|
||||
H: int = shape[1]
|
||||
W: int = shape[2]
|
||||
|
||||
# Flatten
|
||||
points = world_points.reshape(-1, 3)
|
||||
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
|
||||
conf_flat = conf.reshape(-1)
|
||||
|
||||
# Random sample points if too many
|
||||
indices = None
|
||||
if points.shape[0] > 6000000:
|
||||
print(f"Too many points ({points.shape[0]}), randomly sampling 6M points")
|
||||
indices = np.random.choice(points.shape[0], size=6000000, replace=False)
|
||||
points = points[indices]
|
||||
colors_flat = colors_flat[indices]
|
||||
conf_flat = conf_flat[indices]
|
||||
|
||||
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam)
|
||||
cam_to_world = cam_to_world_mat[:, :3, :]
|
||||
|
||||
# Compute scene center and recenter
|
||||
scene_center = np.mean(points, axis=0)
|
||||
points_centered = points - scene_center
|
||||
cam_to_world[..., -1] -= scene_center
|
||||
|
||||
# Store frame indices for filtering
|
||||
frame_indices = (
|
||||
np.repeat(np.arange(S), H * W)[indices]
|
||||
if indices is not None
|
||||
else np.repeat(np.arange(S), H * W)
|
||||
)
|
||||
|
||||
# Build the viser GUI
|
||||
gui_show_frames = server.gui.add_checkbox("Show Cameras", initial_value=True)
|
||||
gui_points_conf = server.gui.add_slider(
|
||||
"Confidence Percent", min=0, max=100, step=0.1, initial_value=init_conf_threshold
|
||||
)
|
||||
gui_frame_selector = server.gui.add_dropdown(
|
||||
"Show Points from Frames",
|
||||
options=["All"] + [str(i) for i in range(S)],
|
||||
initial_value="All"
|
||||
)
|
||||
|
||||
# Create the main point cloud
|
||||
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
|
||||
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
|
||||
point_cloud = server.scene.add_point_cloud(
|
||||
name="viser_pcd",
|
||||
points=points_centered[init_conf_mask],
|
||||
colors=colors_flat[init_conf_mask],
|
||||
point_size=0.0005,
|
||||
point_shape="circle",
|
||||
)
|
||||
|
||||
frames: List[viser.FrameHandle] = []
|
||||
frustums: List[viser.CameraFrustumHandle] = []
|
||||
|
||||
def visualize_frames(extrinsics, images_: np.ndarray) -> None:
|
||||
"""Add camera frames and frustums to the scene."""
|
||||
for f in frames:
|
||||
f.remove()
|
||||
frames.clear()
|
||||
for fr in frustums:
|
||||
fr.remove()
|
||||
frustums.clear()
|
||||
|
||||
def attach_callback(frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle) -> None:
|
||||
@frustum.on_click
|
||||
def _(_) -> None:
|
||||
for client in server.get_clients().values():
|
||||
client.camera.wxyz = frame.wxyz
|
||||
client.camera.position = frame.position
|
||||
|
||||
for img_id in tqdm(range(S)):
|
||||
cam2world_3x4 = extrinsics[img_id]
|
||||
T_world_camera = tf.SE3.from_matrix(cam2world_3x4)
|
||||
|
||||
frame_axis = server.scene.add_frame(
|
||||
f"frame_{img_id}",
|
||||
wxyz=T_world_camera.rotation().wxyz,
|
||||
position=T_world_camera.translation(),
|
||||
axes_length=0.05,
|
||||
axes_radius=0.002,
|
||||
origin_radius=0.002,
|
||||
)
|
||||
frames.append(frame_axis)
|
||||
|
||||
img = images_[img_id]
|
||||
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
h, w = img.shape[:2]
|
||||
|
||||
fy = 1.1 * h
|
||||
fov = 2 * np.arctan2(h / 2, fy)
|
||||
|
||||
frustum_cam = server.scene.add_camera_frustum(
|
||||
f"frame_{img_id}/frustum",
|
||||
fov=fov,
|
||||
aspect=w / h,
|
||||
scale=0.05,
|
||||
image=img,
|
||||
line_width=1.0
|
||||
)
|
||||
frustums.append(frustum_cam)
|
||||
attach_callback(frustum_cam, frame_axis)
|
||||
|
||||
def update_point_cloud() -> None:
|
||||
"""Update point cloud based on current GUI selections."""
|
||||
current_percentage = gui_points_conf.value
|
||||
threshold_val = np.percentile(conf_flat, current_percentage)
|
||||
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
|
||||
|
||||
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
|
||||
|
||||
if gui_frame_selector.value == "All":
|
||||
frame_mask = np.ones_like(conf_mask, dtype=bool)
|
||||
else:
|
||||
selected_idx = int(gui_frame_selector.value)
|
||||
frame_mask = frame_indices == selected_idx
|
||||
|
||||
combined_mask = conf_mask & frame_mask
|
||||
point_cloud.points = points_centered[combined_mask]
|
||||
point_cloud.colors = colors_flat[combined_mask]
|
||||
|
||||
@gui_points_conf.on_update
|
||||
def _(_) -> None:
|
||||
update_point_cloud()
|
||||
|
||||
@gui_frame_selector.on_update
|
||||
def _(_) -> None:
|
||||
update_point_cloud()
|
||||
|
||||
@gui_show_frames.on_update
|
||||
def _(_) -> None:
|
||||
for f in frames:
|
||||
f.visible = gui_show_frames.value
|
||||
for fr in frustums:
|
||||
fr.visible = gui_show_frames.value
|
||||
|
||||
# Add camera frames
|
||||
import torch
|
||||
if torch.is_tensor(cam_to_world):
|
||||
cam_to_world_np = cam_to_world.cpu().numpy()
|
||||
else:
|
||||
cam_to_world_np = cam_to_world
|
||||
visualize_frames(cam_to_world_np, images)
|
||||
|
||||
print("Starting viser server...")
|
||||
if background_mode:
|
||||
def server_loop():
|
||||
while True:
|
||||
time.sleep(0.001)
|
||||
|
||||
thread = threading.Thread(target=server_loop, daemon=True)
|
||||
thread.start()
|
||||
else:
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
|
||||
return server
|
||||
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[project]
|
||||
name = "lingbot-map"
|
||||
version = "0.1.0"
|
||||
description = "LingBot-Map: Geometric Context Transformer for Streaming 3D Reconstruction"
|
||||
requires-python = ">= 3.10"
|
||||
dependencies = [
|
||||
"Pillow",
|
||||
"huggingface_hub",
|
||||
"einops",
|
||||
"safetensors",
|
||||
"opencv-python",
|
||||
"tqdm",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
vis = ["viser>=0.2.23", "trimesh", "matplotlib", "onnxruntime", "requests"]
|
||||
demo = ["lingbot-map[vis]"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["lingbot_map*"]
|
||||
Reference in New Issue
Block a user