Skip to content

Commit

Permalink
Merge pull request #42 from kerrj/justin/jaxtyping
Browse files Browse the repository at this point in the history
Switch to jaxtyping to be consistent with nerfstudio 0.3.1+
  • Loading branch information
kerrj authored Aug 15, 2023
2 parents 8896088 + a6f3675 commit ef934c1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
9 changes: 5 additions & 4 deletions lerf/lerf_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import numpy as np
import torch
from lerf.lerf_fieldheadnames import LERFFieldHeadNames
from torch import nn
from torch import nn, Tensor
from torch.nn.parameter import Parameter
from torchtyping import TensorType
from jaxtyping import Float

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.data.scene_box import SceneBox
Expand Down Expand Up @@ -39,7 +39,8 @@ def __init__(
[
LERFField._get_encoding(
grid_resolutions[i][0], grid_resolutions[i][1], grid_layers[i], indim=3, hash_size=grid_sizes[i]
) for i in range(len(grid_layers))
)
for i in range(len(grid_layers))
]
)
tot_out_dims = sum([e.n_output_dims for e in self.clip_encs])
Expand Down Expand Up @@ -84,7 +85,7 @@ def _get_encoding(start_res, end_res, levels, indim=3, hash_size=19):
)
return enc

def get_outputs(self, ray_samples: RaySamples, clip_scales) -> Dict[LERFFieldHeadNames, TensorType]:
def get_outputs(self, ray_samples: RaySamples, clip_scales) -> Dict[LERFFieldHeadNames, Float[Tensor, "bs dim"]]:
# random scales, one scale
outputs = {}

Expand Down
19 changes: 10 additions & 9 deletions lerf/lerf_renderers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import torch
from torch import nn
from torchtyping import TensorType
from torch import nn, Tensor
from jaxtyping import Float


class CLIPRenderer(nn.Module):
"""Calculate CLIP embeddings along ray."""

@classmethod
def forward(
cls,
embeds: TensorType["bs":..., "num_samples", "num_classes"],
weights: TensorType["bs":..., "num_samples", 1],
) -> TensorType["bs":..., "num_classes"]:
embeds: Float[Tensor, "bs num_samples num_classes"],
weights: Float[Tensor, "bs num_samples 1"],
) -> Float[Tensor, "bs num_classes"]:
"""Calculate semantics along the ray."""
output = torch.sum(weights * embeds, dim=-2)
output = output / torch.linalg.norm(output, dim=-1, keepdim=True)
Expand All @@ -23,9 +24,9 @@ class MeanRenderer(nn.Module):
@classmethod
def forward(
cls,
embeds: TensorType["bs":..., "num_samples", "num_classes"],
weights: TensorType["bs":..., "num_samples", 1],
) -> TensorType["bs":..., "num_classes"]:
embeds: Float[Tensor, "bs num_samples num_classes"],
weights: Float[Tensor, "bs num_samples 1"],
) -> Float[Tensor, "bs num_classes"]:
"""Calculate semantics along the ray."""
output = torch.sum(weights * embeds, dim=-2)
return output
return output

0 comments on commit ef934c1

Please sign in to comment.