diff --git a/jaxlie/_base.py b/jaxlie/_base.py index 973ca65..c273f2a 100644 --- a/jaxlie/_base.py +++ b/jaxlie/_base.py @@ -3,13 +3,10 @@ import jax import numpy as onp -from typing_extensions import final, override +from typing_extensions import Self, final, override from . import hints -GroupType = TypeVar("GroupType", bound="MatrixLieGroup") -SEGroupType = TypeVar("SEGroupType", bound="SEBase") - class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" @@ -44,14 +41,12 @@ def __init__( # Shared implementations. @overload - def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... + def __matmul__(self, other: Self) -> Self: ... @overload def __matmul__(self, other: hints.Array) -> jax.Array: ... - def __matmul__( - self: GroupType, other: Union[GroupType, hints.Array] - ) -> Union[GroupType, jax.Array]: + def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication @@ -69,7 +64,7 @@ def __matmul__( @classmethod @abc.abstractmethod - def identity(cls: Type[GroupType]) -> GroupType: + def identity(cls) -> Self: """Returns identity element. Returns: @@ -78,7 +73,7 @@ def identity(cls: Type[GroupType]) -> GroupType: @classmethod @abc.abstractmethod - def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType: + def from_matrix(cls, matrix: hints.Array) -> Self: """Get group member from matrix representation. Args: @@ -112,7 +107,7 @@ def apply(self, target: hints.Array) -> jax.Array: """ @abc.abstractmethod - def multiply(self: GroupType, other: GroupType) -> GroupType: + def multiply(self, other: Self) -> Self: """Composes this transformation with another. Returns: @@ -121,7 +116,7 @@ def multiply(self: GroupType, other: GroupType) -> GroupType: @classmethod @abc.abstractmethod - def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType: + def exp(cls, tangent: hints.Array) -> Self: """Computes `expm(wedge(tangent))`. Args: @@ -157,7 +152,7 @@ def adjoint(self) -> jax.Array: """ @abc.abstractmethod - def inverse(self: GroupType) -> GroupType: + def inverse(self) -> Self: """Computes the inverse of our transform. Returns: @@ -165,16 +160,16 @@ def inverse(self: GroupType) -> GroupType: """ @abc.abstractmethod - def normalize(self: GroupType) -> GroupType: + def normalize(self) -> Self: """Normalize/projects values and returns. Returns: - GroupType: Normalized group member. + Normalized group member. """ @classmethod @abc.abstractmethod - def sample_uniform(cls: Type[GroupType], key: jax.Array) -> GroupType: + def sample_uniform(cls, key: jax.Array) -> Self: """Draw a uniform sample from the group. Translations (if applicable) are in the range [-1, 1]. @@ -213,10 +208,10 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup): @classmethod @abc.abstractmethod def from_rotation_and_translation( - cls: Type[SEGroupType], + cls, rotation: ContainedSOType, translation: hints.Array, - ) -> SEGroupType: + ) -> Self: """Construct a rigid transform from a rotation and a translation. Args: @@ -229,7 +224,7 @@ def from_rotation_and_translation( @final @classmethod - def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: + def from_rotation(cls, rotation: ContainedSOType) -> Self: return cls.from_rotation_and_translation( rotation=rotation, translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype), @@ -252,7 +247,7 @@ def apply(self, target: hints.Array) -> jax.Array: @final @override - def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: + def multiply(self, other: Self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation() @ other.rotation(), translation=(self.rotation() @ other.translation()) + self.translation(), @@ -260,7 +255,7 @@ def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: @final @override - def inverse(self: SEGroupType) -> SEGroupType: + def inverse(self) -> Self: R_inv = self.rotation().inverse() return type(self).from_rotation_and_translation( rotation=R_inv, @@ -269,7 +264,7 @@ def inverse(self: SEGroupType) -> SEGroupType: @final @override - def normalize(self: SEGroupType) -> SEGroupType: + def normalize(self) -> Self: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), diff --git a/jaxlie/_se2.py b/jaxlie/_se2.py index b6c412e..e7a4ce9 100644 --- a/jaxlie/_se2.py +++ b/jaxlie/_se2.py @@ -51,9 +51,10 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2 # SE-specific. - @staticmethod + @classmethod @override def from_rotation_and_translation( + cls, rotation: SO2, translation: hints.Array, ) -> "SE2": @@ -72,14 +73,14 @@ def translation(self) -> jax.Array: # Factory. - @staticmethod + @classmethod @override - def identity() -> "SE2": + def identity(cls) -> "SE2": return SE2(unit_complex_xy=jnp.array([1.0, 0.0, 0.0, 0.0])) - @staticmethod + @classmethod @override - def from_matrix(matrix: hints.Array) -> "SE2": + def from_matrix(cls, matrix: hints.Array) -> "SE2": assert matrix.shape == (3, 3) # Currently assumes bottom row is [0, 0, 1]. return SE2.from_rotation_and_translation( @@ -106,9 +107,9 @@ def as_matrix(self) -> jax.Array: # Operations. - @staticmethod + @classmethod @override - def exp(tangent: hints.Array) -> "SE2": + def exp(cls, tangent: hints.Array) -> "SE2": # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558 # Also see: @@ -210,9 +211,9 @@ def adjoint(self: "SE2") -> jax.Array: ] ) - @staticmethod + @classmethod @override - def sample_uniform(key: jax.Array) -> "SE2": + def sample_uniform(cls, key: jax.Array) -> "SE2": key0, key1 = jax.random.split(key) return SE2.from_rotation_and_translation( rotation=SO2.sample_uniform(key0), diff --git a/jaxlie/_se3.py b/jaxlie/_se3.py index 6c4636f..0de26df 100644 --- a/jaxlie/_se3.py +++ b/jaxlie/_se3.py @@ -56,9 +56,10 @@ def __repr__(self) -> str: # SE-specific. - @staticmethod + @classmethod @override def from_rotation_and_translation( + cls, rotation: SO3, translation: hints.Array, ) -> SE3: @@ -75,14 +76,14 @@ def translation(self) -> jax.Array: # Factory. - @staticmethod + @classmethod @override - def identity() -> SE3: + def identity(cls) -> SE3: return SE3(wxyz_xyz=jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])) - @staticmethod + @classmethod @override - def from_matrix(matrix: hints.Array) -> SE3: + def from_matrix(cls, matrix: hints.Array) -> SE3: assert matrix.shape == (4, 4) # Currently assumes bottom row is [0, 0, 0, 1]. return SE3.from_rotation_and_translation( @@ -108,9 +109,9 @@ def parameters(self) -> jax.Array: # Operations. - @staticmethod + @classmethod @override - def exp(tangent: hints.Array) -> SE3: + def exp(cls, tangent: hints.Array) -> SE3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761 @@ -202,9 +203,9 @@ def adjoint(self) -> jax.Array: ] ) - @staticmethod + @classmethod @override - def sample_uniform(key: jax.Array) -> SE3: + def sample_uniform(cls, key: jax.Array) -> SE3: key0, key1 = jax.random.split(key) return SE3.from_rotation_and_translation( rotation=SO3.sample_uniform(key0), diff --git a/jaxlie/_so2.py b/jaxlie/_so2.py index 31678db..49c5f71 100644 --- a/jaxlie/_so2.py +++ b/jaxlie/_so2.py @@ -50,14 +50,14 @@ def as_radians(self) -> jax.Array: # Factory. - @staticmethod + @classmethod @override - def identity() -> SO2: + def identity(cls) -> SO2: return SO2(unit_complex=jnp.array([1.0, 0.0])) - @staticmethod + @classmethod @override - def from_matrix(matrix: hints.Array) -> SO2: + def from_matrix(cls, matrix: hints.Array) -> SO2: assert matrix.shape == (2, 2) return SO2(unit_complex=jnp.asarray(matrix[:, 0])) @@ -92,9 +92,9 @@ def apply(self, target: hints.Array) -> jax.Array: def multiply(self, other: SO2) -> SO2: return SO2(unit_complex=self.as_matrix() @ other.unit_complex) - @staticmethod + @classmethod @override - def exp(tangent: hints.Array) -> SO2: + def exp(cls, tangent: hints.Array) -> SO2: (theta,) = tangent cos = jnp.cos(theta) sin = jnp.sin(theta) @@ -118,9 +118,9 @@ def inverse(self) -> SO2: def normalize(self) -> SO2: return SO2(unit_complex=self.unit_complex / jnp.linalg.norm(self.unit_complex)) - @staticmethod + @classmethod @override - def sample_uniform(key: jax.Array) -> SO2: + def sample_uniform(cls, key: jax.Array) -> SO2: return SO2.from_radians( jax.random.uniform(key=key, minval=0.0, maxval=2.0 * jnp.pi) ) diff --git a/jaxlie/_so3.py b/jaxlie/_so3.py index de6e569..a7dca9b 100644 --- a/jaxlie/_so3.py +++ b/jaxlie/_so3.py @@ -160,14 +160,14 @@ def compute_yaw_radians(self) -> jax.Array: # Factory. - @staticmethod + @classmethod @override - def identity() -> SO3: + def identity(cls) -> SO3: return SO3(wxyz=jnp.array([1.0, 0.0, 0.0, 0.0])) - @staticmethod + @classmethod @override - def from_matrix(matrix: hints.Array) -> SO3: + def from_matrix(cls, matrix: hints.Array) -> SO3: assert matrix.shape == (3, 3) # Modified from: @@ -308,9 +308,9 @@ def multiply(self, other: SO3) -> SO3: ) ) - @staticmethod + @classmethod @override - def exp(tangent: hints.Array) -> SO3: + def exp(cls, tangent: hints.Array) -> SO3: # Reference: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583 @@ -400,9 +400,9 @@ def inverse(self) -> SO3: def normalize(self) -> SO3: return SO3(wxyz=self.wxyz / jnp.linalg.norm(self.wxyz)) - @staticmethod + @classmethod @override - def sample_uniform(key: jax.Array) -> SO3: + def sample_uniform(cls, key: jax.Array) -> SO3: # Uniformly sample over S^3. # > Reference: http://planning.cs.uiuc.edu/node198.html u1, u2, u3 = jax.random.uniform( diff --git a/setup.py b/setup.py index 2b9587a..285b060 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "jax>=0.3.18", # For jax.Array. "jax_dataclasses>=1.4.4", "numpy", + "typing_extensions>=4.0.0", "tyro", # Only used in examples. ], extras_require={