You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use opt_einsum with jax, but it seems like it fail to find the correct methods from jax module.
Resulting into a conversion from jax Array to ndarray or a traceback if jax is specified as backend.
issue_report$ python main.py
3.10.13 (main, Dec 15 2023, 19:01:59) [GCC 11.4.0]
0.4.24
v3.3.0
<class 'numpy.ndarray'>
Traceback (most recent call last):
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 65, in get_func
return _cached_funcs[func, backend]
KeyError: ('einsum', <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'>)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 38, in _import_func
lib = importlib.import_module(_aliases.get(backend, backend))
File "/[...]/.pyenv/versions/3.10.13/lib/python3.10/importlib/__init__.py", line 117, in import_module
if name.startswith('.'):
AttributeError: module 'jax.numpy' has no attribute 'startswith'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/[...]/issue_report/main.py", line 14, in<module>
print(type(opt_einsum.contract("i->i", x, backend=jnp)))
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 507, in contract
return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract
new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum
return einsum(*args, **kwargs)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 337, in _einsum
fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 67, in get_func
fn = _import_func(func, backend, default)
File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 44, in _import_func
raise AttributeError(error_msg.format(backend, func))
AttributeError: <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'> doesn't seem to provide the function einsum - see https://optimized-einsum.readthedocs.io/en/latest/backends.html for details on which functions are required for which contractions.
I've just strip in [...] personal folder information.
Thanks and best regards.
The text was updated successfully, but these errors were encountered:
Looks like the the backend dispatch aliases need updating with "jaxlib": "jax.numpy" for it to work automatically. Note passing the module directly is not supported, instead if you call with backend="jax" it should work.
Hi !
I'm trying to use opt_einsum with jax, but it seems like it fail to find the correct methods from jax module.
Resulting into a conversion from jax Array to ndarray or a traceback if jax is specified as backend.
This piece of code reproduce the error behavior.
I've just strip in
[...]
personal folder information.Thanks and best regards.
The text was updated successfully, but these errors were encountered: