From 2c7cf893c3338d2fc6ce07198f85c646337cef78 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 5 May 2024 00:19:09 -0700 Subject: [PATCH] Add batch axes to tests + bug fixes --- jaxlie/_se2.py | 48 ++++++++++++++++------ jaxlie/_se3.py | 24 ++++++----- jaxlie/_so2.py | 6 +-- jaxlie/_so3.py | 35 +++++++++------- jaxlie/manifold/_deltas.py | 65 ++++++++++++++++++----------- jaxlie/utils/_utils.py | 16 ++++---- tests/test_autodiff.py | 55 ++++++++++++++++++------- tests/test_group_axioms.py | 55 ++++++++++++++----------- tests/test_manifold.py | 34 ++++++++++------ tests/test_operations.py | 81 ++++++++++++++++++++++++------------- tests/test_serialization.py | 20 +++++---- tests/utils.py | 48 +++++++++++++++------- 12 files changed, 315 insertions(+), 172 deletions(-) diff --git a/jaxlie/_se2.py b/jaxlie/_se2.py index 6b0baa5..12f2fb4 100644 --- a/jaxlie/_se2.py +++ b/jaxlie/_se2.py @@ -73,7 +73,7 @@ def translation(self) -> jax.Array: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2": + def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> "SE2": return SE2( unit_complex_xy=jnp.broadcast_to( jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4) @@ -99,9 +99,21 @@ def parameters(self) -> jax.Array: @override def as_matrix(self) -> jax.Array: cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0) - return jnp.stack([cos, -sin, x, sin, cos, y, 0.0, 0.0, 1.0], axis=-1).reshape( - (*self.get_batch_axes(), 3, 3) - ) + out = jnp.stack( + [ + cos, + -sin, + x, + sin, + cos, + y, + jnp.zeros_like(x), + jnp.zeros_like(x), + jnp.ones_like(x), + ], + axis=-1, + ).reshape((*self.get_batch_axes(), 3, 3)) + return out # Operations. @@ -205,24 +217,34 @@ def log(self) -> jax.Array: [ jnp.einsum("...ij,...j->...i", V_inv, self.translation()), theta[..., None], - ] + ], + axis=-1, ) return tangent @override def adjoint(self: "SE2") -> jax.Array: - cos, sin, x, y = self.unit_complex_xy - return jnp.array( + cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0) + return jnp.stack( [ - [cos, -sin, y], - [sin, cos, -x], - [0.0, 0.0, 1.0], - ] - ) + cos, + -sin, + y, + sin, + cos, + -x, + jnp.zeros_like(x), + jnp.zeros_like(x), + jnp.ones_like(x), + ], + axis=-1, + ).reshape((*self.get_batch_axes(), 3, 3)) @classmethod @override - def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> "SE2": + def sample_uniform( + cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () + ) -> "SE2": key0, key1 = jax.random.split(key) return SE2.from_rotation_and_translation( rotation=SO2.sample_uniform(key0, batch_axes=batch_axes), diff --git a/jaxlie/_se3.py b/jaxlie/_se3.py index e5c2cda..3bc23f6 100644 --- a/jaxlie/_se3.py +++ b/jaxlie/_se3.py @@ -16,8 +16,9 @@ def _skew(omega: hints.Array) -> jax.Array: """Returns the skew-symmetric form of a length-3 vector.""" wx, wy, wz = jnp.moveaxis(omega, -1, 0) + zeros = jnp.zeros_like(wx) return jnp.stack( - [0.0, -wz, wy, wz, 0.0, -wx, -wy, wx, 0.0], + [zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros], axis=-1, ).reshape((*omega.shape[:-1], 3, 3)) @@ -71,7 +72,7 @@ def translation(self) -> jax.Array: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3: + def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SE3: return SE3( wxyz_xyz=jnp.broadcast_to( jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7) @@ -92,13 +93,16 @@ def from_matrix(cls, matrix: hints.Array) -> SE3: @override def as_matrix(self) -> jax.Array: - return ( - jnp.eye(4) - .at[..., :3, :3] + out = jnp.zeros((*self.get_batch_axes(), 4, 4)) + out = ( + out.at[..., :3, :3] .set(self.rotation().as_matrix()) .at[..., :3, 3] .set(self.translation()) + .at[..., 3, 3] + .set(1.0) ) + return out @override def parameters(self) -> jax.Array: @@ -117,7 +121,7 @@ def exp(cls, tangent: hints.Array) -> SE3: rotation = SO3.exp(tangent[..., 3:]) - theta_squared = jnp.sum(jnp.square(tangent[3:]), axis=-1) + theta_squared = jnp.sum(jnp.square(tangent[..., 3:]), axis=-1) use_taylor = theta_squared < get_epsilon(theta_squared.dtype) # Shim to avoid NaNs in jnp.where branches, which cause failures for @@ -191,7 +195,7 @@ def log(self) -> jax.Array: ), ) return jnp.concatenate( - [jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega] + [jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1 ) @override @@ -207,12 +211,14 @@ def adjoint(self) -> jax.Array: [jnp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1 ), ], - axis=-1, + axis=-2, ) @classmethod @override - def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SE3: + def sample_uniform( + cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () + ) -> SE3: key0, key1 = jax.random.split(key) return SE3.from_rotation_and_translation( rotation=SO3.sample_uniform(key0, batch_axes=batch_axes), diff --git a/jaxlie/_so2.py b/jaxlie/_so2.py index 2adcad7..17a0d96 100644 --- a/jaxlie/_so2.py +++ b/jaxlie/_so2.py @@ -50,7 +50,7 @@ def as_radians(self) -> jax.Array: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: + def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2: return SO2( unit_complex=jnp.stack( [jnp.ones(batch_axes), jnp.zeros(batch_axes)], axis=-1 @@ -60,7 +60,7 @@ def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2: @classmethod @override def from_matrix(cls, matrix: hints.Array) -> SO2: - assert matrix.shape == (2, 2) + assert matrix.shape[-2:] == (2, 2) return SO2(unit_complex=jnp.asarray(matrix[..., :, 0])) # Accessors. @@ -130,7 +130,7 @@ def normalize(self) -> SO2: @classmethod @override - def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SO2: + def sample_uniform(cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2: out = SO2.from_radians( jax.random.uniform( key=key, shape=batch_axes, minval=0.0, maxval=2.0 * jnp.pi diff --git a/jaxlie/_so3.py b/jaxlie/_so3.py index 2ada174..c3a2d75 100644 --- a/jaxlie/_so3.py +++ b/jaxlie/_so3.py @@ -103,17 +103,17 @@ def from_quaternion_xyzw(xyzw: hints.Array) -> SO3: constructor. Args: - xyzw: xyzw quaternion. Shape should be (4,). + xyzw: xyzw quaternion. Shape should be (*, 4). Returns: Output. """ - assert xyzw.shape == (4,) - return SO3(jnp.roll(xyzw, shift=1)) + assert xyzw.shape[-1:] == (4,) + return SO3(jnp.roll(xyzw, axis=-1, shift=1)) def as_quaternion_xyzw(self) -> jax.Array: """Grab parameters as xyzw quaternion.""" - return jnp.roll(self.wxyz, shift=-1) + return jnp.roll(self.wxyz, axis=-1, shift=-1) def as_rpy_radians(self) -> hints.RollPitchYaw: """Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention. @@ -161,7 +161,7 @@ def compute_yaw_radians(self) -> jax.Array: @classmethod @override - def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3: + def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO3: return SO3( wxyz=jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)) ) @@ -363,7 +363,8 @@ def exp(cls, tangent: hints.Array) -> SO3: [ real_factor[..., None], imaginary_factor[..., None] * tangent, - ] + ], + axis=-1, ) ) @@ -373,7 +374,7 @@ def log(self) -> jax.Array: # > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247 w = self.wxyz[..., 0] - norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:] + norm_sq = jnp.sum(jnp.square(self.wxyz[..., 1:]), axis=-1) use_taylor = norm_sq < get_epsilon(norm_sq.dtype) # Shim to avoid NaNs in jnp.where branches, which cause failures for @@ -400,7 +401,7 @@ def log(self) -> jax.Array: ), ) - return atan_factor * self.wxyz[1:] + return atan_factor * self.wxyz[..., 1:] @override def adjoint(self) -> jax.Array: @@ -417,14 +418,20 @@ def normalize(self) -> SO3: @classmethod @override - def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SO3: + def sample_uniform( + cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () + ) -> SO3: # Uniformly sample over S^3. # > Reference: http://planning.cs.uiuc.edu/node198.html - u1, u2, u3 = jax.random.uniform( - key=key, - shape=(3, *batch_axes), - minval=jnp.zeros(3), - maxval=jnp.array([1.0, 2.0 * jnp.pi, 2.0 * jnp.pi]), + u1, u2, u3 = jnp.moveaxis( + jax.random.uniform( + key=key, + shape=(*batch_axes, 3), + minval=jnp.zeros(3), + maxval=jnp.array([1.0, 2.0 * jnp.pi, 2.0 * jnp.pi]), + ), + -1, + 0, ) a = jnp.sqrt(1.0 - u1) b = jnp.sqrt(u1) diff --git a/jaxlie/manifold/_deltas.py b/jaxlie/manifold/_deltas.py index 7b8b6c9..3bc9a7d 100644 --- a/jaxlie/manifold/_deltas.py +++ b/jaxlie/manifold/_deltas.py @@ -49,14 +49,16 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType: def rplus( transform: GroupType, delta: hints.Array, -) -> GroupType: ... +) -> GroupType: + ... @overload def rplus( transform: PytreeType, delta: _tree_utils.TangentPytree, -) -> PytreeType: ... +) -> PytreeType: + ... # Using our typevars in the overloaded signature will cause errors. @@ -79,11 +81,13 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array: @overload -def rminus(a: GroupType, b: GroupType) -> jax.Array: ... +def rminus(a: GroupType, b: GroupType) -> jax.Array: + ... @overload -def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... +def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: + ... # Using our typevars in the overloaded signature will cause errors. @@ -129,23 +133,23 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Jacobian col indices: theta transform_so2 = cast(SO2, transform) - J = jnp.zeros((2, 1)) + J = jnp.zeros((*transform.get_batch_axes(), 2, 1)) - cos, sin = transform_so2.unit_complex - J = J.at[0].set(-sin).at[1].set(cos) + cos, sin = jnp.moveaxis(transform_so2.unit_complex, -1, 0) + J = J.at[..., 0].set(-sin).at[..., 1].set(cos) elif type(transform) is SE2: # Jacobian row indices: cos, sin, x, y # Jacobian col indices: vx, vy, omega transform_se2 = cast(SE2, transform) - J = jnp.zeros((4, 3)) + J = jnp.zeros((*transform.get_batch_axes(), 4, 3)) # Translation terms. - J = J.at[2:, :2].set(transform_se2.rotation().as_matrix()) + J = J.at[..., 2:, :2].set(transform_se2.rotation().as_matrix()) # Rotation terms. - J = J.at[:2, 2:3].set( + J = J.at[..., :2, 2:3].set( rplus_jacobian_parameters_wrt_delta(transform_se2.rotation()) ) @@ -155,18 +159,29 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: transform_so3 = cast(SO3, transform) - w, x, y, z = transform_so3.wxyz - _unused_neg_w, neg_x, neg_y, neg_z = -transform_so3.wxyz + w, x, y, z = jnp.moveaxis(transform_so3.wxyz, -1, 0) + neg_x = -x + neg_y = -y + neg_z = -z J = ( - jnp.array( + jnp.stack( [ - [neg_x, neg_y, neg_z], - [w, neg_z, y], - [z, w, neg_x], - [neg_y, x, w], - ] - ) + neg_x, + neg_y, + neg_z, + w, + neg_z, + y, + z, + w, + neg_x, + neg_y, + x, + w, + ], + axis=-1, + ).reshape((*transform.get_batch_axes(), 4, 3)) / 2.0 ) @@ -175,18 +190,22 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Jacobian col indices: vx, vy, vz, omega x, omega y, omega z transform_se3 = cast(SE3, transform) - J = jnp.zeros((7, 6)) + J = jnp.zeros((*transform.get_batch_axes(), 7, 6)) # Translation terms. - J = J.at[4:, :3].set(transform_se3.rotation().as_matrix()) + J = J.at[..., 4:, :3].set(transform_se3.rotation().as_matrix()) # Rotation terms. - J = J.at[:4, 3:6].set( + J = J.at[..., :4, 3:6].set( rplus_jacobian_parameters_wrt_delta(transform_se3.rotation()) ) else: assert False, f"Unsupported type: {type(transform)}" - assert J.shape == (transform.parameters_dim, transform.tangent_dim) + assert J.shape == ( + *transform.get_batch_axes(), + transform.parameters_dim, + transform.tangent_dim, + ) return J diff --git a/jaxlie/utils/_utils.py b/jaxlie/utils/_utils.py index e39bfdf..1f45b00 100644 --- a/jaxlie/utils/_utils.py +++ b/jaxlie/utils/_utils.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Callable, Type, TypeVar -import jax +import jax_dataclasses as jdc from jax import numpy as jnp if TYPE_CHECKING: @@ -45,13 +45,13 @@ def _wrap(cls: Type[T]) -> Type[T]: cls.space_dim = space_dim # JIT all methods. - for f in filter( - lambda f: not f.startswith("_") - and callable(getattr(cls, f)) - and f != "get_batch_axes", # Avoid returning tracers. - dir(cls), - ): - setattr(cls, f, jax.jit(getattr(cls, f))) + # for f in filter( + # lambda f: not f.startswith("_") + # and callable(getattr(cls, f)) + # and f != "get_batch_axes", # Avoid returning tracers. + # dir(cls), + # ): + # setattr(cls, f, jdc.jit(getattr(cls, f))) return cls diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index bce6f36..a7ed04f 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -1,14 +1,14 @@ """Compare forward- and reverse-mode Jacobians with a numerical Jacobian.""" from functools import lru_cache -from typing import Callable, Type, cast +from typing import Callable, Tuple, Type, cast import jax +import jaxlie import numpy as onp from jax import numpy as jnp -from utils import assert_arrays_close, general_group_test, jacnumerical -import jaxlie +from utils import assert_arrays_close, general_group_test, jacnumerical # We cache JITed Jacobians to improve runtime. cached_jacfwd = lru_cache(maxsize=None)( @@ -56,16 +56,18 @@ def func(x): @general_group_test -def test_exp_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_exp_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check that exp Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. generator = onp.random.randn(Group.tangent_dim) _assert_jacobians_close(Group=Group, f=_exp, primal=generator) @general_group_test -def test_exp_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_exp_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check that exp Jacobians are consistent, with transforms close to the identity.""" + del batch_axes # Not used for autodiff tests. generator = onp.random.randn(Group.tangent_dim) * 1e-6 _assert_jacobians_close(Group=Group, f=_exp, primal=generator) @@ -76,14 +78,15 @@ def _log(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: @general_group_test -def test_log_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_log_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check that log Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() _assert_jacobians_close(Group=Group, f=_log, primal=params) @general_group_test -def test_log_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_log_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check that log Jacobians are consistent, with transforms close to the identity.""" params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() @@ -96,16 +99,22 @@ def _adjoint(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array @general_group_test -def test_adjoint_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_adjoint_random( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that adjoint Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) @general_group_test -def test_adjoint_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_adjoint_identity( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that adjoint Jacobians are consistent, with transforms close to the identity.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() _assert_jacobians_close(Group=Group, f=_adjoint, primal=params) @@ -116,16 +125,20 @@ def _apply(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array: @general_group_test -def test_apply_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_apply_random(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check that apply Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() _assert_jacobians_close(Group=Group, f=_apply, primal=params) @general_group_test -def test_apply_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_apply_identity( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that apply Jacobians are consistent, with transforms close to the identity.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() _assert_jacobians_close(Group=Group, f=_apply, primal=params) @@ -136,17 +149,23 @@ def _multiply(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Arra @general_group_test -def test_multiply_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_multiply_random( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that multiply Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() _assert_jacobians_close(Group=Group, f=_multiply, primal=params) @general_group_test -def test_multiply_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_multiply_identity( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that multiply Jacobians are consistent, with transforms close to the identity.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() _assert_jacobians_close(Group=Group, f=_multiply, primal=params) @@ -157,15 +176,21 @@ def _inverse(Group: Type[jaxlie.MatrixLieGroup], params: jax.Array) -> jax.Array @general_group_test -def test_inverse_random(Group: Type[jaxlie.MatrixLieGroup]): +def test_inverse_random( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that inverse Jacobians are consistent, with randomly sampled transforms.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim)).parameters() _assert_jacobians_close(Group=Group, f=_inverse, primal=params) @general_group_test -def test_inverse_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_inverse_identity( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that inverse Jacobians are consistent, with transforms close to the identity.""" + del batch_axes # Not used for autodiff tests. params = Group.exp(onp.random.randn(Group.tangent_dim) * 1e-6).parameters() _assert_jacobians_close(Group=Group, f=_inverse, primal=params) diff --git a/tests/test_group_axioms.py b/tests/test_group_axioms.py index 168fcbf..7b4e70d 100644 --- a/tests/test_group_axioms.py +++ b/tests/test_group_axioms.py @@ -2,10 +2,11 @@ https://proofwiki.org/wiki/Definition:Group_Axioms """ +from typing import Tuple, Type -from typing import Type - +import jaxlie import numpy as onp + from utils import ( assert_arrays_close, assert_transforms_close, @@ -13,14 +14,12 @@ sample_transform, ) -import jaxlie - @general_group_test -def test_closure(Group: Type[jaxlie.MatrixLieGroup]): +def test_closure(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check closure property.""" - transform_a = sample_transform(Group) - transform_b = sample_transform(Group) + transform_a = sample_transform(Group, batch_axes) + transform_b = sample_transform(Group, batch_axes) composed = transform_a @ transform_b assert_transforms_close(composed, composed.normalize()) @@ -33,45 +32,55 @@ def test_closure(Group: Type[jaxlie.MatrixLieGroup]): @general_group_test -def test_identity(Group: Type[jaxlie.MatrixLieGroup]): +def test_identity(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check identity property.""" - transform = sample_transform(Group) - identity = Group.identity() + transform = sample_transform(Group, batch_axes) + identity = Group.identity(batch_axes) assert_transforms_close(transform, identity @ transform) assert_transforms_close(transform, transform @ identity) assert_arrays_close( - transform.as_matrix(), identity.as_matrix() @ transform.as_matrix() + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", identity.as_matrix(), transform.as_matrix()), ) assert_arrays_close( - transform.as_matrix(), transform.as_matrix() @ identity.as_matrix() + transform.as_matrix(), + onp.einsum("...ij,...jk->...ik", transform.as_matrix(), identity.as_matrix()), ) @general_group_test -def test_inverse(Group: Type[jaxlie.MatrixLieGroup]): +def test_inverse(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check inverse property.""" - transform = sample_transform(Group) - identity = Group.identity() + transform = sample_transform(Group, batch_axes) + identity = Group.identity(batch_axes) assert_transforms_close(identity, transform @ transform.inverse()) assert_transforms_close(identity, transform.inverse() @ transform) assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) assert_arrays_close( - onp.eye(Group.matrix_dim), - transform.as_matrix() @ transform.inverse().as_matrix(), + onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)), + onp.einsum( + "...ij,...jk->...ik", + transform.as_matrix(), + transform.inverse().as_matrix(), + ), ) assert_arrays_close( - onp.eye(Group.matrix_dim), - transform.inverse().as_matrix() @ transform.as_matrix(), + onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)), + onp.einsum( + "...ij,...jk->...ik", + transform.inverse().as_matrix(), + transform.as_matrix(), + ), ) @general_group_test -def test_associative(Group: Type[jaxlie.MatrixLieGroup]): +def test_associative(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check associative property.""" - transform_a = sample_transform(Group) - transform_b = sample_transform(Group) - transform_c = sample_transform(Group) + transform_a = sample_transform(Group, batch_axes) + transform_b = sample_transform(Group, batch_axes) + transform_c = sample_transform(Group, batch_axes) assert_transforms_close( (transform_a @ transform_b) @ transform_c, transform_a @ (transform_b @ transform_c), diff --git a/tests/test_manifold.py b/tests/test_manifold.py index ca31f30..6b31356 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -1,12 +1,14 @@ """Test manifold helpers.""" -from typing import Type +from typing import Tuple, Type import jax +import jaxlie import numpy as onp import pytest from jax import numpy as jnp from jax import tree_util + from utils import ( assert_arrays_close, assert_transforms_close, @@ -15,14 +17,12 @@ sample_transform, ) -import jaxlie - @general_group_test -def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup]): +def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check rplus and rminus on random inputs.""" - T_wa = sample_transform(Group) - T_wb = sample_transform(Group) + T_wa = sample_transform(Group, batch_axes) + T_wb = sample_transform(Group, batch_axes) T_ab = T_wa.inverse() @ T_wb assert_transforms_close(jaxlie.manifold.rplus(T_wa, T_ab.log()), T_wb) @@ -30,14 +30,24 @@ def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup]): @general_group_test -def test_rplus_jacobian(Group: Type[jaxlie.MatrixLieGroup]): +def test_rplus_jacobian( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check analytical rplus Jacobian..""" - T_wa = sample_transform(Group) + T_wa = sample_transform(Group, batch_axes) J_ours = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(T_wa) - J_jacfwd = _rplus_jacobian_parameters_wrt_delta(T_wa) - assert_arrays_close(J_ours, J_jacfwd) + if batch_axes == (): + J_jacfwd = _rplus_jacobian_parameters_wrt_delta(T_wa) + assert_arrays_close(J_ours, J_jacfwd) + else: + # Batch axes should match vmap. + jacfunc = jaxlie.manifold.rplus_jacobian_parameters_wrt_delta + for _ in batch_axes: + jacfunc = jax.vmap(jacfunc) + J_vmap = jacfunc(T_wa) + assert_arrays_close(J_ours, J_vmap) @jax.jit @@ -51,11 +61,11 @@ def _rplus_jacobian_parameters_wrt_delta( @general_group_test_faster -def test_sgd(Group: Type[jaxlie.MatrixLieGroup]): +def test_sgd(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): def loss(transform: jaxlie.MatrixLieGroup): return (transform.log() ** 2).sum() - transform = Group.exp(sample_transform(Group).log()) + transform = Group.exp(sample_transform(Group, batch_axes).log()) original_loss = loss(transform) @jax.jit diff --git a/tests/test_operations.py b/tests/test_operations.py index 07d5a6c..2456be1 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -1,11 +1,13 @@ """Tests for general operation definitions.""" -from typing import Type +from typing import Tuple, Type +import jaxlie import numpy as onp from hypothesis import given, settings from hypothesis import strategies as st from jax import numpy as jnp + from utils import ( assert_arrays_close, assert_transforms_close, @@ -13,14 +15,14 @@ sample_transform, ) -import jaxlie - @general_group_test -def test_sample_uniform_valid(Group: Type[jaxlie.MatrixLieGroup]): +def test_sample_uniform_valid( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that sample_uniform() returns valid group members.""" - T = sample_transform(Group) # Calls sample_uniform under the hood. - assert_transforms_close(T, T.normalize()) + T = sample_transform(Group, batch_axes) # Calls sample_uniform under the hood. + # assert_transforms_close(T, T.normalize()) @settings(deadline=None) @@ -48,12 +50,14 @@ def test_so3_rpy_bijective(_random_module): @general_group_test -def test_log_exp_bijective(Group: Type[jaxlie.MatrixLieGroup]): +def test_log_exp_bijective( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check 1-to-1 mapping for log <=> exp operations.""" - transform = sample_transform(Group) + transform = sample_transform(Group, batch_axes) tangent = transform.log() - assert tangent.shape == (Group.tangent_dim,) + assert tangent.shape == (*batch_axes, Group.tangent_dim) exp_transform = Group.exp(tangent) assert_transforms_close(transform, exp_transform) @@ -61,48 +65,53 @@ def test_log_exp_bijective(Group: Type[jaxlie.MatrixLieGroup]): @general_group_test -def test_inverse_bijective(Group: Type[jaxlie.MatrixLieGroup]): +def test_inverse_bijective( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check inverse of inverse.""" - transform = sample_transform(Group) + transform = sample_transform(Group, batch_axes) assert_transforms_close(transform, transform.inverse().inverse()) @general_group_test -def test_matrix_bijective(Group: Type[jaxlie.MatrixLieGroup]): +def test_matrix_bijective( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check that we can convert to and from matrices.""" - transform = sample_transform(Group) + transform = sample_transform(Group, batch_axes) assert_transforms_close(transform, Group.from_matrix(transform.as_matrix())) @general_group_test -def test_adjoint(Group: Type[jaxlie.MatrixLieGroup]): +def test_adjoint(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check adjoint definition.""" - transform = sample_transform(Group) - omega = onp.random.randn(Group.tangent_dim) + transform = sample_transform(Group, batch_axes) + omega = onp.random.randn(*batch_axes, Group.tangent_dim) assert_transforms_close( transform @ Group.exp(omega), - Group.exp(transform.adjoint() @ omega) @ transform, + Group.exp(onp.einsum("...ij,...j->...i", transform.adjoint(), omega)) + @ transform, ) @general_group_test -def test_repr(Group: Type[jaxlie.MatrixLieGroup]): +def test_repr(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Smoke test for __repr__ implementations.""" - transform = sample_transform(Group) + transform = sample_transform(Group, batch_axes) print(transform) @general_group_test -def test_apply(Group: Type[jaxlie.MatrixLieGroup]): +def test_apply(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check group action interfaces.""" - T_w_b = sample_transform(Group) - p_b = onp.random.randn(Group.space_dim) + T_w_b = sample_transform(Group, batch_axes) + p_b = onp.random.randn(*batch_axes, Group.space_dim) if Group.matrix_dim == Group.space_dim: assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - T_w_b.as_matrix() @ p_b, + onp.einsum("...ij,...j->...i", T_w_b.as_matrix(), p_b), ) else: # Homogeneous coordinates. @@ -110,19 +119,33 @@ def test_apply(Group: Type[jaxlie.MatrixLieGroup]): assert_arrays_close( T_w_b @ p_b, T_w_b.apply(p_b), - (T_w_b.as_matrix() @ onp.append(p_b, 1.0))[:-1], + onp.einsum( + "...ij,...j->...i", + T_w_b.as_matrix(), + onp.concatenate([p_b, onp.ones_like(p_b[..., :1])], axis=-1), + )[..., :-1], ) @general_group_test -def test_multiply(Group: Type[jaxlie.MatrixLieGroup]): +def test_multiply(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): """Check multiply interfaces.""" - T_w_b = sample_transform(Group) - T_b_a = sample_transform(Group) + T_w_b = sample_transform(Group, batch_axes) + T_b_a = sample_transform(Group, batch_axes) assert_arrays_close( - T_w_b.as_matrix() @ T_w_b.inverse().as_matrix(), onp.eye(Group.matrix_dim) + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), T_w_b.inverse().as_matrix() + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), ) assert_arrays_close( - T_w_b.as_matrix() @ jnp.linalg.inv(T_w_b.as_matrix()), onp.eye(Group.matrix_dim) + onp.einsum( + "...ij,...jk->...ik", T_w_b.as_matrix(), jnp.linalg.inv(T_w_b.as_matrix()) + ), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), ) assert_transforms_close(T_w_b @ T_b_a, Group.multiply(T_w_b, T_b_a)) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 84018df..74a8fdf 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,18 +1,20 @@ """Test transform serialization, for things like saving calibrated transforms to disk.""" -from typing import Type - -import flax -from utils import assert_transforms_close, general_group_test, sample_transform +from typing import Tuple, Type +import flax.serialization import jaxlie +from utils import assert_transforms_close, general_group_test, sample_transform + @general_group_test -def test_serialization_state_dict_bijective(Group: Type[jaxlie.MatrixLieGroup]): +def test_serialization_state_dict_bijective( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check bijectivity of state dict representation conversions.""" - T = sample_transform(Group) + T = sample_transform(Group, batch_axes) T_recovered = flax.serialization.from_state_dict( T, flax.serialization.to_state_dict(T) ) @@ -20,8 +22,10 @@ def test_serialization_state_dict_bijective(Group: Type[jaxlie.MatrixLieGroup]): @general_group_test -def test_serialization_bytes_bijective(Group: Type[jaxlie.MatrixLieGroup]): +def test_serialization_bytes_bijective( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] +): """Check bijectivity of byte representation conversions.""" - T = sample_transform(Group) + T = sample_transform(Group, batch_axes) T_recovered = flax.serialization.from_bytes(T, flax.serialization.to_bytes(T)) assert_transforms_close(T, T_recovered) diff --git a/tests/utils.py b/tests/utils.py index 55d2780..98f77cc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,8 +1,9 @@ import functools import random -from typing import Any, Callable, List, Type, TypeVar, cast +from typing import Any, Callable, List, Tuple, Type, TypeVar, cast import jax +import jaxlie import numpy as onp import pytest import scipy.optimize @@ -10,40 +11,48 @@ from hypothesis import strategies as st from jax import numpy as jnp -import jaxlie - # Run all tests with double-precision. jax.config.update("jax_enable_x64", True) T = TypeVar("T", bound=jaxlie.MatrixLieGroup) -def sample_transform(Group: Type[T]) -> T: +def sample_transform(Group: Type[T], batch_axes: Tuple[int, ...] = ()) -> T: """Sample a random transform from a group.""" seed = random.getrandbits(32) strategy = random.randint(0, 2) if strategy == 0: # Uniform sampling. - return cast(T, Group.sample_uniform(key=jax.random.PRNGKey(seed=seed))) + return cast( + T, + Group.sample_uniform( + key=jax.random.PRNGKey(seed=seed), batch_axes=batch_axes + ), + ) elif strategy == 1: # Sample from normally-sampled tangent vector. - return cast(T, Group.exp(onp.random.randn(Group.tangent_dim))) + return cast(T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim))) elif strategy == 2: # Sample near identity. - return cast(T, Group.exp(onp.random.randn(Group.tangent_dim) * 1e-7)) + return cast( + T, Group.exp(onp.random.randn(*batch_axes, Group.tangent_dim) * 1e-7) + ) else: assert False def general_group_test( - f: Callable[[Type[jaxlie.MatrixLieGroup]], None], max_examples: int = 30 -) -> Callable[[Type[jaxlie.MatrixLieGroup], Any], None]: + f: Callable[[Type[jaxlie.MatrixLieGroup], Tuple[int, ...]], None], + max_examples: int = 10, +) -> Callable[[Type[jaxlie.MatrixLieGroup], Tuple[int, ...], Any], None]: """Decorator for defining tests that run on all group types.""" # Disregard unused argument. - def f_wrapped(Group: Type[jaxlie.MatrixLieGroup], _random_module) -> None: - f(Group) + def f_wrapped( + Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...], _random_module + ) -> None: + f(Group, batch_axes) # Disable timing check (first run requires JIT tracing and will be slower). f_wrapped = settings(deadline=None, max_examples=max_examples)(f_wrapped) @@ -61,6 +70,15 @@ def f_wrapped(Group: Type[jaxlie.MatrixLieGroup], _random_module) -> None: jaxlie.SE3, ], )(f_wrapped) + + # Parametrize tests with each group type. + f_wrapped = pytest.mark.parametrize( + "batch_axes", + [ + (), + (1,), + ], + )(f_wrapped) return f_wrapped @@ -77,11 +95,11 @@ def assert_transforms_close(a: jaxlie.MatrixLieGroup, b: jaxlie.MatrixLieGroup): p1 = jnp.asarray(a.parameters()) p2 = jnp.asarray(b.parameters()) if isinstance(a, jaxlie.SO3): - p1 = p1 * jnp.sign(jnp.sum(p1)) - p2 = p2 * jnp.sign(jnp.sum(p2)) + p1 = p1 * jnp.sign(jnp.sum(p1, axis=-1)) + p2 = p2 * jnp.sign(jnp.sum(p2, axis=-1)) elif isinstance(a, jaxlie.SE3): - p1 = p1.at[:4].mul(jnp.sign(jnp.sum(p1[:4]))) - p2 = p2.at[:4].mul(jnp.sign(jnp.sum(p2[:4]))) + p1 = p1.at[..., :4].mul(jnp.sign(jnp.sum(p1[..., :4], axis=-1))) + p2 = p2.at[..., :4].mul(jnp.sign(jnp.sum(p2[..., :4], axis=-1))) # Make sure parameters are equal. assert_arrays_close(p1, p2)