Skip to content

Commit

Permalink
support has_aux for manifold gradient functions (#17)
Browse files Browse the repository at this point in the history
* support has_aux for manifold gradient functions

* Formatting, mypy

---------

Co-authored-by: Brent Yi <[email protected]>
  • Loading branch information
alvinsunyixiao and brentyi authored May 5, 2024
1 parent d284418 commit 6cf00ce
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 26 deletions.
6 changes: 2 additions & 4 deletions jaxlie/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,10 @@ def __init__(
# Shared implementations.

@overload
def __matmul__(self: GroupType, other: GroupType) -> GroupType:
...
def __matmul__(self: GroupType, other: GroupType) -> GroupType: ...

@overload
def __matmul__(self, other: hints.Array) -> jax.Array:
...
def __matmul__(self, other: hints.Array) -> jax.Array: ...

def __matmul__(
self: GroupType, other: Union[GroupType, hints.Array]
Expand Down
25 changes: 14 additions & 11 deletions jaxlie/manifold/_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, _tree_utils.TangentPytree]:
...
) -> Callable[P, _tree_utils.TangentPytree]: ...


@overload
Expand All @@ -49,8 +48,7 @@ def grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]:
...
) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: ...


def grad(
Expand All @@ -72,7 +70,15 @@ def grad(
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return lambda *args, **kwargs: compute_value_and_grad(*args, **kwargs)[1] # type: ignore

def grad_fun(*args, **kwargs):
ret = compute_value_and_grad(*args, **kwargs)
if has_aux:
return ret[1], ret[0][1]
else:
return ret[1]

return grad_fun


@overload
Expand All @@ -83,8 +89,7 @@ def value_and_grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]:
...
) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: ...


@overload
Expand All @@ -95,8 +100,7 @@ def value_and_grad(
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = (),
) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]:
...
) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: ...


def value_and_grad(
Expand All @@ -121,14 +125,13 @@ def tangent_fun(*tangent_args, **tangent_kwargs):
tangent_args = map(zero_tangents, args)
tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()}

value, grad = jax.value_and_grad(
return jax.value_and_grad(
fun=tangent_fun,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*tangent_args, **tangent_kwargs)
return value, grad

return wrapped_grad # type: ignore
12 changes: 4 additions & 8 deletions jaxlie/manifold/_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,14 @@ def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
def rplus(
transform: GroupType,
delta: hints.Array,
) -> GroupType:
...
) -> GroupType: ...


@overload
def rplus(
transform: PytreeType,
delta: _tree_utils.TangentPytree,
) -> PytreeType:
...
) -> PytreeType: ...


# Using our typevars in the overloaded signature will cause errors.
Expand All @@ -81,13 +79,11 @@ def _rminus(a: GroupType, b: GroupType) -> jax.Array:


@overload
def rminus(a: GroupType, b: GroupType) -> jax.Array:
...
def rminus(a: GroupType, b: GroupType) -> jax.Array: ...


@overload
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree:
...
def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ...


# Using our typevars in the overloaded signature will cause errors.
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests with explicit examples."""

import numpy as onp
import pytest
from hypothesis import given, settings
Expand Down
1 change: 1 addition & 0 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test manifold helpers."""

from typing import Type

import jax
Expand Down
5 changes: 2 additions & 3 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from hypothesis import given, settings
from hypothesis import strategies as st
from jax import numpy as jnp
from jax.config import config

import jaxlie

# Run all tests with double-precision.
config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)

T = TypeVar("T", bound=jaxlie.MatrixLieGroup)

Expand Down Expand Up @@ -101,7 +100,7 @@ def assert_arrays_close(


def jacnumerical(
f: Callable[[jaxlie.hints.Array], jax.Array]
f: Callable[[jaxlie.hints.Array], jax.Array],
) -> Callable[[jaxlie.hints.Array], jax.Array]:
"""Decorator for computing numerical Jacobians of vector->vector functions."""

Expand Down

0 comments on commit 6cf00ce

Please sign in to comment.