-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix lossy type promotion in JAX primitive lowering #1008
Conversation
This fixes two issues: - Silently downcasting from complex arguments to floats. - Undefined variable for certian inputs (`baseType`).
.. illegal (non-unitary) operators.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## v0.8.0-rc #1008 +/- ##
============================================
Coverage ? 97.65%
============================================
Files ? 75
Lines ? 10748
Branches ? 1243
============================================
Hits ? 10496
Misses ? 203
Partials ? 49 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is ok.
I wouldn't do it at this level of abstraction though. When tracing gphase and extract you can use jax.numpy.promote_types
to find the least upper bound of the type promotion lattice using the element type from the user and the type you want as the two operands. If the least upper bound is different from the type you want to promote to, you know that an error will be raised. If the least upper bound is the same, then you can safely promote with jax.lax.convert_element_type
or jax.numpy.astype
. This would lead to the operands in gphase and extract to already be tensors of the type you want in their lowering.
The advantage of using jax.numpy.promote_types
is that it is more general. Instead of specializing the function to f64, you just have safe_cast(value, type)
and that's it.
That's a good point! I think I just applied the fix at this level because we were already doing it here, but I'm happy if someone wants to move it in the future :) Now there is actually a benefit of doing it here, which is that here we are specifically targeting the type system defined in MLIR, and doing it further up could be premature. For instance, other frontends that target integration by converting to our jaxpr will have to repeat this casting on their side. |
eea5fc3
to
9434c02
Compare
9434c02
to
21397a0
Compare
Unfortunately it looks like the CI is getting test crashes that I don't have locally (maybe linux vs macos 🤔):
|
I'll take a look first time tomorrow 🤔 |
Thanks Erick! I've tried debugging this for a while on the AWS machine, but I could not make sense of it. Something related to the IR contexts that I'm generating in the tests, but this seems to be the standard pattern in mlir (and there are instances in jax as well) for unit testing functions using the Python bindings. One avenue I explored was maybe it doesn't like having multiple contexts alive at the same time, but even force garbage collection in between each test didn't seem to work. For now I've disabled the tests on linux, I don't think the functional code has any issues, only the tests themselves. |
PennyLaneAI/pennylane#6082 unearthed that in an attempt to be more lenient with user supplied types, Catalyst eagerly converts any type to the required float64 type for gate parameters, including when this results in a loss of data (like converting complex numbers to floats).
This fixes the issue as well as providing some minor code cleanup, and fixing a long-standing issue of potentially undefined variables in the Python code.
In order to help users with the proposed fix in #pennylane/6082 for the decomposition of
Exp
, the error message for complex gate parameters mentions potential non-unitary operators like the exponential with real exponent.[sc-71066]