diff --git a/lerf/lerf_field.py b/lerf/lerf_field.py index 68178bc..c5fb7af 100644 --- a/lerf/lerf_field.py +++ b/lerf/lerf_field.py @@ -3,9 +3,9 @@ import numpy as np import torch from lerf.lerf_fieldheadnames import LERFFieldHeadNames -from torch import nn +from torch import nn, Tensor from torch.nn.parameter import Parameter -from torchtyping import TensorType +from jaxtyping import Float from nerfstudio.cameras.rays import RaySamples from nerfstudio.data.scene_box import SceneBox @@ -39,7 +39,8 @@ def __init__( [ LERFField._get_encoding( grid_resolutions[i][0], grid_resolutions[i][1], grid_layers[i], indim=3, hash_size=grid_sizes[i] - ) for i in range(len(grid_layers)) + ) + for i in range(len(grid_layers)) ] ) tot_out_dims = sum([e.n_output_dims for e in self.clip_encs]) @@ -84,7 +85,7 @@ def _get_encoding(start_res, end_res, levels, indim=3, hash_size=19): ) return enc - def get_outputs(self, ray_samples: RaySamples, clip_scales) -> Dict[LERFFieldHeadNames, TensorType]: + def get_outputs(self, ray_samples: RaySamples, clip_scales) -> Dict[LERFFieldHeadNames, Float[Tensor, "bs dim"]]: # random scales, one scale outputs = {} diff --git a/lerf/lerf_renderers.py b/lerf/lerf_renderers.py index 92f6fc4..12d3021 100644 --- a/lerf/lerf_renderers.py +++ b/lerf/lerf_renderers.py @@ -1,6 +1,7 @@ import torch -from torch import nn -from torchtyping import TensorType +from torch import nn, Tensor +from jaxtyping import Float + class CLIPRenderer(nn.Module): """Calculate CLIP embeddings along ray.""" @@ -8,9 +9,9 @@ class CLIPRenderer(nn.Module): @classmethod def forward( cls, - embeds: TensorType["bs":..., "num_samples", "num_classes"], - weights: TensorType["bs":..., "num_samples", 1], - ) -> TensorType["bs":..., "num_classes"]: + embeds: Float[Tensor, "bs num_samples num_classes"], + weights: Float[Tensor, "bs num_samples 1"], + ) -> Float[Tensor, "bs num_classes"]: """Calculate semantics along the ray.""" output = torch.sum(weights * embeds, dim=-2) output = output / torch.linalg.norm(output, dim=-1, keepdim=True) @@ -23,9 +24,9 @@ class MeanRenderer(nn.Module): @classmethod def forward( cls, - embeds: TensorType["bs":..., "num_samples", "num_classes"], - weights: TensorType["bs":..., "num_samples", 1], - ) -> TensorType["bs":..., "num_classes"]: + embeds: Float[Tensor, "bs num_samples num_classes"], + weights: Float[Tensor, "bs num_samples 1"], + ) -> Float[Tensor, "bs num_classes"]: """Calculate semantics along the ray.""" output = torch.sum(weights * embeds, dim=-2) - return output \ No newline at end of file + return output