Skip to content

Commit

Permalink
Remove deprecated KeyArray type (fixes #13)
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Dec 11, 2023
1 parent 8d0c242 commit 99ae901
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 15 deletions.
2 changes: 1 addition & 1 deletion jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def normalize(self: GroupType) -> GroupType:

@classmethod
@abc.abstractmethod
def sample_uniform(cls: Type[GroupType], key: hints.KeyArray) -> GroupType:
def sample_uniform(cls: Type[GroupType], key: jax.Array) -> GroupType:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def adjoint(self: "SE2") -> jax.Array:

@staticmethod
@override
def sample_uniform(key: hints.KeyArray) -> "SE2":
def sample_uniform(key: jax.Array) -> "SE2":
key0, key1 = jax.random.split(key)
return SE2.from_rotation_and_translation(
rotation=SO2.sample_uniform(key0),
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def adjoint(self) -> jax.Array:

@staticmethod
@override
def sample_uniform(key: hints.KeyArray) -> SE3:
def sample_uniform(key: jax.Array) -> SE3:
key0, key1 = jax.random.split(key)
return SE3.from_rotation_and_translation(
rotation=SO3.sample_uniform(key0),
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def normalize(self) -> SO2:

@staticmethod
@override
def sample_uniform(key: hints.KeyArray) -> SO2:
def sample_uniform(key: jax.Array) -> SO2:
return SO2.from_radians(
jax.random.uniform(key=key, minval=0.0, maxval=2.0 * jnp.pi)
)
2 changes: 1 addition & 1 deletion jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def normalize(self) -> SO3:

@staticmethod
@override
def sample_uniform(key: hints.KeyArray) -> SO3:
def sample_uniform(key: jax.Array) -> SO3:
# Uniformly sample over S^3.
# > Reference: http://planning.cs.uiuc.edu/node198.html
u1, u2, u3 = jax.random.uniform(
Expand Down
9 changes: 0 additions & 9 deletions jaxlie/hints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,8 @@ class RollPitchYaw(NamedTuple):
yaw: Scalar


try:
# This is only exposed in `jax>=0.2.21`.
from jax.random import KeyArray
except ImportError:
KeyArray = Any # type: ignore
"""Backward-compatible alias for `jax.random.KeyArray`."""


__all__ = [
"Array",
"Scalar",
"RollPitchYaw",
"KeyArray",
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="jaxlie",
version="1.3.3",
version="1.3.4",
description="Matrix Lie groups in JAX",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 99ae901

Please sign in to comment.