Skip to content

Commit

Permalink
Support batch axes across APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 5, 2024
1 parent 7374df2 commit 8960937
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 251 deletions.
24 changes: 15 additions & 9 deletions examples/se3_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def from_global(params: Parameters) -> ExponentialCoordinatesParameters:


def compute_loss(
params: Union[Parameters, ExponentialCoordinatesParameters]
params: Union[Parameters, ExponentialCoordinatesParameters],
) -> jax.Array:
"""As our loss, we enforce (a) priors on our transforms and (b) a consistency
constraint."""
Expand Down Expand Up @@ -128,11 +128,11 @@ def initialize(algorithm: Algorithm, learning_rate: float) -> State:
elif algorithm == "projected":
# Initialize gradient statistics directly in quaternion space.
params = global_params
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
elif algorithm == "exponential_coordinates":
# Switch to a log-space parameterization.
params = ExponentialCoordinatesParameters.from_global(global_params)
optimizer_state = optimizer.init(params) # type: ignore
optimizer_state = optimizer.init(params)
else:
assert_never(algorithm)

Expand All @@ -155,17 +155,21 @@ def step(self: State) -> Tuple[jax.Array, State]:
# the tangent space.
loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = jaxlie.manifold.rplus(self.params, updates)

elif self.algorithm == "projected":
# Projection-based approach.
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

# Project back to manifold.
new_params = jaxlie.manifold.normalize_all(new_params)
Expand All @@ -174,16 +178,18 @@ def step(self: State) -> Tuple[jax.Array, State]:
# If we parameterize with exponential coordinates, we can
loss, grads = jax.value_and_grad(compute_loss)(self.params)
updates, new_optimizer_state = self.optimizer.update(
grads, self.optimizer_state, self.params # type: ignore
grads,
self.optimizer_state,
self.params,
)
new_params = optax.apply_updates(self.params, updates) # type: ignore
new_params = optax.apply_updates(self.params, updates)

else:
assert assert_never(self.algorithm)

# Return updated structure.
with jdc.copy_and_mutate(self, validate=True) as new_state:
new_state.params = new_params # type: ignore
new_state.params = new_params
new_state.optimizer_state = new_optimizer_state

return loss, new_state
Expand Down
29 changes: 10 additions & 19 deletions jaxlie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,10 @@
from . import hints, manifold, utils
from ._base import MatrixLieGroup, SEBase, SOBase
from ._se2 import SE2
from ._se3 import SE3
from ._so2 import SO2
from ._so3 import SO3

__all__ = [
"hints",
"manifold",
"utils",
"MatrixLieGroup",
"SOBase",
"SEBase",
"SE2",
"SO2",
"SE3",
"SO3",
]
from . import hints as hints
from . import manifold as manifold
from . import utils as utils
from ._base import MatrixLieGroup as MatrixLieGroup
from ._base import SEBase as SEBase
from ._base import SOBase as SOBase
from ._se2 import SE2 as SE2
from ._se3 import SE3 as SE3
from ._so2 import SO2 as SO2
from ._so3 import SO3 as SO3
24 changes: 16 additions & 8 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import abc
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload
from typing import ClassVar, Generic, Tuple, TypeVar, Union, overload

import jax
import numpy as onp
from jax import numpy as jnp
from typing_extensions import Self, final, override

from . import hints
Expand Down Expand Up @@ -64,9 +65,12 @@ def __matmul__(self, other: Union[Self, hints.Array]) -> Union[Self, jax.Array]:

