Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 5, 2024
1 parent 847409d commit a043c14
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 34 deletions.
18 changes: 10 additions & 8 deletions jaxlie/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,23 @@ def log(self) -> jax.Array:
half_theta_safe = theta_safe / 2.0

V_inv = jnp.where(
use_taylor,
use_taylor[..., None, None],
jnp.eye(3)
- 0.5 * skew_omega
+ jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega) / 12.0,
(
jnp.eye(3)
- 0.5 * skew_omega
+ (
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
* (skew_omega @ skew_omega)
(
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
)[..., None, None]
* jnp.einsum("...ij,...jk->...ik", skew_omega, skew_omega)
),
)
return jnp.concatenate(
Expand Down
2 changes: 1 addition & 1 deletion jaxlie/_so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def log(self) -> jax.Array:
),
)

return atan_factor * self.wxyz[..., 1:]
return atan_factor[..., None] * self.wxyz[..., 1:]

@override
def adjoint(self) -> jax.Array:
Expand Down
22 changes: 0 additions & 22 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,8 @@

PytreeType = TypeVar("PytreeType")
GroupType = TypeVar("GroupType", bound=MatrixLieGroup)
CallableType = TypeVar("CallableType", bound=Callable)


def _naive_auto_vmap(f: CallableType) -> CallableType:
def inner(*args, **kwargs):
batch_axes = None
for arg in args + tuple(kwargs.values()):
if isinstance(arg, MatrixLieGroup):
if batch_axes is None:
batch_axes = arg.get_batch_axes()
else:
assert arg.get_batch_axes() == batch_axes
assert batch_axes is not None

f_vmapped: Callable = f
for i in range(len(batch_axes)):
f_vmapped = jax.vmap(f_vmapped)
return f_vmapped(*args, **kwargs)

return inner # type: ignore


@_naive_auto_vmap
def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
assert isinstance(transform, MatrixLieGroup)
assert isinstance(delta, (jax.Array, onp.ndarray))
Expand Down Expand Up @@ -72,7 +51,6 @@ def rplus(
return _tree_utils._map_group_trees(_rplus, jnp.add, transform, delta)


@_naive_auto_vmap
def _rminus(a: GroupType, b: GroupType) -> jax.Array:
assert isinstance(a, MatrixLieGroup) and isinstance(b, MatrixLieGroup)
return (a.inverse() @ b).log()
Expand Down
8 changes: 5 additions & 3 deletions jaxlie/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ def broadcast_leading_axes(inputs: TupleOfBroadcastable) -> TupleOfBroadcastable
from .._base import MatrixLieGroup

array_inputs = [
(x.parameters(), (x.parameters_dim,))
if isinstance(x, MatrixLieGroup)
else (x, x.shape[-1:])
(
(x.parameters(), (x.parameters_dim,))
if isinstance(x, MatrixLieGroup)
else (x, x.shape[-1:])
)
for x in inputs
]
for array, shape_suffix in array_inputs:
Expand Down

0 comments on commit a043c14

Please sign in to comment.