Skip to content

Commit

Permalink
Add batch axes to tests + bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 5, 2024
1 parent 8960937 commit 2c7cf89
Show file tree
Hide file tree
Showing 12 changed files with 315 additions and 172 deletions.
48 changes: 35 additions & 13 deletions jaxlie/_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def translation(self) -> jax.Array:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> "SE2":
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> "SE2":
return SE2(
unit_complex_xy=jnp.broadcast_to(
jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4)
Expand All @@ -99,9 +99,21 @@ def parameters(self) -> jax.Array:
@override
def as_matrix(self) -> jax.Array:
cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0)
return jnp.stack([cos, -sin, x, sin, cos, y, 0.0, 0.0, 1.0], axis=-1).reshape(
(*self.get_batch_axes(), 3, 3)
)
out = jnp.stack(
[
cos,
-sin,
x,
sin,
cos,
y,
jnp.zeros_like(x),
jnp.zeros_like(x),
jnp.ones_like(x),
],
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))
return out

# Operations.

Expand Down Expand Up @@ -205,24 +217,34 @@ def log(self) -> jax.Array:
[
jnp.einsum("...ij,...j->...i", V_inv, self.translation()),
theta[..., None],
]
],
axis=-1,
)
return tangent

@override
def adjoint(self: "SE2") -> jax.Array:
cos, sin, x, y = self.unit_complex_xy
return jnp.array(
cos, sin, x, y = jnp.moveaxis(self.unit_complex_xy, -1, 0)
return jnp.stack(
[
[cos, -sin, y],
[sin, cos, -x],
[0.0, 0.0, 1.0],
]
)
cos,
-sin,
y,
sin,
cos,
-x,
jnp.zeros_like(x),
jnp.zeros_like(x),
jnp.ones_like(x),
],
axis=-1,
).reshape((*self.get_batch_axes(), 3, 3))

@classmethod
@override
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> "SE2":
def sample_uniform(
cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()
) -> "SE2":
key0, key1 = jax.random.split(key)
return SE2.from_rotation_and_translation(
rotation=SO2.sample_uniform(key0, batch_axes=batch_axes),
Expand Down
24 changes: 15 additions & 9 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def _skew(omega: hints.Array) -> jax.Array:
"""Returns the skew-symmetric form of a length-3 vector."""

wx, wy, wz = jnp.moveaxis(omega, -1, 0)
zeros = jnp.zeros_like(wx)
return jnp.stack(
[0.0, -wz, wy, wz, 0.0, -wx, -wy, wx, 0.0],
[zeros, -wz, wy, wz, zeros, -wx, -wy, wx, zeros],
axis=-1,
).reshape((*omega.shape[:-1], 3, 3))

Expand Down Expand Up @@ -71,7 +72,7 @@ def translation(self) -> jax.Array:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SE3:
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SE3:
return SE3(
wxyz_xyz=jnp.broadcast_to(
jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), (*batch_axes, 7)
Expand All @@ -92,13 +93,16 @@ def from_matrix(cls, matrix: hints.Array) -> SE3:

@override
def as_matrix(self) -> jax.Array:
return (
jnp.eye(4)
.at[..., :3, :3]
out = jnp.zeros((*self.get_batch_axes(), 4, 4))
out = (
out.at[..., :3, :3]
.set(self.rotation().as_matrix())
.at[..., :3, 3]
.set(self.translation())
.at[..., 3, 3]
.set(1.0)
)
return out

@override
def parameters(self) -> jax.Array:
Expand All @@ -117,7 +121,7 @@ def exp(cls, tangent: hints.Array) -> SE3:

rotation = SO3.exp(tangent[..., 3:])

theta_squared = jnp.sum(jnp.square(tangent[3:]), axis=-1)
theta_squared = jnp.sum(jnp.square(tangent[..., 3:]), axis=-1)
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
Expand Down Expand Up @@ -191,7 +195,7 @@ def log(self) -> jax.Array:
),
)
return jnp.concatenate(
[jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega]
[jnp.einsum("...ij,...j->...i", V_inv, self.translation()), omega], axis=-1
)

@override
Expand All @@ -207,12 +211,14 @@ def adjoint(self) -> jax.Array:
[jnp.zeros((*self.get_batch_axes(), 3, 3)), R], axis=-1
),
],
axis=-1,
axis=-2,
)

@classmethod
@override
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SE3:
def sample_uniform(
cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()
) -> SE3:
key0, key1 = jax.random.split(key)
return SE3.from_rotation_and_translation(
rotation=SO3.sample_uniform(key0, batch_axes=batch_axes),
Expand Down
6 changes: 3 additions & 3 deletions jaxlie/_so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def as_radians(self) -> jax.Array:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2:
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2:
return SO2(
unit_complex=jnp.stack(
[jnp.ones(batch_axes), jnp.zeros(batch_axes)], axis=-1
Expand All @@ -60,7 +60,7 @@ def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO2:
@classmethod
@override
def from_matrix(cls, matrix: hints.Array) -> SO2:
assert matrix.shape == (2, 2)
assert matrix.shape[-2:] == (2, 2)
return SO2(unit_complex=jnp.asarray(matrix[..., :, 0]))

# Accessors.
Expand Down Expand Up @@ -130,7 +130,7 @@ def normalize(self) -> SO2:

@classmethod
@override
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SO2:
def sample_uniform(cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2:
out = SO2.from_radians(
jax.random.uniform(
key=key, shape=batch_axes, minval=0.0, maxval=2.0 * jnp.pi
Expand Down
35 changes: 21 additions & 14 deletions jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ def from_quaternion_xyzw(xyzw: hints.Array) -> SO3:
constructor.
Args:
xyzw: xyzw quaternion. Shape should be (4,).
xyzw: xyzw quaternion. Shape should be (*, 4).
Returns:
Output.
"""
assert xyzw.shape == (4,)
return SO3(jnp.roll(xyzw, shift=1))
assert xyzw.shape[-1:] == (4,)
return SO3(jnp.roll(xyzw, axis=-1, shift=1))

def as_quaternion_xyzw(self) -> jax.Array:
"""Grab parameters as xyzw quaternion."""
return jnp.roll(self.wxyz, shift=-1)
return jnp.roll(self.wxyz, axis=-1, shift=-1)

def as_rpy_radians(self) -> hints.RollPitchYaw:
"""Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention.
Expand Down Expand Up @@ -161,7 +161,7 @@ def compute_yaw_radians(self) -> jax.Array:

@classmethod
@override
def identity(cls, batch_axes: Tuple[int, ...] = ()) -> SO3:
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO3:
return SO3(
wxyz=jnp.broadcast_to(jnp.array([1.0, 0.0, 0.0, 0.0]), (*batch_axes, 4))
)
Expand Down Expand Up @@ -363,7 +363,8 @@ def exp(cls, tangent: hints.Array) -> SO3:
[
real_factor[..., None],
imaginary_factor[..., None] * tangent,
]
],
axis=-1,
)
)

Expand All @@ -373,7 +374,7 @@ def log(self) -> jax.Array:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247

w = self.wxyz[..., 0]
norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:]
norm_sq = jnp.sum(jnp.square(self.wxyz[..., 1:]), axis=-1)
use_taylor = norm_sq < get_epsilon(norm_sq.dtype)

# Shim to avoid NaNs in jnp.where branches, which cause failures for
Expand All @@ -400,7 +401,7 @@ def log(self) -> jax.Array:
),
)

return atan_factor * self.wxyz[1:]
return atan_factor * self.wxyz[..., 1:]

@override
def adjoint(self) -> jax.Array:
Expand All @@ -417,14 +418,20 @@ def normalize(self) -> SO3:

@classmethod
@override
def sample_uniform(cls, key: jax.Array, batch_axes: Tuple[int, ...] = ()) -> SO3:
def sample_uniform(
cls, key: jax.Array, batch_axes: jdc.Static[Tuple[int, ...]] = ()
) -> SO3:
# Uniformly sample over S^3.
# > Reference: http://planning.cs.uiuc.edu/node198.html
u1, u2, u3 = jax.random.uniform(
key=key,
shape=(3, *batch_axes),
minval=jnp.zeros(3),
maxval=jnp.array([1.0, 2.0 * jnp.pi, 2.0 * jnp.pi]),
u1, u2, u3 = jnp.moveaxis(
jax.random.uniform(
key=key,
shape=(*batch_axes, 3),
minval=jnp.zeros(3),
maxval=jnp.array([1.0, 2.0 * jnp.pi, 2.0 * jnp.pi]),
),
-1,
0,
)
a = jnp.sqrt(1.0 - u1)
b = jnp.sqrt(u1)
Expand Down
65 changes: 42 additions & 23 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
def rplus(
transform: GroupType,
delta: hints.Array,
) -> GroupType: ...
) -> GroupType:
...


@overload
def rplus(
transform: PytreeType,
delta: _tree_utils.TangentPytree,
) -> PytreeType: ...
) -> PytreeType:
...


# Using our typevars in the overloaded signature will cause errors.
Expand All @@ -79,11 +81,13 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array:


@overload
def rminus(a: GroupType, b: GroupType) -> jax.Array: ...
def rminus(a: GroupType, b: GroupType) -> jax.Array:
...


@overload
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ...
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree:
...


# Using our typevars in the overloaded signature will cause errors.
Expand Down Expand Up @@ -129,23 +133,23 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:
# Jacobian col indices: theta

transform_so2 = cast(SO2, transform)
J = jnp.zeros((2, 1))
J = jnp.zeros((*transform.get_batch_axes(), 2, 1))

cos, sin = transform_so2.unit_complex
J = J.at[0].set(-sin).at[1].set(cos)
cos, sin = jnp.moveaxis(transform_so2.unit_complex, -1, 0)
J = J.at[..., 0].set(-sin).at[..., 1].set(cos)

elif type(transform) is SE2:
# Jacobian row indices: cos, sin, x, y
# Jacobian col indices: vx, vy, omega

transform_se2 = cast(SE2, transform)
J = jnp.zeros((4, 3))
J = jnp.zeros((*transform.get_batch_axes(), 4, 3))

# Translation terms.
J = J.at[2:, :2].set(transform_se2.rotation().as_matrix())
J = J.at[..., 2:, :2].set(transform_se2.rotation().as_matrix())

# Rotation terms.
J = J.at[:2, 2:3].set(
J = J.at[..., :2, 2:3].set(
rplus_jacobian_parameters_wrt_delta(transform_se2.rotation())
)

Expand All @@ -155,18 +159,29 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:

transform_so3 = cast(SO3, transform)

w, x, y, z = transform_so3.wxyz
_unused_neg_w, neg_x, neg_y, neg_z = -transform_so3.wxyz
w, x, y, z = jnp.moveaxis(transform_so3.wxyz, -1, 0)
neg_x = -x
neg_y = -y
neg_z = -z

J = (
jnp.array(
jnp.stack(
[
[neg_x, neg_y, neg_z],
[w, neg_z, y],
[z, w, neg_x],
[neg_y, x, w],
]
)
neg_x,
neg_y,
neg_z,
w,
neg_z,
y,
z,
w,
neg_x,
neg_y,
x,
w,
],
axis=-1,
).reshape((*transform.get_batch_axes(), 4, 3))
/ 2.0
)

Expand All @@ -175,18 +190,22 @@ def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:
# Jacobian col indices: vx, vy, vz, omega x, omega y, omega z

transform_se3 = cast(SE3, transform)
J = jnp.zeros((7, 6))
J = jnp.zeros((*transform.get_batch_axes(), 7, 6))

# Translation terms.
J = J.at[4:, :3].set(transform_se3.rotation().as_matrix())
J = J.at[..., 4:, :3].set(transform_se3.rotation().as_matrix())

# Rotation terms.
J = J.at[:4, 3:6].set(
J = J.at[..., :4, 3:6].set(
rplus_jacobian_parameters_wrt_delta(transform_se3.rotation())
)

else:
assert False, f"Unsupported type: {type(transform)}"

assert J.shape == (transform.parameters_dim, transform.tangent_dim)
assert J.shape == (
*transform.get_batch_axes(),
transform.parameters_dim,
transform.tangent_dim,
)
return J
Loading

0 comments on commit 2c7cf89

Please sign in to comment.