Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 5, 2024
1 parent 03aa2f4 commit 53edbf9
Show file tree
Hide file tree
Showing 9 changed files with 26 additions and 23 deletions.
4 changes: 3 additions & 1 deletion jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def normalize(self) -> SO2:

@classmethod
@override
def sample_uniform(cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2:
def sample_uniform(
cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()
) -> SO2:
out = SO2.from_radians(
jax.random.uniform(
key=key, shape=batch_axes, minval=0.0, maxval=2.0 * jnp.pi
Expand Down
12 changes: 4 additions & 8 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,14 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
def rplus(
transform: GroupType,
delta: hints.Array,
) -> GroupType:
...
) -> GroupType: ...


@overload
def rplus(
transform: PytreeType,
delta: _tree_utils.TangentPytree,
) -> PytreeType:
...
) -> PytreeType: ...


# Using our typevars in the overloaded signature will cause errors.
Expand All @@ -81,13 +79,11 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array:


@overload
def rminus(a: GroupType, b: GroupType) -> jax.Array:
...
def rminus(a: GroupType, b: GroupType) -> jax.Array: ...


@overload
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree:
...
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ...


# Using our typevars in the overloaded signature will cause errors.
Expand Down
1 change: 0 additions & 1 deletion jaxlie/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Callable, Type, TypeVar

import jax_dataclasses as jdc
from jax import numpy as jnp

if TYPE_CHECKING:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Callable, Tuple, Type, cast

import jax
import jaxlie
import numpy as onp
from jax import numpy as jnp

from utils import assert_arrays_close, general_group_test, jacnumerical

import jaxlie

# We cache JITed Jacobians to improve runtime.
cached_jacfwd = lru_cache(maxsize=None)(
lambda f: jax.jit(jax.jacfwd(f, argnums=1), static_argnums=0)
Expand Down
13 changes: 9 additions & 4 deletions tests/test_group_axioms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
https://proofwiki.org/wiki/Definition:Group_Axioms
"""

from typing import Tuple, Type

import jaxlie
import numpy as onp

from utils import (
assert_arrays_close,
assert_transforms_close,
general_group_test,
sample_transform,
)

import jaxlie


@general_group_test
def test_closure(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]):
Expand Down Expand Up @@ -58,15 +59,19 @@ def test_inverse(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]
assert_transforms_close(identity, Group.multiply(transform, transform.inverse()))
assert_transforms_close(identity, Group.multiply(transform.inverse(), transform))
assert_arrays_close(
onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)),
onp.broadcast_to(
onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)
),
onp.einsum(
"...ij,...jk->...ik",
transform.as_matrix(),
transform.inverse().as_matrix(),
),
)
assert_arrays_close(
onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)),
onp.broadcast_to(
onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)
),
onp.einsum(
"...ij,...jk->...ik",
transform.inverse().as_matrix(),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from typing import Tuple, Type

import jax
import jaxlie
import numpy as onp
import pytest
from jax import numpy as jnp
from jax import tree_util

from utils import (
assert_arrays_close,
assert_transforms_close,
Expand All @@ -17,6 +15,8 @@
sample_transform,
)

import jaxlie


@general_group_test
def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@

from typing import Tuple, Type

import jaxlie
import numpy as onp
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp

from utils import (
assert_arrays_close,
assert_transforms_close,
general_group_test,
sample_transform,
)

import jaxlie


@general_group_test
def test_sample_uniform_valid(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Tuple, Type

import flax.serialization
import jaxlie

from utils import assert_transforms_close, general_group_test, sample_transform

import jaxlie


@general_group_test
def test_serialization_state_dict_bijective(
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from typing import Any, Callable, List, Tuple, Type, TypeVar, cast

import jax
import jaxlie
import numpy as onp
import pytest
import scipy.optimize
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp

import jaxlie

# Run all tests with double-precision.
jax.config.update("jax_enable_x64", True)

Expand Down

0 comments on commit 53edbf9

Please sign in to comment.