Skip to content

Commit

Permalink
Improve types (#18)
Browse files Browse the repository at this point in the history
* Improve typing

* Add typing_extensions to dependencies
  • Loading branch information
brentyi authored May 5, 2024
1 parent 6cf00ce commit 7374df2
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 56 deletions.
39 changes: 17 additions & 22 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@

import jax
import numpy as onp
from typing_extensions import final, override
from typing_extensions import Self, final, override

from . import hints

GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
SEGroupType = TypeVar("SEGroupType", bound="SEBase")


class MatrixLieGroup(abc.ABC):
"""Interface definition for matrix Lie groups."""
Expand Down Expand Up @@ -44,14 +41,12 @@ def __init__(
# Shared implementations.

@overload
def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...
def __matmul__(self, other: Self) -> Self: ...

@overload
def __matmul__(self, other: hints.Array) -> jax.Array: ...

def __matmul__(
self: GroupType, other: Union[GroupType, hints.Array]
) -> Union[GroupType, jax.Array]:
def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]:
"""Overload for the `@` operator.
Switches between the group action (`.apply()`) and multiplication
Expand All @@ -69,7 +64,7 @@ def __matmul__(

@classmethod
@abc.abstractmethod
def identity(cls: Type[GroupType]) -> GroupType:
def identity(cls) -> Self:
"""Returns identity element.
Returns:
Expand All @@ -78,7 +73,7 @@ def identity(cls: Type[GroupType]) -> GroupType:

@classmethod
@abc.abstractmethod
def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType:
def from_matrix(cls, matrix: hints.Array) -> Self:
"""Get group member from matrix representation.
Args:
Expand Down Expand Up @@ -112,7 +107,7 @@ def apply(self, target: hints.Array) -> jax.Array:
"""

@abc.abstractmethod
def multiply(self: GroupType, other: GroupType) -> GroupType:
def multiply(self, other: Self) -> Self:
"""Composes this transformation with another.
Returns:
Expand All @@ -121,7 +116,7 @@ def multiply(self: GroupType, other: GroupType) -> GroupType:

@classmethod
@abc.abstractmethod
def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType:
def exp(cls, tangent: hints.Array) -> Self:
"""Computes `expm(wedge(tangent))`.
Args:
Expand Down Expand Up @@ -157,24 +152,24 @@ def adjoint(self) -> jax.Array:
"""

@abc.abstractmethod
def inverse(self: GroupType) -> GroupType:
def inverse(self) -> Self:
"""Computes the inverse of our transform.
Returns:
Output.
"""

@abc.abstractmethod
def normalize(self: GroupType) -> GroupType:
def normalize(self) -> Self:
"""Normalize/projects values and returns.
Returns:
GroupType: Normalized group member.
Normalized group member.
"""

@classmethod
@abc.abstractmethod
def sample_uniform(cls: Type[GroupType], key: jax.Array) -> GroupType:
def sample_uniform(cls, key: jax.Array) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Expand Down Expand Up @@ -213,10 +208,10 @@ class SEBase(Generic[ContainedSOType], MatrixLieGroup):
@classmethod
@abc.abstractmethod
def from_rotation_and_translation(
cls: Type[SEGroupType],
cls,
rotation: ContainedSOType,
translation: hints.Array,
) -> SEGroupType:
) -> Self:
"""Construct a rigid transform from a rotation and a translation.
Args:
Expand All @@ -229,7 +224,7 @@ def from_rotation_and_translation(

@final
@classmethod
def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType:
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
Expand All @@ -252,15 +247,15 @@ def apply(self, target: hints.Array) -> jax.Array:

@final
@override
def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType:
def multiply(self, other: Self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation() @ other.rotation(),
translation=(self.rotation() @ other.translation()) + self.translation(),
)

@final
@override
def inverse(self: SEGroupType) -> SEGroupType:
def inverse(self) -> Self:
R_inv = self.rotation().inverse()
return type(self).from_rotation_and_translation(
rotation=R_inv,
Expand All @@ -269,7 +264,7 @@ def inverse(self: SEGroupType) -> SEGroupType:

@final
@override
def normalize(self: SEGroupType) -> SEGroupType:
def normalize(self) -> Self:
return type(self).from_rotation_and_translation(
rotation=self.rotation().normalize(),
translation=self.translation(),
Expand Down
19 changes: 10 additions & 9 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2

# SE-specific.

@staticmethod
@classmethod
@override
def from_rotation_and_translation(
cls,
rotation: SO2,
translation: hints.Array,
) -> "SE2":
Expand All @@ -72,14 +73,14 @@ def translation(self) -> jax.Array:

# Factory.

@staticmethod
@classmethod
@override
def identity() -> "SE2":
def identity(cls) -> "SE2":
return SE2(unit_complex_xy=jnp.array([1.0, 0.0, 0.0, 0.0]))

@staticmethod
@classmethod
@override
def from_matrix(matrix: hints.Array) -> "SE2":
def from_matrix(cls, matrix: hints.Array) -> "SE2":
assert matrix.shape == (3, 3)
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
Expand All @@ -106,9 +107,9 @@ def as_matrix(self) -> jax.Array:

# Operations.

@staticmethod
@classmethod
@override
def exp(tangent: hints.Array) -> "SE2":
def exp(cls, tangent: hints.Array) -> "SE2":
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se2.hpp#L558
# Also see:
Expand Down Expand Up @@ -210,9 +211,9 @@ def adjoint(self: "SE2") -> jax.Array:
]
)

@staticmethod
@classmethod
@override
def sample_uniform(key: jax.Array) -> "SE2":
def sample_uniform(cls, key: jax.Array) -> "SE2":
key0, key1 = jax.random.split(key)
return SE2.from_rotation_and_translation(
rotation=SO2.sample_uniform(key0),
Expand Down
19 changes: 10 additions & 9 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def __repr__(self) -> str:

# SE-specific.

@staticmethod
@classmethod
@override
def from_rotation_and_translation(
cls,
rotation: SO3,
translation: hints.Array,
) -> SE3:
Expand All @@ -75,14 +76,14 @@ def translation(self) -> jax.Array:

# Factory.

@staticmethod
@classmethod
@override
def identity() -> SE3:
def identity(cls) -> SE3:
return SE3(wxyz_xyz=jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))

@staticmethod
@classmethod
@override
def from_matrix(matrix: hints.Array) -> SE3:
def from_matrix(cls, matrix: hints.Array) -> SE3:
assert matrix.shape == (4, 4)
# Currently assumes bottom row is [0, 0, 0, 1].
return SE3.from_rotation_and_translation(
Expand All @@ -108,9 +109,9 @@ def parameters(self) -> jax.Array:

# Operations.

@staticmethod
@classmethod
@override
def exp(tangent: hints.Array) -> SE3:
def exp(cls, tangent: hints.Array) -> SE3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761

Expand Down Expand Up @@ -202,9 +203,9 @@ def adjoint(self) -> jax.Array:
]
)

@staticmethod
@classmethod
@override
def sample_uniform(key: jax.Array) -> SE3:
def sample_uniform(cls, key: jax.Array) -> SE3:
key0, key1 = jax.random.split(key)
return SE3.from_rotation_and_translation(
rotation=SO3.sample_uniform(key0),
Expand Down
16 changes: 8 additions & 8 deletions jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def as_radians(self) -> jax.Array:

# Factory.

@staticmethod
@classmethod
@override
def identity() -> SO2:
def identity(cls) -> SO2:
return SO2(unit_complex=jnp.array([1.0, 0.0]))

@staticmethod
@classmethod
@override
def from_matrix(matrix: hints.Array) -> SO2:
def from_matrix(cls, matrix: hints.Array) -> SO2:
assert matrix.shape == (2, 2)
return SO2(unit_complex=jnp.asarray(matrix[:, 0]))

Expand Down Expand Up @@ -92,9 +92,9 @@ def apply(self, target: hints.Array) -> jax.Array:
def multiply(self, other: SO2) -> SO2:
return SO2(unit_complex=self.as_matrix() @ other.unit_complex)

@staticmethod
@classmethod
@override
def exp(tangent: hints.Array) -> SO2:
def exp(cls, tangent: hints.Array) -> SO2:
(theta,) = tangent
cos = jnp.cos(theta)
sin = jnp.sin(theta)
Expand All @@ -118,9 +118,9 @@ def inverse(self) -> SO2:
def normalize(self) -> SO2:
return SO2(unit_complex=self.unit_complex / jnp.linalg.norm(self.unit_complex))

@staticmethod
@classmethod
@override
def sample_uniform(key: jax.Array) -> SO2:
def sample_uniform(cls, key: jax.Array) -> SO2:
return SO2.from_radians(
jax.random.uniform(key=key, minval=0.0, maxval=2.0 * jnp.pi)
)
16 changes: 8 additions & 8 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def compute_yaw_radians(self) -> jax.Array:

# Factory.

@staticmethod
@classmethod
@override
def identity() -> SO3:
def identity(cls) -> SO3:
return SO3(wxyz=jnp.array([1.0, 0.0, 0.0, 0.0]))

@staticmethod
@classmethod
@override
def from_matrix(matrix: hints.Array) -> SO3:
def from_matrix(cls, matrix: hints.Array) -> SO3:
assert matrix.shape == (3, 3)

# Modified from:
Expand Down Expand Up @@ -308,9 +308,9 @@ def multiply(self, other: SO3) -> SO3:
)
)

@staticmethod
@classmethod
@override
def exp(tangent: hints.Array) -> SO3:
def exp(cls, tangent: hints.Array) -> SO3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583

Expand Down Expand Up @@ -400,9 +400,9 @@ def inverse(self) -> SO3:
def normalize(self) -> SO3:
return SO3(wxyz=self.wxyz / jnp.linalg.norm(self.wxyz))

@staticmethod
@classmethod
@override
def sample_uniform(key: jax.Array) -> SO3:
def sample_uniform(cls, key: jax.Array) -> SO3:
# Uniformly sample over S^3.
# > Reference: http://planning.cs.uiuc.edu/node198.html
u1, u2, u3 = jax.random.uniform(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"jax>=0.3.18", # For jax.Array.
"jax_dataclasses>=1.4.4",
"numpy",
"typing_extensions>=4.0.0",
"tyro", # Only used in examples.
],
extras_require={
Expand Down

0 comments on commit 7374df2

Please sign in to comment.