Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JVP/VJP type checking in Catalyst frontend #1031

Merged
merged 6 commits into from
Aug 19, 2024

Conversation

joeycarter
Copy link
Contributor

Context: Improve type checking and error messaging in the Catalyst JVP and VJP functions.

Description of the Change: Adds type checking to ensure the following for JVP:

  • Number of tangent operands and number of differentiable parameters are equal.
  • Data types of function params and tangents arguments are equal.
  • Function params and tangent arguments are the same shape.

and for VJP:

  • Number of cotangent operands and number of function output parameters are equal.
  • Data types of function output params and cotangents arguments are equal.
  • Function output params and cotangent arguments are the same shape.

Note that the equivalent type checking is also performed at the MLIR level.

Benefits: Checking functional parameters earlier in the frontend improves error messaging and usability rather than falling back to the MLIR error messages.

Possible Drawbacks: There's a small risk that stricter type checking in the frontend might break backward compatibility if users are calling these functions in unconventional or non-standard ways that are not captured by our unit/integration tests.

Related GitHub Issues:

Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job! I would just split the single test into multiple ones

frontend/catalyst/api_extensions/differentiation.py Outdated Show resolved Hide resolved
frontend/catalyst/api_extensions/differentiation.py Outdated Show resolved Hide resolved
frontend/catalyst/api_extensions/differentiation.py Outdated Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Show resolved Hide resolved
frontend/test/pytest/test_jvpvjp.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Aug 16, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.88%. Comparing base (5b589ec) to head (e96afba).
Report is 207 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1031   +/-   ##
=======================================
  Coverage   97.88%   97.88%           
=======================================
  Files          75       75           
  Lines       10685    10702   +17     
  Branches     1226     1235    +9     
=======================================
+ Hits        10459    10476   +17     
  Misses        177      177           
  Partials       49       49           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@rmoyard
Copy link
Contributor

rmoyard commented Aug 16, 2024

@joeycarter Also for the the code factor issue, you can apply black -l 100 on the test file

@joeycarter joeycarter force-pushed the joeycarter/jvp-vjp-frontend-type-checking branch from 3a1be58 to 739dd4a Compare August 19, 2024 15:28
Copy link
Contributor

@rmoyard rmoyard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, great job 💯

@joeycarter joeycarter merged commit 730de7d into main Aug 19, 2024
42 of 43 checks passed
@joeycarter joeycarter deleted the joeycarter/jvp-vjp-frontend-type-checking branch August 19, 2024 17:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants