Skip to content

Commit

Permalink
Support batch axes / broadcasting (#19)
Browse files Browse the repository at this point in the history
* Support batch axes across APIs

* Add batch axes to tests + bug fixes

* Fix index error for analytical SO(2) Jacobians

* Formatting

* Broadcasting tests, bug fixes

* More fixes

* More test improvements, fixes

* Update vmap example with broadcasting notes

* Formatting

* More comments

* Add bullet to docs
  • Loading branch information
brentyi authored May 6, 2024
1 parent 7374df2 commit bacbd19
Show file tree
Hide file tree
Showing 23 changed files with 756 additions and 445 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Where each group supports:
<code>jaxlie.<strong>manifold.\*</strong></code>).
- Compatibility with standard JAX function transformations. (see
[./examples/vmap_example.py](./examples/vmap_example.py))
- Broadcasting for leading axes.
- (Un)flattening as pytree nodes.
- Serialization using [flax](https://github.com/google/flax).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Current functionality:

- Pytree registration for all dataclasses.

- Broadcasting for leading axes.

- Helpers + analytical Jacobians for tangent-space optimization
(:code:`jaxlie.manifold`).

Expand Down
24 changes: 15 additions & 9 deletions examples/se3_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def from_global(params: Parameters) -> ExponentialCoordinatesParameters:


def compute_loss(
params: Union[Parameters, ExponentialCoordinatesParameters]
params: Union[Parameters, ExponentialCoordinatesParameters],
) -> jax.Array:
"""As our loss, we enforce (a) priors on our transforms and (b) a consistency
constraint."""
Expand Down Expand Up @@ -128,11 +128,11 @@ def initialize(algorithm: Algorithm, learning_rate: float) -> State:
elif algorithm == "projected":
# Initialize gradient statistics directly in quaternion space.
params = global_params
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
elif algorithm == "exponential_coordinates":
# Switch to a log-space parameterization.
params = ExponentialCoordinatesParameters.from_global(global_params)
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
else:
assert_never(algorithm)

Expand All @@ -155,17 +155,21 @@ def step(self: State) -> Tuple[jax.Array, State]:
# the tangent space.
loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = jaxlie.manifold.rplus(self.params, updates)

elif self.algorithm == "projected":
# Projection-based approach.
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

# Project back to manifold.
new_params = jaxlie.manifold.normalize_all(new_params)
Expand All @@ -174,16 +178,18 @@ def step(self: State) -> Tuple[jax.Array, State]:
# If we parameterize with exponential coordinates, we can
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

else:
assert assert_never(self.algorithm)

# Return updated structure.
with jdc.copy_and_mutate(self, validate=True) as new_state:
new_state.params = new_params # type: ignore
new_state.params = new_params
new_state.optimizer_state = new_optimizer_state

return loss, new_state
Expand Down
23 changes: 20 additions & 3 deletions examples/vmap_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Examples of vectorizing transformations via vmap.
"""jaxlie implements numpy-style broadcasting for all operations. For more
explicit vectorization, we can also use vmap function transformations.
Omitted for brevity here, but note that in practice we usually want to JIT after
vmapping!"""
Omitted for brevity here, but in practice we usually want to JIT after
vmapping."""

import jax
import numpy as onp
Expand Down Expand Up @@ -60,6 +61,10 @@
p_transformed_stacked = jax.vmap(lambda p: SO3.apply(R_single, p))(p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_single @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (4) Applying N transformations to N points.
#############################
Expand All @@ -69,13 +74,21 @@
p_transformed_stacked = jax.vmap(SO3.apply)(R_stacked, p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (5) Applying N transformations to 1 point.
#############################

p_transformed_stacked = jax.vmap(lambda R: SO3.apply(R, p_single))(R_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_single[None, :]
assert p_transformed_stacked.shape == (N, 3)

#############################
# (6) Multiplying transformations.
#############################
Expand All @@ -95,3 +108,7 @@

# Or N x 1 multiplication:
assert (jax.vmap(lambda R: SO3.multiply(R, R_single))(R_stacked)).wxyz.shape == (N, 4)

# Again, broadcasting also works.
assert (R_stacked @ R_stacked).wxyz.shape == (N, 4)
assert (R_stacked @ SO3(R_single.wxyz[None, :])).wxyz.shape == (N, 4)
29 changes: 10 additions & 19 deletions jaxlie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from . import hints, manifold, utils
from ._base import MatrixLieGroup, SEBase, SOBase
from ._se2 import SE2
from ._se3 import SE3
from ._so2 import SO2
from ._so3 import SO3

__all__ = [
"hints",
"manifold",
"utils",
"MatrixLieGroup",
"SOBase",
"SEBase",
"SE2",
"SO2",
"SE3",
"SO3",
]
from . import hints as hints
from . import manifold as manifold
from . import utils as utils
from ._base import MatrixLieGroup as MatrixLieGroup
from ._base import SEBase as SEBase
from ._base import SOBase as SOBase
from ._se2 import SE2 as SE2
from ._se3 import SE3 as SE3
from ._so2 import SO2 as SO2
from ._so3 import SO3 as SO3
24 changes: 16 additions & 8 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload

import jax
import numpy as onp
from jax import numpy as jnp
from typing_extensions import Self, final, override

from . import hints
Expand Down Expand Up @@ -64,9 +65,12 @@ def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]:

@classmethod
@abc.abstractmethod
def identity(cls) -> Self:
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
Returns:
Identity element.
"""
Expand Down Expand Up @@ -169,24 +173,25 @@ def normalize(self) -> Self:

@classmethod
@abc.abstractmethod
def sample_uniform(cls, key: jax.Array) -> Self:
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Args:
key: PRNG key, as returned by `jax.random.PRNGKey()`.
batch_axes: Any leading batch axes for the output transforms. Each
sampled transform will be different.
Returns:
Sampled group member.
"""

@abc.abstractmethod
@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`.
This should generally be implemented by `jdc.EnforcedAnnotationsMixin`."""
return `(100,)`."""
return self.parameters().shape[:-1]


class SOBase(MatrixLieGroup):
Expand Down Expand Up @@ -227,7 +232,10 @@ def from_rotation_and_translation(
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
translation=jnp.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)

@abc.abstractmethod
Expand Down
Loading

0 comments on commit bacbd19

Please sign in to comment.