Skip to content

Commit

Permalink
support has_aux for manifold gradient functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinsunyixiao committed Apr 23, 2024
1 parent d284418 commit f68e681
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions jaxlie/manifold/_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,15 @@ def grad(
allow_int=allow_int,
reduce_axes=reduce_axes,
)
return lambda *args, **kwargs: compute_value_and_grad(*args, **kwargs)[1] # type: ignore

def grad_fun(*args, **kwargs):
ret = compute_value_and_grad(*args, **kwargs)
if has_aux:
return ret[1], ret[0][1]
else:
return ret[1]

return grad_fun


@overload
Expand Down Expand Up @@ -121,14 +129,13 @@ def tangent_fun(*tangent_args, **tangent_kwargs):
tangent_args = map(zero_tangents, args)
tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()}

value, grad = jax.value_and_grad(
return jax.value_and_grad(
fun=tangent_fun,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
reduce_axes=reduce_axes,
)(*tangent_args, **tangent_kwargs)
return value, grad

return wrapped_grad # type: ignore

0 comments on commit f68e681

Please sign in to comment.