-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A possible side-effect from calling jax.Array.__str__ method? #19123
Comments
Ah, it appears that However, this example illustrates possibly unsafe usage of the dlpack protocol that leads to breaking the jax array immutability property with no way of detecting such a misuse of the protocol. |
Thanks for the question. Indeed the issue is that the string representation of a JAX array is essentially cached the first time it is called, so if you somehow change the values in the underlying buffer, the representation will no longer match the internal values. This is not a problem during normal usage, because JAX arrays are immutable. The dlpack mutation issue is a problem, and unfortunately not one with an easy solution beyond removing the ability to export via |
Yes, this makes sense. The warnings should be in the docs of both |
Resolved by #19183 |
Description
Consider the following two examples:
and
(Leaving aside the problematic in-place change of a jax.Array instance via torch.Tensor wrapper,) I am curious why these examples have different results depending on if
str(a)
expression is used or not?What jax/jaxlib version are you using?
0.4.23 0.4.23.dev20231223
Which accelerator(s) are you using?
GPU
Additional system info?
1.26.2 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] uname_result(system='Linux', node='ex', release='5.4.0-153-generic', version='#170-Ubuntu SMP Fri Jun 16 13:43:31 UTC 2023', machine='x86_64')
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: