Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lossy type promotion in JAX primitive lowering #1008

Merged
merged 10 commits into from
Aug 29, 2024
Merged

Conversation

dime10
Copy link
Contributor

@dime10 dime10 commented Aug 9, 2024

PennyLaneAI/pennylane#6082 unearthed that in an attempt to be more lenient with user supplied types, Catalyst eagerly converts any type to the required float64 type for gate parameters, including when this results in a loss of data (like converting complex numbers to floats).

This fixes the issue as well as providing some minor code cleanup, and fixing a long-standing issue of potentially undefined variables in the Python code.

In order to help users with the proposed fix in #pennylane/6082 for the decomposition of Exp, the error message for complex gate parameters mentions potential non-unitary operators like the exponential with real exponent.

[sc-71066]

This fixes two issues:
- Silently downcasting from complex arguments to floats.
- Undefined variable for certian inputs (`baseType`).
.. illegal (non-unitary) operators.
@dime10 dime10 added bug Something isn't working frontend Pull requests that update the frontend labels Aug 9, 2024
Copy link

codecov bot commented Aug 9, 2024

Codecov Report

Attention: Patch coverage is 69.23077% with 12 lines in your changes missing coverage. Please review.

Please upload report for BASE (v0.8.0-rc@e2577f2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
frontend/catalyst/jax_primitives.py 69.23% 6 Missing and 6 partials ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##             v0.8.0-rc    #1008   +/-   ##
============================================
  Coverage             ?   97.65%           
============================================
  Files                ?       75           
  Lines                ?    10748           
  Branches             ?     1243           
============================================
  Hits                 ?    10496           
  Misses               ?      203           
  Partials             ?       49           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@erick-xanadu erick-xanadu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok.

I wouldn't do it at this level of abstraction though. When tracing gphase and extract you can use jax.numpy.promote_types to find the least upper bound of the type promotion lattice using the element type from the user and the type you want as the two operands. If the least upper bound is different from the type you want to promote to, you know that an error will be raised. If the least upper bound is the same, then you can safely promote with jax.lax.convert_element_type or jax.numpy.astype. This would lead to the operands in gphase and extract to already be tensors of the type you want in their lowering.

The advantage of using jax.numpy.promote_types is that it is more general. Instead of specializing the function to f64, you just have safe_cast(value, type) and that's it.

@dime10
Copy link
Contributor Author

dime10 commented Aug 9, 2024

I think it is ok.

I wouldn't do it at this level of abstraction though. When tracing gphase and extract you can use jax.numpy.promote_types to find the least upper bound of the type promotion lattice using the element type from the user and the type you want as the two operands. If the least upper bound is different from the type you want to promote to, you know that an error will be raised. If the least upper bound is the same, then you can safely promote with jax.lax.convert_element_type or jax.numpy.astype. This would lead to the operands in gphase and extract to already be tensors of the type you want in their lowering.

That's a good point! I think I just applied the fix at this level because we were already doing it here, but I'm happy if someone wants to move it in the future :)

Now there is actually a benefit of doing it here, which is that here we are specifically targeting the type system defined in MLIR, and doing it further up could be premature. For instance, other frontends that target integration by converting to our jaxpr will have to repeat this casting on their side.

@dime10 dime10 added this to the v0.8.0 milestone Aug 22, 2024
@rauletorresc rauletorresc changed the base branch from main to v0.8.0-rc August 26, 2024 15:59
@dime10
Copy link
Contributor Author

dime10 commented Aug 29, 2024

Unfortunately it looks like the CI is getting test crashes that I don't have locally (maybe linux vs macos 🤔):

FAILED frontend/test/pytest/test_jax_primitives.py::TestHelpers::test_float_casting[0]
FAILED frontend/test/pytest/test_jax_primitives.py::TestHelpers::test_float_casting[test_input5]
FAILED frontend/test/pytest/test_jax_primitives.py::TestHelpers::test_scalar_extraction[test_input1]

@erick-xanadu
Copy link
Contributor

I'll take a look first time tomorrow 🤔

@dime10
Copy link
Contributor Author

dime10 commented Aug 29, 2024

I'll take a look first time tomorrow 🤔

Thanks Erick! I've tried debugging this for a while on the AWS machine, but I could not make sense of it. Something related to the IR contexts that I'm generating in the tests, but this seems to be the standard pattern in mlir (and there are instances in jax as well) for unit testing functions using the Python bindings. One avenue I explored was maybe it doesn't like having multiple contexts alive at the same time, but even force garbage collection in between each test didn't seem to work.

For now I've disabled the tests on linux, I don't think the functional code has any issues, only the tests themselves.

@dime10 dime10 merged commit 666577f into v0.8.0-rc Aug 29, 2024
38 of 39 checks passed
@dime10 dime10 deleted the fix-type-promotion branch August 29, 2024 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working frontend Pull requests that update the frontend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants