Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend Enhancements for GPU/TPU Support #144

Merged

Conversation

Sampreet
Copy link
Contributor

@Sampreet Sampreet commented Sep 29, 2024

Summary

This PR aims to include experimental features for GPU/TPU support using a dynamical numerical backend and addresses issues #128, #140, and #142 following the detailed discussion with @piperfw and @gefux over #142 and email.

Note: The JAX-backend tests are still very slow since no explicit JAX-related imports or primitives are utilized for speedup.

Pull Request Check List

  • The contribution has been discussed and agreed on in the Issue section.
  • Code contributions do its best to follow the zen of python.
  • The automated test are all positive:
    • tox -e py36 (to run pytest) the code tests.
    • tox -e style (to run pylint) the code style tests.
  • Added test for changed/added code.
  • The documentation has been updated:
    • docstring for all new functions/methods/classes/modules.
    • consistent style of all docstrings.
    • for new modules: /docs/pages/modules.rst has been updated.
    • for new functionality: /docs/pages/api.rst has been updated.
    • for new functionality: tutorials and examples have been updated.

Notable Changes

  • Support for JAX backend with dynamical switching via oqupy.config and oqupy.backends.numerical_backend.
  • A generalized and faster implementation of oqupy.util.create_delta with support for vectorization.
  • Safer tensor shape handling with reshape and vectorization signatures.
  • Unit tests with the dynamical backend.

Changes Awaiting Decision

  • JAX-free or explicit-JAX backends (with environment variables).
  • Updating/adding documentation for experimental features.
  • Separate development branch for JAX-related PRs.

Thank you.

@piperfw piperfw added the enhancement New feature or request label Sep 29, 2024
@piperfw
Copy link
Collaborator

piperfw commented Oct 14, 2024

Really good progress @Sampreet, to summarise my last post in the issue thread and reegarding the changes awaiting decision (subject to me misunderstanding there)

  • I think I 'JAX-free' backend using an environment variable could work, @gefux please may you comment on this when possible
  • Adding your points to CONTRIBUTING.MD can be done now. A separate docs page is also a good idea, if you have time you can add one under 'DEVELOPMENT' or I can at a later date - up to you how brief and whether to include extra info about jax etc.. I will read and suggest changes if you do
  • I'm happy with a separate jax branch and think it is inline with how we have used branches in the past for large developments

@Sampreet
Copy link
Contributor Author

Thanks a lot, @piperfw! I have added a "DEVELOPMENT.md" page which contains a section on the current backend development as discussed. In addition to this, I have added a small comment on "CONTRIBUTING.md" so that contributions to features under active development (which are not on the "main" branch) can be maintained via "DEVELOPMENT.md". You may take a look and suggest changes on the same.

I have also added a comment on #142 mentioning a smoother implementation of the experimental features for backend switching. Once an approach is finalized, I shall update the "api.rst" file ("modules.rst" doesn't have any backend-related section).

Kindly let me know if any tutorial/example would also be required for the current implementation too.

@piperfw
Copy link
Collaborator

piperfw commented Nov 3, 2024

DEVELOPMENT.md (and CONTRIBUTING.md) looks good @Sampreet, thanks! I tweaked the second and final bullet points slightly, and made some other edits; let me know if that's all OK.

I'm going to create a new "jax" branch now (if non-jax development is required we can change it), and I think we can complete the PR request to there once we've finalised the dynamical switching and you've had a chance to update api.rst (agree modules.rst can be left). I suggest calling it enable_gpu_features() to be more explicit.

I've created a new .rst page on the docs under the Development section that mirrors Development.md. We can put this on the main branch, which should increase visibility for those both looking to whether OQuPy has GPU/TPU support and those intending to contribute other features (directly related to GPU/TPU support or otherwise).

Edit: I changed the branch name to dev/jax. Regarding tutorial/example, no tutorial jupyter notebook is needed, but we can add a minimal working example under examples/simple_dynamics_with_jax.py or similar that basically copies simple_dynamics.py except having enabled the jax backend (may want to wait until we have finalised that).

@piperfw piperfw self-assigned this Nov 3, 2024
@piperfw piperfw changed the base branch from main to dev/jax November 3, 2024 22:51
@piperfw
Copy link
Collaborator

piperfw commented Nov 5, 2024

I realise enable_experimental_features or perhaps enable_jax_features may a more appropriate name after all since the jax backend may offer enhancements (or at least performance differences) for CPU usage as well.

@Sampreet Sampreet marked this pull request as ready for review November 6, 2024 15:01
@Sampreet
Copy link
Contributor Author

Sampreet commented Nov 6, 2024

Thanks a lot for the feedback, @piperfw. The changes you have made also look good. I have implemented the suggestions (with a few additional changes to docs and DEVELOPMENT.md), and have marked the PR ready for review.

@piperfw piperfw merged commit 6ccbd2e into tempoCollaboration:dev/jax Nov 10, 2024
@piperfw
Copy link
Collaborator

piperfw commented Nov 10, 2024

That's merged @Sampreet, congratulations!

I checked over your last changes and the use of the environment variable to set the backend. All worked nicely, and the example is helpful. I added a small note on using this variable in DEVELOPMENT.md/gpu_features.rst.

I think we should add the docs page to the main branch (and then update readthedocs) for visibility of the support. Would you like to do this as PR for authorship or shall I go ahead? (just copy gpu_features.rst to a new fork of main).

I'll also discuss with Gerald when he's next available about whether we want to merge this into main in the near future, because you have written in a flexible way so as to not interfere with the numpy backend this seems possible.

Really good work and thank you; you have been very efficient at making changes and responding to feedback which makes a maintainer's role so much easier.

@piperfw
Copy link
Collaborator

piperfw commented Nov 10, 2024

[Also DEVELOPMENT.md can go to the main branch]

@Sampreet
Copy link
Contributor Author

Hi @piperfw, thanks a lot for the final changes and your positive feedback on the work. Likewise, it has been a pleasure implementing the changes through the wonderful discussions with you and Gerald.

Sure, I would be happy to make a PR adding DEVELOPMENT.md and gpu_features.rst (will probably need to update the corresponding lines in CONTRIBUTING.md and index.rst too) to the main branch. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants