first commit
This commit is contained in:
359
lingbot_map/models/gct_base.py
Normal file
359
lingbot_map/models/gct_base.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
GCTBase - Base class for GCT model implementations.
|
||||
|
||||
Provides shared functionality:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Model hub mixin (PyTorchModelHubMixin)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, List, Union
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
from lingbot_map.heads.dpt_head import DPTHead
|
||||
from lingbot_map.utils.pose_enc import pose_encoding_to_extri_intri
|
||||
from lingbot_map.utils.geometry import closed_form_inverse_se3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCTBase(nn.Module, PyTorchModelHubMixin, ABC):
|
||||
"""
|
||||
Base class for GCT model implementations.
|
||||
|
||||
Handles shared components:
|
||||
- Prediction heads (camera, depth, point)
|
||||
- Forward pass structure
|
||||
- Input normalization
|
||||
|
||||
Subclasses must implement:
|
||||
- _build_aggregator(): Create mode-specific aggregator
|
||||
- _build_camera_head(): Create mode-specific camera head
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Architecture parameters
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
embed_dim: int = 1024,
|
||||
patch_embed: str = 'dinov2_vitl14_reg',
|
||||
disable_global_rope: bool = False,
|
||||
# Head configuration
|
||||
enable_camera: bool = True,
|
||||
enable_point: bool = True,
|
||||
enable_local_point: bool = False,
|
||||
enable_depth: bool = True,
|
||||
enable_track: bool = False,
|
||||
# Camera head sliding window
|
||||
enable_camera_sliding_window: bool = False,
|
||||
# 3D RoPE
|
||||
enable_3d_rope: bool = False,
|
||||
# Context Parallelism (kept for checkpoint compatibility but not used)
|
||||
enable_ulysses_cp: bool = False,
|
||||
# Normalization
|
||||
enable_normalize: bool = False,
|
||||
# Prediction normalization
|
||||
pred_normalization: bool = False,
|
||||
pred_normalization_detach_scale: bool = False,
|
||||
# Gradient checkpointing
|
||||
use_gradient_checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store configuration
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_embed = patch_embed
|
||||
self.disable_global_rope = disable_global_rope
|
||||
|
||||
self.enable_ulysses_cp = False # CP disabled in standalone package
|
||||
self.enable_normalize = enable_normalize
|
||||
self.pred_normalization = pred_normalization
|
||||
self.pred_normalization_detach_scale = pred_normalization_detach_scale
|
||||
self.use_gradient_checkpoint = use_gradient_checkpoint
|
||||
|
||||
# Head flags
|
||||
self.enable_camera = enable_camera
|
||||
self.enable_point = enable_point
|
||||
self.enable_local_point = enable_local_point
|
||||
self.enable_depth = enable_depth
|
||||
self.enable_track = enable_track
|
||||
self.enable_camera_sliding_window = enable_camera_sliding_window
|
||||
self.enable_3d_rope = enable_3d_rope
|
||||
|
||||
# Build aggregator (subclass-specific)
|
||||
self.aggregator = self._build_aggregator()
|
||||
|
||||
# Build prediction heads (subclass-specific)
|
||||
self.camera_head = self._build_camera_head() if enable_camera else None
|
||||
self.point_head = self._build_point_head() if enable_point else None
|
||||
self.local_point_head = self._build_local_point_head() if enable_local_point else None
|
||||
self.depth_head = self._build_depth_head() if enable_depth else None
|
||||
|
||||
@abstractmethod
|
||||
def _build_aggregator(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _build_camera_head(self) -> nn.Module:
|
||||
pass
|
||||
|
||||
def _build_depth_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=2,
|
||||
activation="exp",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _build_local_point_head(self) -> nn.Module:
|
||||
return DPTHead(
|
||||
dim_in=2 * self.embed_dim,
|
||||
patch_size=self.patch_size,
|
||||
output_dim=4,
|
||||
activation="inv_log",
|
||||
conf_activation="expp1"
|
||||
)
|
||||
|
||||
def _normalize_input(self, images: torch.Tensor, query_points=None):
|
||||
if len(images.shape) == 4:
|
||||
images = images.unsqueeze(0)
|
||||
if query_points is not None and len(query_points.shape) == 2:
|
||||
query_points = query_points.unsqueeze(0)
|
||||
return images, query_points
|
||||
|
||||
@abstractmethod
|
||||
def _aggregate_features(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
view_graphs: Optional[torch.Tensor] = None,
|
||||
causal_graphs: Optional[Union[torch.Tensor, List[np.ndarray]]] = None,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
is_cp_sliced: bool = False,
|
||||
) -> tuple:
|
||||
pass
|
||||
|
||||
def _predict_camera(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.camera_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
|
||||
camera_sliding_window = sliding_window_size if self.enable_camera_sliding_window else -1
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pose_enc_list = self.camera_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
mask=mask,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale if num_frame_for_scale is not None else -1,
|
||||
sliding_window_size=camera_sliding_window,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
return {
|
||||
"pose_enc": pose_enc_list[-1],
|
||||
"pose_enc_list": pose_enc_list,
|
||||
}
|
||||
|
||||
def _predict_depth(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.depth_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
depth, depth_conf = self.depth_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"depth": depth, "depth_conf": depth_conf}
|
||||
|
||||
def _predict_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"world_points": pts3d, "world_points_conf": pts3d_conf}
|
||||
|
||||
def _predict_local_points(
|
||||
self,
|
||||
aggregated_tokens_list: list,
|
||||
images: torch.Tensor,
|
||||
patch_start_idx: int,
|
||||
gather_outputs: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if self.local_point_head is None:
|
||||
return {}
|
||||
|
||||
aggregated_tokens_list_fp32 = [t.float() for t in aggregated_tokens_list]
|
||||
images_fp32 = images.float()
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
pts3d, pts3d_conf = self.local_point_head(
|
||||
aggregated_tokens_list_fp32,
|
||||
images=images_fp32,
|
||||
patch_start_idx=patch_start_idx
|
||||
)
|
||||
|
||||
return {"cam_points": pts3d, "cam_points_conf": pts3d_conf}
|
||||
|
||||
def _unproject_depth_to_world(
|
||||
self,
|
||||
depth: torch.Tensor,
|
||||
pose_enc: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
B, S, H, W, _ = depth.shape
|
||||
device = depth.device
|
||||
dtype = depth.dtype
|
||||
|
||||
image_size_hw = (H, W)
|
||||
extrinsics, intrinsics = pose_encoding_to_extri_intri(
|
||||
pose_enc, image_size_hw=image_size_hw, build_intrinsics=True
|
||||
)
|
||||
|
||||
extrinsics_flat = extrinsics.view(B * S, 3, 4)
|
||||
extrinsics_4x4 = torch.zeros(B * S, 4, 4, device=device, dtype=dtype)
|
||||
extrinsics_4x4[:, :3, :] = extrinsics_flat
|
||||
extrinsics_4x4[:, 3, 3] = 1.0
|
||||
c2w = closed_form_inverse_se3(extrinsics_4x4).view(B, S, 4, 4)
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(
|
||||
torch.arange(H, device=device, dtype=dtype),
|
||||
torch.arange(W, device=device, dtype=dtype),
|
||||
indexing='ij'
|
||||
)
|
||||
pixel_coords = torch.stack([x_grid, y_grid, torch.ones_like(x_grid)], dim=-1)
|
||||
|
||||
intrinsics_inv = torch.inverse(intrinsics)
|
||||
camera_coords = torch.einsum('bsij,hwj->bshwi', intrinsics_inv, pixel_coords)
|
||||
camera_points = camera_coords * depth
|
||||
|
||||
ones = torch.ones_like(camera_points[..., :1])
|
||||
camera_points_h = torch.cat([camera_points, ones], dim=-1)
|
||||
world_points_h = torch.einsum('bsij,bshwj->bshwi', c2w, camera_points_h)
|
||||
|
||||
return world_points_h[..., :3]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images: torch.Tensor,
|
||||
query_points: Optional[torch.Tensor] = None,
|
||||
num_frame_for_scale: Optional[int] = None,
|
||||
sliding_window_size: Optional[int] = None,
|
||||
num_frame_per_block: int = 1,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
causal_inference: bool = False,
|
||||
ordered_video: Optional[torch.Tensor] = None,
|
||||
gather_outputs: bool = True,
|
||||
point_masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of the GCT model.
|
||||
|
||||
Args:
|
||||
images: Input images [S, 3, H, W] or [B, S, 3, H, W], in range [0, 1]
|
||||
query_points: Optional query points [N, 2] or [B, N, 2]
|
||||
|
||||
Returns:
|
||||
Dictionary containing predictions:
|
||||
- pose_enc: Camera pose encoding [B, S, 9]
|
||||
- depth: Depth maps [B, S, H, W, 1]
|
||||
- depth_conf: Depth confidence [B, S, H, W]
|
||||
- world_points: 3D world coordinates [B, S, H, W, 3]
|
||||
- world_points_conf: Point confidence [B, S, H, W]
|
||||
"""
|
||||
images, query_points = self._normalize_input(images, query_points)
|
||||
|
||||
aggregated_tokens_list, patch_start_idx = self._aggregate_features(
|
||||
images,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
)
|
||||
|
||||
predictions = {}
|
||||
|
||||
predictions.update(self._predict_camera(
|
||||
aggregated_tokens_list,
|
||||
mask=ordered_video,
|
||||
causal_inference=causal_inference,
|
||||
num_frame_for_scale=num_frame_for_scale,
|
||||
sliding_window_size=sliding_window_size,
|
||||
num_frame_per_block=num_frame_per_block,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_depth(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
predictions.update(self._predict_local_points(
|
||||
aggregated_tokens_list, images, patch_start_idx,
|
||||
gather_outputs=gather_outputs,
|
||||
))
|
||||
|
||||
if not self.training:
|
||||
predictions["images"] = images
|
||||
|
||||
return predictions
|
||||
Reference in New Issue
Block a user