tree-math makes it easy to implement numerical algorithms that work on
JAX pytrees, such as
iterative methods for optimization and equation solving. It does so by providing
a wrapper class tree_math.Vector
that defines array operations such as
infix arithmetic and dot-products on pytrees as if they were vectors.
In a library like SciPy, numerical algorithms are typically written to handle
fixed-rank arrays, e.g., scipy.integrate.solve_ivp
requires inputs of shape (n,)
. This is convenient for implementors of
numerical methods, but not for users, because 1d arrays are typically not the
best way to keep track of state for non-trivial functions (e.g., neural networks
or PDE solvers).
tree-math provides an alternative to flattening and unflattening these more complex data structures ("pytrees") for use in numerical algorithms. Instead, the numerical algorithm itself can be written in way to handle arbitrary collections of arrays stored in pytrees. This avoids unnecessary memory copies, and gives the user more control over the memory layouts used in computation. In practice, this can often makes a big difference for computational efficiency as well, which is why support for flexible data structures is so prevalent inside libraries that use JAX.
tree-math is implemented in pure Python, and only depends upon JAX.
You can install it from PyPI: pip install tree-math
.
tree-math is simple to use. Just pass arbitrary pytree objects into
tree_math.Vector
to create an a object that arithmetic as if all leaves of
the pytree were flattened and concatenated together:
>>> import tree_math as tm
>>> import jax.numpy as jnp
>>> v = tm.Vector({'x': 1, 'y': jnp.arange(2, 4)})
>>> v
tree_math.Vector({'x': 1, 'y': DeviceArray([2, 3], dtype=int32)})
>>> v + 1
tree_math.Vector({'x': 2, 'y': DeviceArray([3, 4], dtype=int32)})
>>> v.sum()
DeviceArray(6, dtype=int32)
You can also find a few functions defined on vectors in tree_math.numpy
, which
implements a very restricted subset of jax.numpy
. If you're interested in more
functionality, please open an issue to discuss before sending a pull request.
(In the long term, this separate module might disappear if we can support
Vector
objects directly inside jax.numpy
.)
Vector objects are pytrees themselves, which means the are compatible with JAX
transformations like jit
, vmap
and grad
, and control flow like
while_loop
and cond
.
When you're done manipulating vectors, you can pull out the underlying pytrees
from the .tree
property:
>>> v.tree
{'x': 1, 'y': DeviceArray([2, 3], dtype=int32)}
As an alternative to manipulating Vector
objects directly, you can also use
the functional transformations wrap
and unwrap
(see the "Writing an
algorithm" below). Or you can create your own Vector
-like objects from a pytree with
VectorMixin
or tree_math.struct
(see "Custom vector classes" below).
One important difference between tree_math
and jax.numpy
is that dot
products in tree_math
default to full precision on all platforms, rather
than defaulting to bfloat16 precision on TPUs. This is useful for writing most
numerical algorithms, and will likely be JAX's default behavior
in the future.
It would be nice to have a Matrix
class to make it possible to use tree-math
for numerical algorithms such as
L-BFGS which use matrices
to represent stacks of vectors. If you're interesting in contributing this
feature, please comment on this GitHub issue.
Here is how we could write the preconditioned conjugate gradient method. Notice how similar the implementation is to the pseudocode from Wikipedia, unlike the implementation in JAX. Both versions support arbitrary pytrees as input:
import functools
from jax import lax
import tree_math as tm
import tree_math.numpy as tnp
@functools.partial(tm.wrap, vector_argnames=['b', 'x0'])
def cg(A, b, x0, M=lambda x: x, maxiter=5, tol=1e-5, atol=0.0):
"""jax.scipy.sparse.linalg.cg, written with tree_math."""
A = tm.unwrap(A)
M = tm.unwrap(M)
atol2 = tnp.maximum(tol**2 * (b @ b), atol**2)
def cond_fun(value):
x, r, gamma, p, k = value
return (r @ r > atol2) & (k < maxiter)
def body_fun(value):
x, r, gamma, p, k = value
Ap = A(p)
alpha = gamma / (p.conj() @ Ap)
x_ = x + alpha * p
r_ = r - alpha * Ap
z_ = M(r_)
gamma_ = r_.conj() @ z_
beta_ = gamma_ / gamma
p_ = z_ + beta_ * p
return x_, r_, gamma_, p_, k + 1
r0 = b - A(x0)
p0 = z0 = M(r0)
gamma0 = r0 @ z0
initial_value = (x0, r0, gamma0, p0, 0)
x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)
return x_final
You can also make your own classes directly support math like Vector
. To do
so, either inherit from tree_math.VectorMixin
on your pytree class, or use
tree_math.struct
(similar to
flax.struct
)
to create pytree and tree math supporting
dataclass:
import jax
import tree_math
@tree_math.struct
class Point:
x: float | jax.Array
y: float | jax.Array
a = Point(0.0, 1.0)
b = Point(2.0, 3.0)
a + 3 * b # Point(6.0, 10.0)
jax.grad(lambda x, y: x @ y)(a, b) # Point(2.0, 3.0)