Replies: 2 comments
-
I tried to run OQuPy on the GPU, although not with PyTorch/TensorFlow, but with JAX. Since TensorNetwork's Updating the Modules to Support JAXThe import numpy as default_numpy
import scipy as default_scipy
NUMERICAL_BACKEND_NUMPY = default_numpy
NUMERICAL_BACKEND_LINALG = default_scipy.linalg
NUMPY_DTYPE_COMPLEX = default_numpy.complex128 # earlier NpDtype
NUMPY_DTYPE_REAL = default_numpy.float64 # earlier NpDtypeReal The specific choice of class NumPy:
@property
def backend(self) -> default_numpy:
return oqupy.config.NUMERICAL_BACKEND_NUMPY
def __getattr__(self, name):
backend = object.__getattribute__(self, 'backend')
return getattr(backend, name)
# additional overridden methods
class SciPyLinAlg:
# same as above for ``default_scipy.linalg``
np = Numpy()
sl = SciPyLinAlg() The calls for import numpy as np
import scipy.linalg as linalg
from oqupy.config import NpDtype, NpDtypeReal to: from oqupy.backends.numerical_backend import np, sl
# change corresponding occurrences of NpDtype and NpDtypeReal With a few more changes to override in-place updates and random number generations in Running an ExampleAny script using the JAX backend would require the following lines: # packages
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsl
import oqupy.config as oc
import tensornetwork as tn
from oqupy.backends.numerical_backend import np, sl
# configuration
jax.config.update('jax_enable_x64', True)
oc.NUMERICAL_BACKEND_NUMPY = jnp
oc.NUMPY_DTYPE_COMPLEX = jnp.complex128
oc.NUMPY_DTYPE_REAL = jnp.float64
oc.NUMERICAL_BACKEND_LINALG = jsl
tn.set_default_backend('jax') The runtimes (without jitted functions) corresponding to
Further optimizations (for-looped updates, lists -> arrays) only reduced the runtimes by a couple of seconds. IssuesThe tests involving the NumPy backend went smoothly! However, the JAX-backend tests had multiple issues. Before attempting to resolve them, I wanted to know if any such change would break the flow of the package. A couple of these issues are mentioned below: Immutability of
|
Beta Was this translation helpful? Give feedback.
-
Hi @Sampreet, as per our email discussion thanks for the excellent work on this and for so clearly laying out relevant points of discussion for us. Your contribution certainly has enough structure for us to move to the Issue section! Continued in #142. |
Beta Was this translation helpful? Give feedback.
-
We currently use TensorNetwork as a backend in OQuPy, which in turn can use numpy/pytorch/tensorflow as a backend. Has anybody tried using OQuPy with tensorflow on a GPU?
Beta Was this translation helpful? Give feedback.
All reactions