diff --git a/jaxlie/_base.py b/jaxlie/_base.py index 6e7aca3..4453a92 100644 --- a/jaxlie/_base.py +++ b/jaxlie/_base.py @@ -176,7 +176,7 @@ def normalize(self: GroupType) -> GroupType: @classmethod @abc.abstractmethod - def sample_uniform(cls: Type[GroupType], key: hints.KeyArray) -> GroupType: + def sample_uniform(cls: Type[GroupType], key: jax.Array) -> GroupType: """Draw a uniform sample from the group. Translations (if applicable) are in the range [-1, 1]. diff --git a/jaxlie/_se2.py b/jaxlie/_se2.py index 42dc84c..b6c412e 100644 --- a/jaxlie/_se2.py +++ b/jaxlie/_se2.py @@ -212,7 +212,7 @@ def adjoint(self: "SE2") -> jax.Array: @staticmethod @override - def sample_uniform(key: hints.KeyArray) -> "SE2": + def sample_uniform(key: jax.Array) -> "SE2": key0, key1 = jax.random.split(key) return SE2.from_rotation_and_translation( rotation=SO2.sample_uniform(key0), diff --git a/jaxlie/_se3.py b/jaxlie/_se3.py index 83246b5..6c4636f 100644 --- a/jaxlie/_se3.py +++ b/jaxlie/_se3.py @@ -204,7 +204,7 @@ def adjoint(self) -> jax.Array: @staticmethod @override - def sample_uniform(key: hints.KeyArray) -> SE3: + def sample_uniform(key: jax.Array) -> SE3: key0, key1 = jax.random.split(key) return SE3.from_rotation_and_translation( rotation=SO3.sample_uniform(key0), diff --git a/jaxlie/_so2.py b/jaxlie/_so2.py index c83a329..31678db 100644 --- a/jaxlie/_so2.py +++ b/jaxlie/_so2.py @@ -120,7 +120,7 @@ def normalize(self) -> SO2: @staticmethod @override - def sample_uniform(key: hints.KeyArray) -> SO2: + def sample_uniform(key: jax.Array) -> SO2: return SO2.from_radians( jax.random.uniform(key=key, minval=0.0, maxval=2.0 * jnp.pi) ) diff --git a/jaxlie/_so3.py b/jaxlie/_so3.py index 36e826f..de6e569 100644 --- a/jaxlie/_so3.py +++ b/jaxlie/_so3.py @@ -402,7 +402,7 @@ def normalize(self) -> SO3: @staticmethod @override - def sample_uniform(key: hints.KeyArray) -> SO3: + def sample_uniform(key: jax.Array) -> SO3: # Uniformly sample over S^3. # > Reference: http://planning.cs.uiuc.edu/node198.html u1, u2, u3 = jax.random.uniform( diff --git a/jaxlie/hints/__init__.py b/jaxlie/hints/__init__.py index cb38980..566a762 100644 --- a/jaxlie/hints/__init__.py +++ b/jaxlie/hints/__init__.py @@ -20,17 +20,8 @@ class RollPitchYaw(NamedTuple): yaw: Scalar -try: - # This is only exposed in `jax>=0.2.21`. - from jax.random import KeyArray -except ImportError: - KeyArray = Any # type: ignore - """Backward-compatible alias for `jax.random.KeyArray`.""" - - __all__ = [ "Array", "Scalar", "RollPitchYaw", - "KeyArray", ] diff --git a/setup.py b/setup.py index b87d756..6faa80f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="jaxlie", - version="1.3.3", + version="1.3.4", description="Matrix Lie groups in JAX", long_description=long_description, long_description_content_type="text/markdown",