Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add examples for custom derivative definitions in pytorch #1

Open
elcorto opened this issue Jun 9, 2021 · 3 comments
Open

Add examples for custom derivative definitions in pytorch #1

elcorto opened this issue Jun 9, 2021 · 3 comments
Labels
help wanted Extra attention is needed

Comments

@elcorto
Copy link
Member

elcorto commented Jun 9, 2021

No description provided.

@elcorto elcorto added the help wanted Extra attention is needed label Jun 9, 2021
@elcorto
Copy link
Member Author

elcorto commented Jun 22, 2021

@elcorto
Copy link
Member Author

elcorto commented Jan 16, 2024

The example code should follow test_jax.py and implement the same operations if possible to enable easy comparison of libraries.

@elcorto
Copy link
Member Author

elcorto commented Jul 12, 2024

As of torch 2.0, there is torch.func (formerly functorch) which implements a subset of the jax API (e.g. torch.func.grad behaves like jax.grad). There is also support for using the torch.func API with custom derivatives, using torch.autograd.Function, even though it seems more complex to set up (i.e. things like ctx.save_for_backward).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant