Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batch axes / broadcasting #19

Merged
merged 11 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Where each group supports:
<code>jaxlie.<strong>manifold.\*</strong></code>).
- Compatibility with standard JAX function transformations. (see
[./examples/vmap_example.py](./examples/vmap_example.py))
- Broadcasting for leading axes.
- (Un)flattening as pytree nodes.
- Serialization using [flax](https://github.com/google/flax).

Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Current functionality:

- Pytree registration for all dataclasses.

- Broadcasting for leading axes.

- Helpers + analytical Jacobians for tangent-space optimization
(:code:`jaxlie.manifold`).

Expand Down
24 changes: 15 additions & 9 deletions examples/se3_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def from_global(params: Parameters) -> ExponentialCoordinatesParameters:


def compute_loss(
params: Union[Parameters, ExponentialCoordinatesParameters]
params: Union[Parameters, ExponentialCoordinatesParameters],
) -> jax.Array:
"""As our loss, we enforce (a) priors on our transforms and (b) a consistency
constraint."""
Expand Down Expand Up @@ -128,11 +128,11 @@ def initialize(algorithm: Algorithm, learning_rate: float) -> State:
elif algorithm == "projected":
# Initialize gradient statistics directly in quaternion space.
params = global_params
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
elif algorithm == "exponential_coordinates":
# Switch to a log-space parameterization.
params = ExponentialCoordinatesParameters.from_global(global_params)
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
else:
assert_never(algorithm)

Expand All @@ -155,17 +155,21 @@ def step(self: State) -> Tuple[jax.Array, State]:
# the tangent space.
loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = jaxlie.manifold.rplus(self.params, updates)

elif self.algorithm == "projected":
# Projection-based approach.
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

# Project back to manifold.
new_params = jaxlie.manifold.normalize_all(new_params)
Expand All @@ -174,16 +178,18 @@ def step(self: State) -> Tuple[jax.Array, State]:
# If we parameterize with exponential coordinates, we can
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

else:
assert assert_never(self.algorithm)

# Return updated structure.
with jdc.copy_and_mutate(self, validate=True) as new_state:
new_state.params = new_params # type: ignore
new_state.params = new_params
new_state.optimizer_state = new_optimizer_state

return loss, new_state
Expand Down
23 changes: 20 additions & 3 deletions examples/vmap_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Examples of vectorizing transformations via vmap.
"""jaxlie implements numpy-style broadcasting for all operations. For more
explicit vectorization, we can also use vmap function transformations.

Omitted for brevity here, but note that in practice we usually want to JIT after
vmapping!"""
Omitted for brevity here, but in practice we usually want to JIT after
vmapping."""

import jax
import numpy as onp
Expand Down Expand Up @@ -60,6 +61,10 @@
p_transformed_stacked = jax.vmap(lambda p: SO3.apply(R_single, p))(p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_single @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (4) Applying N transformations to N points.
#############################
Expand All @@ -69,13 +74,21 @@
p_transformed_stacked = jax.vmap(SO3.apply)(R_stacked, p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (5) Applying N transformations to 1 point.
#############################

p_transformed_stacked = jax.vmap(lambda R: SO3.apply(R, p_single))(R_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_single[None, :]
assert p_transformed_stacked.shape == (N, 3)

#############################
# (6) Multiplying transformations.
#############################
Expand All @@ -95,3 +108,7 @@

# Or N x 1 multiplication:
assert (jax.vmap(lambda R: SO3.multiply(R, R_single))(R_stacked)).wxyz.shape == (N, 4)

# Again, broadcasting also works.
assert (R_stacked @ R_stacked).wxyz.shape == (N, 4)
assert (R_stacked @ SO3(R_single.wxyz[None, :])).wxyz.shape == (N, 4)
29 changes: 10 additions & 19 deletions jaxlie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from . import hints, manifold, utils
from ._base import MatrixLieGroup, SEBase, SOBase
from ._se2 import SE2
from ._se3 import SE3
from ._so2 import SO2
from ._so3 import SO3

__all__ = [
"hints",
"manifold",
"utils",
"MatrixLieGroup",
"SOBase",
"SEBase",
"SE2",
"SO2",
"SE3",
"SO3",
]
from . import hints as hints
from . import manifold as manifold
from . import utils as utils
from ._base import MatrixLieGroup as MatrixLieGroup
from ._base import SEBase as SEBase
from ._base import SOBase as SOBase
from ._se2 import SE2 as SE2
from ._se3 import SE3 as SE3
from ._so2 import SO2 as SO2
from ._so3 import SO3 as SO3
24 changes: 16 additions & 8 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload

import jax
import numpy as onp
from jax import numpy as jnp
from typing_extensions import Self, final, override

from . import hints
Expand Down Expand Up @@ -64,9 +65,12 @@ def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]:

@classmethod
@abc.abstractmethod
def identity(cls) -> Self:
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Returns identity element.

Args:
batch_axes: Any leading batch axes for the output transform.

Returns:
Identity element.
"""
Expand Down Expand Up @@ -169,24 +173,25 @@ def normalize(self) -> Self:

@classmethod
@abc.abstractmethod
def sample_uniform(cls, key: jax.Array) -> Self:
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].

Args:
key: PRNG key, as returned by `jax.random.PRNGKey()`.
batch_axes: Any leading batch axes for the output transforms. Each
sampled transform will be different.

Returns:
Sampled group member.
"""

@abc.abstractmethod
@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`.

This should generally be implemented by `jdc.EnforcedAnnotationsMixin`."""
return `(100,)`."""
return self.parameters().shape[:-1]


class SOBase(MatrixLieGroup):
Expand Down Expand Up @@ -227,7 +232,10 @@ def from_rotation_and_translation(
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
translation=jnp.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)

@abc.abstractmethod
Expand Down
Loading
Loading