Skip to content

Commit

Permalink
Update vmap example with broadcasting notes
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed May 6, 2024
1 parent 24edd26 commit 6a65fb7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ Where each group supports:
<code>jaxlie.<strong>manifold.\*</strong></code>).
- Compatibility with standard JAX function transformations. (see
[./examples/vmap_example.py](./examples/vmap_example.py))
- Broadcasting.
- (Un)flattening as pytree nodes.
- Serialization using [flax](https://github.com/google/flax).

Expand Down
24 changes: 20 additions & 4 deletions examples/vmap_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Examples of vectorizing transformations via vmap.
"""jaxlie implements numpy-style broadcasting for all operations. For more
explicit vectorization, we can also use vmap function transformations.
Omitted for brevity here, but note that in practice we usually want to JIT after
vmapping!"""
Omitted for brevity here, but in practice we usually want to JIT after
vmapping."""

import jax
import numpy as onp

from jaxlie import SO3

N = 100
Expand Down Expand Up @@ -60,6 +60,10 @@
p_transformed_stacked = jax.vmap(lambda p: SO3.apply(R_single, p))(p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_single @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (4) Applying N transformations to N points.
#############################
Expand All @@ -69,13 +73,21 @@
p_transformed_stacked = jax.vmap(SO3.apply)(R_stacked, p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (5) Applying N transformations to 1 point.
#############################

p_transformed_stacked = jax.vmap(lambda R: SO3.apply(R, p_single))(R_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_single[None, :]
assert p_transformed_stacked.shape == (N, 3)

#############################
# (6) Multiplying transformations.
#############################
Expand All @@ -95,3 +107,7 @@

# Or N x 1 multiplication:
assert (jax.vmap(lambda R: SO3.multiply(R, R_single))(R_stacked)).wxyz.shape == (N, 4)

# Again, broadcasting also works.
assert (R_stacked @ R_stacked).wxyz.shape == (N, 4)
assert (R_stacked @ SO3(R_single.wxyz[None, :])).wxyz.shape == (N, 4)

0 comments on commit 6a65fb7

Please sign in to comment.