Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cryptic error message from jax integration #5607

Open
1 task done
quanvuong opened this issue Aug 15, 2024 · 2 comments
Open
1 task done

Cryptic error message from jax integration #5607

quanvuong opened this issue Aug 15, 2024 · 2 comments
Assignees
Labels
bug Something isn't working

Comments

@quanvuong
Copy link

Version

1.40.0

Describe the bug.

When obtaining a batch from the data iterator, I received this error. I am not sure why this is happening and not sure how to debug this myself.

1555   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 189, in __next__
1556     return self._next_impl()
1557            ^^^^^^^^^^^^^^^^^
1558   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 170, in _next_impl
1559     category_outputs = self._gather_outputs_for_category(pipelines_outputs, category_id)
1560                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1561   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/iterator.py", line 196, in _gather_outputs_for_category
1562     _to_jax_array(pipelines_outputs[pipeline_id][category_id].as_tensor())
1563   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/nvidia/dali/plugin/jax/integration.py", line 43, in _to_jax_array
1564     return jax_array.copy()
1565            ^^^^^^^^^^^^^^^^
1566   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2689, in copy
1567     return array(a, copy=True, order=order)
1568            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1569   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2595, in array
1570     out = _array_copy(object) if copy else object
1571           ^^^^^^^^^^^^^^^^^^^
1572   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4650, in _array_copy
1573     return copy_p.bind(arr)
1574            ^^^^^^^^^^^^^^^^
1575   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 387, in bind
1576     return self.bind_with_trace(find_top_trace(args), args, params)
1577            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1578   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 391, in bind_with_trace
1579     out = trace.process_primitive(self, map(trace.full_raise, args), params)
1580           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1581   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/core.py", line 879, in process_primitive
1582     return primitive.impl(*tracers, **params)
1583            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1584   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 4691, in _copy_impl
1585     return dispatch.apply_primitive(prim, *args, **kwargs)
1586            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1587   File "/home/quan/micromamba/envs/monopi/lib/python3.11/site-packages/jax/_src/dispatch.py", line 86, in apply_primitive
1588     outs = fun(*args)
1589            ^^^^^^^^^^
1590 jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Buffer passed to Execute() as argument 0 to replica 0 is on device cuda:1, but replica is assigned to device cuda:0.

Minimum reproducible example

No response

Relevant log output

No response

Other/Misc.

No response

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@quanvuong quanvuong added the bug Something isn't working label Aug 15, 2024
@quanvuong
Copy link
Author

if I reduce the batch size, the problem does not seem to happen. Sorry I can't give you a minimal reproducible snippet.

@awolant
Copy link
Contributor

awolant commented Aug 19, 2024

Hello @quanvuong thanks for reporting the issue.
Could you tell more about your setup? It looks like you are using multiple GPUs here? What are the other parameters? What are the batch sizes you mentioned?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants