Skip to content

Commit

Permalink
More test improvements, fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 6, 2024
1 parent a043c14 commit 24edd26
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 39 deletions.
3 changes: 2 additions & 1 deletion jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import _base, hints
from ._so2 import SO2
from .utils import get_epsilon, register_lie_group
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group


@register_lie_group(
Expand Down Expand Up @@ -55,6 +55,7 @@ def from_rotation_and_translation(
translation: hints.Array,
) -> "SE2":
assert translation.shape[-1:] == (2,)
rotation, translation = broadcast_leading_axes((rotation, translation))
return SE2(
unit_complex_xy=jnp.concatenate(
[rotation.unit_complex, translation], axis=-1
Expand Down
10 changes: 5 additions & 5 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from . import _base, hints
from ._so3 import SO3
from .utils import get_epsilon, register_lie_group
from .utils import broadcast_leading_axes, get_epsilon, register_lie_group


def _skew(omega: hints.Array) -> jax.Array:
Expand Down Expand Up @@ -58,6 +58,7 @@ def from_rotation_and_translation(
translation: hints.Array,
) -> SE3:
assert translation.shape[-1:] == (3,)
rotation, translation = broadcast_leading_axes((rotation, translation))
return SE3(wxyz_xyz=jnp.concatenate([rotation.wxyz, translation], axis=-1))

@override
Expand Down Expand Up @@ -93,16 +94,15 @@ def from_matrix(cls, matrix: hints.Array) -> SE3:

@override
def as_matrix(self) -> jax.Array:
out = jnp.zeros((*self.get_batch_axes(), 4, 4))
out = (
out.at[..., :3, :3]
return (
jnp.zeros((*self.get_batch_axes(), 4, 4))
.at[..., :3, :3]
.set(self.rotation().as_matrix())
.at[..., :3, 3]
.set(self.translation())
.at[..., 3, 3]
.set(1.0)
)
return out

@override
def parameters(self) -> jax.Array:
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import override

from . import _base, hints
from .utils import register_lie_group, broadcast_leading_axes
from .utils import broadcast_leading_axes, register_lie_group


@register_lie_group(
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def case3(m):
jnp.where(cond2, case2_t, case3_t),
)
q = jnp.where(
cond0,
cond0[..., None],
jnp.where(cond1[..., None], case0_q, case1_q),
jnp.where(cond2[..., None], case2_q, case3_q),
)
Expand Down
31 changes: 11 additions & 20 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,31 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:
Returns:
Jacobian. Shape should be `(Group.parameters_dim, Group.tangent_dim)`.
"""
if type(transform) is SO2:
if isinstance(transform, SO2):
# Jacobian row indices: cos, sin
# Jacobian col indices: theta

transform_so2 = cast(SO2, transform)
J = jnp.zeros((*transform.get_batch_axes(), 2, 1))
cos, sin = jnp.moveaxis(transform.unit_complex, -1, 0)
J = J.at[..., 0, 0].set(-sin).at[..., 1, 0].set(cos)

cos, sin = jnp.moveaxis(transform_so2.unit_complex, -1, 0)
J = J.at[..., 0, :].set(-sin).at[..., 1, :].set(cos)

elif type(transform) is SE2:
elif isinstance(transform, SE2):
# Jacobian row indices: cos, sin, x, y
# Jacobian col indices: vx, vy, omega

transform_se2 = cast(SE2, transform)
J = jnp.zeros((*transform.get_batch_axes(), 4, 3))

# Translation terms.
J = J.at[..., 2:, :2].set(transform_se2.rotation().as_matrix())
J = J.at[..., 2:, :2].set(transform.rotation().as_matrix())

# Rotation terms.
J = J.at[..., :2, 2:3].set(
rplus_jacobian_parameters_wrt_delta(transform_se2.rotation())
rplus_jacobian_parameters_wrt_delta(transform.rotation())
)

elif type(transform) is SO3:
elif isinstance(transform, SO3):
# Jacobian row indices: qw, qx, qy, qz
# Jacobian col indices: omega x, omega y, omega z

transform_so3 = cast(SO3, transform)

w, x, y, z = jnp.moveaxis(transform_so3.wxyz, -1, 0)
w, x, y, z = jnp.moveaxis(transform.wxyz, -1, 0)
neg_x = -x
neg_y = -y
neg_z = -z
Expand All @@ -159,19 +152,17 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:
/ 2.0
)

elif type(transform) is SE3:
elif isinstance(transform, SE3):
# Jacobian row indices: qw, qx, qy, qz, x, y, z
# Jacobian col indices: vx, vy, vz, omega x, omega y, omega z

transform_se3 = cast(SE3, transform)
J = jnp.zeros((*transform.get_batch_axes(), 7, 6))

# Translation terms.
J = J.at[..., 4:, :3].set(transform_se3.rotation().as_matrix())
J = J.at[..., 4:, :3].set(transform.rotation().as_matrix())

# Rotation terms.
J = J.at[..., :4, 3:6].set(
rplus_jacobian_parameters_wrt_delta(transform_se3.rotation())
rplus_jacobian_parameters_wrt_delta(transform.rotation())
)

else:
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._utils import get_epsilon, register_lie_group, broadcast_leading_axes
from ._utils import broadcast_leading_axes, get_epsilon, register_lie_group

__all__ = ["get_epsilon", "register_lie_group", "broadcast_leading_axes"]
16 changes: 9 additions & 7 deletions jaxlie/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast

import jax_dataclasses as jdc
from jax import numpy as jnp

from jaxlie.hints import Array

if TYPE_CHECKING:
Expand Down Expand Up @@ -45,13 +47,13 @@ def _wrap(cls: Type[T]) -> Type[T]:
cls.space_dim = space_dim

# JIT all methods.
# for f in filter(
# lambda f: not f.startswith("_")
# and callable(getattr(cls, f))
# and f != "get_batch_axes", # Avoid returning tracers.
# dir(cls),
# ):
# setattr(cls, f, jdc.jit(getattr(cls, f)))
for f in filter(
lambda f: not f.startswith("_")
and callable(getattr(cls, f))
and f != "get_batch_axes", # Avoid returning tracers.
dir(cls),
):
setattr(cls, f, jdc.jit(getattr(cls, f)))

return cls

Expand Down
13 changes: 11 additions & 2 deletions tests/test_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from typing import Tuple, Type

import jaxlie
import numpy as onp
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp

from utils import (
assert_arrays_close,
assert_transforms_close,
general_group_test,
sample_transform,
)

import jaxlie


@general_group_test
def test_broadcast_multiply(
Expand All @@ -29,6 +29,11 @@ def test_broadcast_multiply(
T = sample_transform(Group, batch_axes) @ sample_transform(Group, batch_axes=(1,))
assert T.get_batch_axes() == batch_axes

T = sample_transform(Group, batch_axes) @ sample_transform(
Group, batch_axes=(1,) * len(batch_axes)
)
assert T.get_batch_axes() == batch_axes

T = sample_transform(Group) @ sample_transform(Group, batch_axes)
assert T.get_batch_axes() == batch_axes

Expand All @@ -51,6 +56,10 @@ def test_broadcast_apply(
points = onp.random.randn(1, Group.space_dim)
assert (T @ points).shape == (*batch_axes, Group.space_dim)

T = sample_transform(Group, batch_axes)
points = onp.random.randn(*((1,) * len(batch_axes)), Group.space_dim)
assert (T @ points).shape == (*batch_axes, Group.space_dim)

T = sample_transform(Group)
points = onp.random.randn(*batch_axes, Group.space_dim)
assert (T @ points).shape == (*batch_axes, Group.space_dim)
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Any, Callable, List, Tuple, Type, TypeVar, cast

import jax
import jaxlie
import numpy as onp
import pytest
import scipy.optimize
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp

import jaxlie

# Run all tests with double-precision.
jax.config.update("jax_enable_x64", True)

Expand Down

0 comments on commit 24edd26

Please sign in to comment.