@classmethod
@abc.abstractmethod
def identity(cls) -> Self:
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Returns identity element.
Args:
batch_axes: Any leading batch axes for the output transform.
Returns:
Identity element.
"""
Expand Down Expand Up @@ -169,24 +173,25 @@ def normalize(self) -> Self:

@classmethod
@abc.abstractmethod
def sample_uniform(cls, key: jax.Array) -> Self:
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> Self:
"""Draw a uniform sample from the group. Translations (if applicable) are in the
range [-1, 1].
Args:
key: PRNG key, as returned by `jax.random.PRNGKey()`.
batch_axes: Any leading batch axes for the output transforms. Each
sampled transform will be different.
Returns:
Sampled group member.
"""

@abc.abstractmethod
@final
def get_batch_axes(self) -> Tuple[int, ...]:
"""Return any leading batch axes in contained parameters. If an array of shape
`(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will
return `(100,)`.
This should generally be implemented by `jdc.EnforcedAnnotationsMixin`."""
return `(100,)`."""
return self.parameters().shape[:-1]


class SOBase(MatrixLieGroup):
Expand Down Expand Up @@ -227,7 +232,10 @@ def from_rotation_and_translation(
def from_rotation(cls, rotation: ContainedSOType) -> Self:
return cls.from_rotation_and_translation(
rotation=rotation,
translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype),
translation=jnp.zeros(
(*rotation.get_batch_axes(), cls.space_dim),
dtype=rotation.parameters().dtype,
),
)

@abc.abstractmethod
Expand Down
101 changes: 58 additions & 43 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import cast
from typing import Tuple, cast

import jax
import jax_dataclasses as jdc
from jax import numpy as jnp
from typing_extensions import Annotated, override
from typing_extensions import override

from . import _base, hints
from ._so2 import SO2
Expand All @@ -17,7 +17,7 @@
space_dim=2,
)
@jdc.pytree_dataclass
class SE2(jdc.EnforcedAnnotationsMixin, _base.SEBase[SO2]):
class SE2(_base.SEBase[SO2]):
"""Special Euclidean group for proper rigid transforms in 2D.
Internal parameterization is `(cos, sin, x, y)`. Tangent parameterization is `(vx,
Expand All @@ -26,12 +26,8 @@ class SE2(jdc.EnforcedAnnotationsMixin, _base.SEBase[SO2]):

# SE2-specific.

unit_complex_xy: Annotated[
jax.Array,
(..., 4), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. `(cos, sin, x, y)`."""
unit_complex_xy: jax.Array
"""Internal parameters. `(cos, sin, x, y)`. Shape should be `(*, 3)`."""

@override
def __repr__(self) -> str:
Expand All @@ -47,7 +43,7 @@ def from_xy_theta(x: hints.Scalar, y: hints.Scalar, theta: hints.Scalar) -> "SE2
"""
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SE2(unit_complex_xy=jnp.array([cos, sin, x, y]))
return SE2(unit_complex_xy=jnp.stack([cos, sin, x, y], axis=-1))

# SE-specific.

Expand All @@ -58,9 +54,11 @@ def from_rotation_and_translation(
rotation: SO2,
translation: hints.Array,
) -> "SE2":
assert translation.shape == (2,)
assert translation.shape[-1:] == (2,)
return SE2(
unit_complex_xy=jnp.concatenate([rotation.unit_complex, translation])
unit_complex_xy=jnp.concatenate(
[rotation.unit_complex, translation], axis=-1
)
)

@override
Expand All @@ -75,17 +73,21 @@ def translation(self) -> jax.Array:

@classmethod
@override
def identity(cls) -> "SE2":
return SE2(unit_complex_xy=jnp.array([1.0, 0.0, 0.0, 0.0]))
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2":
return SE2(
unit_complex_xy=jnp.broadcast_to(
jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)
)
)

@classmethod
@override
def from_matrix(cls, matrix: hints.Array) -> "SE2":
assert matrix.shape == (3, 3)
assert matrix.shape[-2:] == (3, 3)
# Currently assumes bottom row is [0, 0, 1].
return SE2.from_rotation_and_translation(
rotation=SO2.from_matrix(matrix[:2, :2]),
translation=matrix[:2, 2],
rotation=SO2.from_matrix(matrix[..., :2, :2]),
translation=matrix[..., :2, 2],
)

# Accessors.
Expand All @@ -96,13 +98,9 @@ def parameters(self) -> jax.Array:

@override
def as_matrix(self) -> jax.Array:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
[
[cos, -sin, x],
[sin, cos, y],
[0.0, 0.0, 1.0],
]
cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0)
return jnp.stack([cos, -sin, x, sin, cos, y, 0.0, 0.0, 1.0], axis=-1).reshape(
(*self.get_batch_axes(), 3, 3)
)

# Operations.
Expand All @@ -115,9 +113,9 @@ def exp(cls, tangent: hints.Array) -> "SE2":
# Also see:
# > http://ethaneade.com/lie.pdf

assert tangent.shape == (3,)
assert tangent.shape[-1:] == (3,)

theta = tangent[2]
theta = tangent[..., 2]
use_taylor = jnp.abs(theta) < get_epsilon(tangent.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
Expand All @@ -126,7 +124,7 @@ def exp(cls, tangent: hints.Array) -> "SE2":
jax.Array,
jnp.where(
use_taylor,
1.0, # Any non-zero value should do here.
jnp.ones_like(theta), # Any non-zero value should do here.
theta,
),
)
Expand All @@ -149,15 +147,18 @@ def exp(cls, tangent: hints.Array) -> "SE2":
),
)

V = jnp.array(
V = jnp.stack(
[
[sin_over_theta, -one_minus_cos_over_theta],
[one_minus_cos_over_theta, sin_over_theta],
]
)
sin_over_theta,
-one_minus_cos_over_theta,
one_minus_cos_over_theta,
sin_over_theta,
],
axis=-1,
).reshape((*tangent.shape[:-1], 2, 2))
return SE2.from_rotation_and_translation(
rotation=SO2.from_radians(theta),
translation=V @ tangent[:2],
translation=jnp.einsum("...ij,...j->...i", V, tangent[..., :2]),
)

@override
Expand All @@ -167,7 +168,7 @@ def log(self) -> jax.Array:
# Also see:
# > http://ethaneade.com/lie.pdf

theta = self.rotation().log()[0]
theta = self.rotation().log()[..., 0]

cos = jnp.cos(theta)
cos_minus_one = cos - 1.0
Expand All @@ -178,7 +179,7 @@ def log(self) -> jax.Array:
# reverse-mode AD.
safe_cos_minus_one = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here.
jnp.ones_like(cos_minus_one), # Any non-zero value should do here.
cos_minus_one,
)

Expand All @@ -190,14 +191,22 @@ def log(self) -> jax.Array:
-(half_theta * jnp.sin(theta)) / safe_cos_minus_one,
)

V_inv = jnp.array(
V_inv = jnp.stack(
[
[half_theta_over_tan_half_theta, half_theta],
[-half_theta, half_theta_over_tan_half_theta],
half_theta_over_tan_half_theta,
half_theta,
-half_theta,
half_theta_over_tan_half_theta,
],
axis=-1,
).reshape((*theta.shape, 2, 2))

tangent = jnp.concatenate(
[
jnp.einsum("...ij,...j->...i", V_inv, self.translation()),
theta[..., None],
]
)

tangent = jnp.concatenate([V_inv @ self.translation(), theta[None]])
return tangent

@override
Expand All @@ -213,11 +222,17 @@ def adjoint(self: "SE2") -> jax.Array:

@classmethod
@override
def sample_uniform(cls, key: jax.Array) -> "SE2":
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> "SE2":
key0, key1 = jax.random.split(key)
return SE2.from_rotation_and_translation(
rotation=SO2.sample_uniform(key0),
rotation=SO2.sample_uniform(key0, batch_axes=batch_axes),
translation=jax.random.uniform(
key=key1, shape=(2,), minval=-1.0, maxval=1.0
key=key1,
shape=(
*batch_axes,
2,
),
minval=-1.0,
maxval=1.0,
),
)
Loading

0 comments on commit 8960937

Please sign in to comment.