From 53edbf91eed66ecf7e85cf5728206047077a8c2e Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Sun, 5 May 2024 00:23:41 -0700 Subject: [PATCH] Formatting --- jaxlie/_so2.py | 4 +++- jaxlie/manifold/_deltas.py | 12 ++++-------- jaxlie/utils/_utils.py | 1 - tests/test_autodiff.py | 4 ++-- tests/test_group_axioms.py | 13 +++++++++---- tests/test_manifold.py | 4 ++-- tests/test_operations.py | 4 ++-- tests/test_serialization.py | 4 ++-- tests/utils.py | 3 ++- 9 files changed, 26 insertions(+), 23 deletions(-) diff --git a/jaxlie/_so2.py b/jaxlie/_so2.py index 17a0d96..4100cb7 100644 --- a/jaxlie/_so2.py +++ b/jaxlie/_so2.py @@ -130,7 +130,9 @@ def normalize(self) -> SO2: @classmethod @override - def sample_uniform(cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2: + def sample_uniform( + cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = () + ) -> SO2: out = SO2.from_radians( jax.random.uniform( key=key, shape=batch_axes, minval=0.0, maxval=2.0 * jnp.pi diff --git a/jaxlie/manifold/_deltas.py b/jaxlie/manifold/_deltas.py index 75ba1ca..f15181c 100644 --- a/jaxlie/manifold/_deltas.py +++ b/jaxlie/manifold/_deltas.py @@ -49,16 +49,14 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType: def rplus( transform: GroupType, delta: hints.Array, -) -> GroupType: - ... +) -> GroupType: ... @overload def rplus( transform: PytreeType, delta: _tree_utils.TangentPytree, -) -> PytreeType: - ... +) -> PytreeType: ... # Using our typevars in the overloaded signature will cause errors. @@ -81,13 +79,11 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array: @overload -def rminus(a: GroupType, b: GroupType) -> jax.Array: - ... +def rminus(a: GroupType, b: GroupType) -> jax.Array: ... @overload -def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: - ... +def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... # Using our typevars in the overloaded signature will cause errors. diff --git a/jaxlie/utils/_utils.py b/jaxlie/utils/_utils.py index 1f45b00..f67f92f 100644 --- a/jaxlie/utils/_utils.py +++ b/jaxlie/utils/_utils.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Callable, Type, TypeVar -import jax_dataclasses as jdc from jax import numpy as jnp if TYPE_CHECKING: diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index a7ed04f..4f81f7d 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -4,12 +4,12 @@ from typing import Callable, Tuple, Type, cast import jax -import jaxlie import numpy as onp from jax import numpy as jnp - from utils import assert_arrays_close, general_group_test, jacnumerical +import jaxlie + # We cache JITed Jacobians to improve runtime. cached_jacfwd = lru_cache(maxsize=None)( lambda f: jax.jit(jax.jacfwd(f, argnums=1), static_argnums=0) diff --git a/tests/test_group_axioms.py b/tests/test_group_axioms.py index 7b4e70d..3a43c4f 100644 --- a/tests/test_group_axioms.py +++ b/tests/test_group_axioms.py @@ -2,11 +2,10 @@ https://proofwiki.org/wiki/Definition:Group_Axioms """ + from typing import Tuple, Type -import jaxlie import numpy as onp - from utils import ( assert_arrays_close, assert_transforms_close, @@ -14,6 +13,8 @@ sample_transform, ) +import jaxlie + @general_group_test def test_closure(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): @@ -58,7 +59,9 @@ def test_inverse(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] assert_transforms_close(identity, Group.multiply(transform, transform.inverse())) assert_transforms_close(identity, Group.multiply(transform.inverse(), transform)) assert_arrays_close( - onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), onp.einsum( "...ij,...jk->...ik", transform.as_matrix(), @@ -66,7 +69,9 @@ def test_inverse(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...] ), ) assert_arrays_close( - onp.broadcast_to(onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim)), + onp.broadcast_to( + onp.eye(Group.matrix_dim), (*batch_axes, Group.matrix_dim, Group.matrix_dim) + ), onp.einsum( "...ij,...jk->...ik", transform.inverse().as_matrix(), diff --git a/tests/test_manifold.py b/tests/test_manifold.py index 6b31356..d930f62 100644 --- a/tests/test_manifold.py +++ b/tests/test_manifold.py @@ -3,12 +3,10 @@ from typing import Tuple, Type import jax -import jaxlie import numpy as onp import pytest from jax import numpy as jnp from jax import tree_util - from utils import ( assert_arrays_close, assert_transforms_close, @@ -17,6 +15,8 @@ sample_transform, ) +import jaxlie + @general_group_test def test_rplus_rminus(Group: Type[jaxlie.MatrixLieGroup], batch_axes: Tuple[int, ...]): diff --git a/tests/test_operations.py b/tests/test_operations.py index 2456be1..3582a3f 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -2,12 +2,10 @@ from typing import Tuple, Type -import jaxlie import numpy as onp from hypothesis import given, settings from hypothesis import strategies as st from jax import numpy as jnp - from utils import ( assert_arrays_close, assert_transforms_close, @@ -15,6 +13,8 @@ sample_transform, ) +import jaxlie + @general_group_test def test_sample_uniform_valid( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 74a8fdf..5f979d8 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,10 +4,10 @@ from typing import Tuple, Type import flax.serialization -import jaxlie - from utils import assert_transforms_close, general_group_test, sample_transform +import jaxlie + @general_group_test def test_serialization_state_dict_bijective( diff --git a/tests/utils.py b/tests/utils.py index 98f77cc..70780fd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,7 +3,6 @@ from typing import Any, Callable, List, Tuple, Type, TypeVar, cast import jax -import jaxlie import numpy as onp import pytest import scipy.optimize @@ -11,6 +10,8 @@ from hypothesis import strategies as st from jax import numpy as jnp +import jaxlie + # Run all tests with double-precision. jax.config.update("jax_enable_x64", True)