Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A possible side-effect from calling jax.Array.__str__ method? #19123

Closed
pearu opened this issue Dec 25, 2023 · 4 comments
Closed

A possible side-effect from calling jax.Array.__str__ method? #19123

pearu opened this issue Dec 25, 2023 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@pearu
Copy link
Collaborator

pearu commented Dec 25, 2023

Description

Consider the following two examples:

>>> a = jax.numpy.array([[0, 1], [2, 3]])
>>> t = torch.from_dlpack(a)
>>> t[0, 0] = 99
>>> a
Array([[99,  1],
       [ 2,  3]], dtype=int32)

and

>>> a = jax.numpy.array([[0, 1], [2, 3]])
>>> a_str = str(a)
>>> t = torch.from_dlpack(a)
>>> t[0, 0] = 99
>>> a
Array([[0, 1],
       [2, 3]], dtype=int32)

(Leaving aside the problematic in-place change of a jax.Array instance via torch.Tensor wrapper,) I am curious why these examples have different results depending on if str(a) expression is used or not?

What jax/jaxlib version are you using?

0.4.23 0.4.23.dev20231223

Which accelerator(s) are you using?

GPU

Additional system info?

1.26.2 3.11.7 | packaged by conda-forge | (main, Dec 23 2023, 14:43:09) [GCC 12.3.0] uname_result(system='Linux', node='ex', release='5.4.0-153-generic', version='#170-Ubuntu SMP Fri Jun 16 13:43:31 UTC 2023', machine='x86_64')

NVIDIA GPU info

Tue Dec 26 00:11:32 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 2060 S...    Off| 00000000:17:00.0 Off |                  N/A |
| 17%   35C    P0               31W / 175W|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 2060 S...    Off| 00000000:65:00.0 Off |                  N/A |
| 17%   42C    P0               17W / 175W|      0MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
@pearu pearu added the bug Something isn't working label Dec 25, 2023
@pearu
Copy link
Collaborator Author

pearu commented Dec 25, 2023

Ah, it appears that jax.Array.__str__ result is cached: a.__cuda_array_interface__ and t.__cuda_array_interface__ contain equal data pointer values. Sorry for the noise.

However, this example illustrates possibly unsafe usage of the dlpack protocol that leads to breaking the jax array immutability property with no way of detecting such a misuse of the protocol.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 26, 2023

Thanks for the question. Indeed the issue is that the string representation of a JAX array is essentially cached the first time it is called, so if you somehow change the values in the underlying buffer, the representation will no longer match the internal values. This is not a problem during normal usage, because JAX arrays are immutable.

The dlpack mutation issue is a problem, and unfortunately not one with an easy solution beyond removing the ability to export via dlpack (see some relevant discussion here: data-apis/array-api#191). The dlpack project lists a read-only flag on its roadmap; when that is realized we could utilize it in JAX to prevent this issue. For now, probably the best we can do is add some warnings to the docs that if you mutate an exported buffer, the JAX-side behavior is undefined.

@jakevdp jakevdp self-assigned this Dec 26, 2023
@pearu
Copy link
Collaborator Author

pearu commented Dec 26, 2023

For now, probably the best we can do is add some warnings to the docs that if you mutate an exported buffer, the JAX-side behavior is undefined.

Yes, this makes sense. The warnings should be in the docs of both from_dlpack and to_dlpack as the same undefined behavior can be triggered via both directions of sharing buffers.

@pearu
Copy link
Collaborator Author

pearu commented Jan 4, 2024

Resolved by #19183

@pearu pearu closed this as completed Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants