Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend Enhancements for GPU/TPU Support #144

Merged
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The current setup uses:
* [tox](https://tox.readthedocs.io) ... for testing with different environments.
* [travis](https://travis-ci.com) ... for continuous integration.

We are actively incorporating additional features to OQuPy,
details of which can be found in [DEVELOPMENT.md](./DEVELOPMENT.md).

## How to contribute to the code or documentation
Please use the
[Issues](https://github.com/tempoCollaboration/OQuPy/issues) and
Expand Down
43 changes: 43 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Development

The current development branch "dev/jax" implements

* [Experimental Support for GPUs/TPUs](#experimental-support-for-gpustpus)

## Experimental Support for GPUs/TPUs

Although OQuPy is built on top of the backend-agnostic
[TensorNetwork](https://github.com/google/TensorNetwork) library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.

The "dev/jax" branch adds supports for GPUs/TPUs via the
[JAX](https://jax.readthedocs.io/en/latest/) library.
A new `oqupy.backends.numerical_backend.py` module handles the
[breaking changes in JAX NumPy](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html),
while the rest of the modules utilizes `numpy` and `scipy.linalg` instances from there
without explicitly importing JAX-based libraries.

### Enabling Experimental Features

To enable experimental features switch to the `dev/jax` branch and use
```python
from oqupy.backends import enable_jax_features
enable_jax_features()
```
Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

### Contributing Guidelines

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

* avoid wildcard imports of NumPy and SciPy.
* use `from oqupy.backends.numerical_backend import np` instead of `import numpy as np` and use the alias `default_np` in cases vanilla NumPy is explicitly required.
* use `from oqupy.backends.numerical_backend import la` instead of `import scipy.linalg as la`, except that for non-symmetric eigen-decomposition, `scipy.linalg.eig` should be used.
* use one of `np.dtype_complex` (`np.dtype_float`) or `oqupy.config.NumPyDtypeComplex` (`oqupy.config.NumPyDtypeFloat`) instead of `np.complex_` (`np.float_`).
* convert lists or tuples to arrays when passing them as arguments inside functions.
* use `array = np.update(array, indices, values)` instead of `array[indices] = values`.
* use `np.get_random_floats(seed, shape)` instead of `np.random.default_rng(seed).random(shape)`.
* declare signatures for `np.vectorize` explicitly.
* avoid directly changing the `shape` attribute of an array (use `.reshape` instead)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ Furthermore, OQuPy implements methods to ...
:caption: Development

pages/contributing
pages/gpu_features
pages/authors
pages/how_to_cite
pages/sharing
Expand Down
4 changes: 3 additions & 1 deletion docs/pages/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ class :class:`oqupy.pt_tebd.PtTebd`
dictionary.



Results
-------

Expand Down Expand Up @@ -207,3 +206,6 @@ module :mod:`oqupy.operators`
function :func:`oqupy.helpers.plot_correlations_with_parameters`
A helper function to plot an auto-correlation function and the sampling
points given by a set of parameters for a TEMPO computation.

function :func:`oqupy.backends.enable_jax_features`
Option to use JAX to support multiple device backends (CPUs/GPUs/TPUs).
55 changes: 55 additions & 0 deletions docs/pages/gpu_features.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
Experimental Support for GPUs/TPUs
==================================
The current development branch "dev/jax" implements experimental support
for GPUs/TPUs.

Although OQuPy is built on top of the backend-agnostic
`TensorNetwork <https://github.com/google/TensorNetwork>`__ library,
OQuPy uses vanilla NumPy and SciPy throughout its implementation.

The "dev/jax" branch adds supports for GPUs/TPUs via the
`JAX <https://jax.readthedocs.io/en/latest/>`__ library. A new
``oqupy.backends.numerical_backend.py`` module handles the
`breaking changes in JAX
NumPy <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html>`__,
while the rest of the modules utilizes ``numpy`` and ``scipy.linalg``
instances from there without explicitly importing JAX-based libraries.

Enabling Experimental Features
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To enable experimental features, switch to the ``dev/jax`` branch and use

.. code:: python

from oqupy.backends import enable_jax_features
enable_jax_features()

Alternatively, the `OQUPY_BACKEND` environmental variable may be set to `jax` to
initialize the jax backend by default.

Contributing Guidelines
~~~~~~~~~~~~~~~~~~~~~~~

To contribute features compatible with the JAX backend,
please adhere to the following set of guidelines:

- avoid wildcard imports of NumPy and SciPy.
- use ``from oqupy.backends.numerical_backend import np`` instead of
``import numpy as np`` and use the alias ``default_np`` in cases
vanilla NumPy is explicitly required.
- use ``from oqupy.backends.numerical_backend import la`` instead of
``import scipy.linalg as la``, except that for non-symmetric
eigen-decomposition, ``scipy.linalg.eig`` should be used.
- use one of ``np.dtype_complex`` (``np.dtype_float``) or
``oqupy.config.NumPyDtypeComplex`` (``oqupy.config.NumPyDtypeFloat``)
instead of ``np.complex_`` (``np.float_``).
- convert lists or tuples to arrays when passing them as arguments
inside functions.
- use ``array = np.update(array, indices, values)`` instead of
``array[indices] = values``.
- use ``np.get_random_floats(seed, shape)`` instead of
``np.random.default_rng(seed).random(shape)``.
- declare signatures for ``np.vectorize`` explicitly.
- avoid directly changing the ``shape`` attribute of an array (use
``.reshape`` instead)
43 changes: 43 additions & 0 deletions examples/simple_dynamics_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python

import sys
sys.path.insert(0, '.')
# set the 'OQUPY_BACKEND' environment variable
# to 'jax' to initialize JAX backend by default
# or switch to JAX backend using oqupy.backends
import oqupy
from oqupy.backends import enable_jax_features
# import NumPy from numerical_backend
#from oqupy.backends.numerical_backend import np
#enable_jax_features()

import matplotlib.pyplot as plt
sigma_x = oqupy.operators.sigma("x")
sigma_z = oqupy.operators.sigma("z")
up_density_matrix = oqupy.operators.spin_dm("z+")
Omega = 1.0
omega_cutoff = 5.0
alpha = 0.3

system = oqupy.System(0.5 * Omega * sigma_x)
correlations = oqupy.PowerLawSD(alpha=alpha,
zeta=1,
cutoff=omega_cutoff,
cutoff_type='exponential')
bath = oqupy.Bath(0.5 * sigma_z, correlations)
tempo_parameters = oqupy.TempoParameters(dt=0.1, tcut=3.0, epsrel=10**(-4))

dynamics = oqupy.tempo_compute(system=system,
bath=bath,
initial_state=up_density_matrix,
start_time=0.0,
end_time=2.0,
parameters=tempo_parameters,
unique=True)
t, s_z = dynamics.expectations(0.5*sigma_z, real=True)
print(s_z)
plt.plot(t, s_z, label=r'$\alpha=0.3$')
plt.xlabel(r'$t\,\Omega$')
plt.ylabel(r'$\langle\sigma_z\rangle$')
plt.savefig('simple_dynamics_jax.png')
#plt.show()
9 changes: 9 additions & 0 deletions oqupy/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Module to initialize OQuPy's backends."""

from oqupy.backends.numerical_backend import set_numerical_backends

def enable_jax_features():
"""Function to enable experimental features."""

# set numerical backend to JAX
set_numerical_backends('jax')
2 changes: 1 addition & 1 deletion oqupy/backends/node_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

from typing import Any, List, Optional, Text, Tuple, Union

import numpy as np
import tensornetwork as tn
from tensornetwork import Node
from tensornetwork.backends.base_backend import BaseBackend

from oqupy.backends.numerical_backend import np

class NodeArray:
"""NodeArray class. """
Expand Down
166 changes: 166 additions & 0 deletions oqupy/backends/numerical_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Module containing NumPy-like and SciPy-like numerical backends.
"""

import os

import numpy as default_np
import scipy.linalg as default_la

from tensornetwork.backend_contextmanager import \
set_default_backend

import oqupy.config as oc

# store instances of the initialized backends
# this way, `oqupy.config` remains unchanged
# and `ocupy.config.DEFAULT_BACKEND` is used
# when NumPy and LinAlg are initialized
NUMERICAL_BACKEND_INSTANCES = {}

def get_numerical_backends(
backend_name: str,
):
"""Function to get numerical backend.

Parameters
----------
backend_name: str
Name of the backend. Options are `'jax'` and `'numpy'`.

Returns
-------
backends: list
NumPy and LinAlg backends.
"""

_bn = backend_name.lower()
if _bn in NUMERICAL_BACKEND_INSTANCES:
set_default_backend(_bn)
return NUMERICAL_BACKEND_INSTANCES[_bn]
assert _bn in ['jax', 'numpy'], \
"currently supported backends are `'jax'` and `'numpy'`"

if 'jax' in _bn:
try:
# explicitly import and configure jax
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla
jax.config.update('jax_enable_x64', True)

# # TODO: GPU memory allocation (default is 0.75)
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.5'

# set TensorNetwork backend
set_default_backend('jax')

NUMERICAL_BACKEND_INSTANCES['jax'] = [jnp, jla]
return NUMERICAL_BACKEND_INSTANCES['jax']
except ImportError:
print("JAX not installed, defaulting to NumPy")

# set TensorNetwork backend
set_default_backend('numpy')

NUMERICAL_BACKEND_INSTANCES['numpy'] = [default_np, default_la]
return NUMERICAL_BACKEND_INSTANCES['numpy']

class NumPy:
"""
The NumPy backend employing
dynamic switching through `oqupy.config`.
"""
def __init__(self,
backend_name=oc.DEFAULT_BACKEND,
):
"""Getter for the backend."""
self.backend = get_numerical_backends(backend_name)[0]

@property
def dtype_complex(self) -> default_np.dtype:
"""Getter for the complex datatype."""
return oc.NumPyDtypeComplex

@property
def dtype_float(self) -> default_np.dtype:
"""Getter for the float datatype."""
return oc.NumPyDtypeFloat

def __getattr__(self,
name: str,
):
"""Return the backend's default attribute."""
return getattr(self.backend, name)

def update(self,
array,
indices: tuple,
values,
) -> default_np.ndarray:
"""Option to update select indices of an array with given values."""
if not isinstance(array, default_np.ndarray):
return array.at[indices].set(values)
array[indices] = values
return array

def get_random_floats(self,
seed,
shape,
):
"""Method to obtain random floats with a given seed and shape."""
random_floats = default_np.random.default_rng(seed).random(shape, \
dtype=default_np.float64)
return self.backend.array(random_floats, dtype=self.dtype_float)

class LinAlg:
"""
The Linear Algebra backend employing
dynamic switching through `oqupy.config`.
"""
def __init__(self,
backend_name=oc.DEFAULT_BACKEND,
):
"""Getter for the backend."""
self.backend = get_numerical_backends(backend_name)[1]

def __getattr__(self,
name: str,
):
"""Return the backend's default attribute."""
return getattr(self.backend, name)

# setup libraries using environment variable
# fall back to oqupy.config.DEFAULT_BACKEND
try:
BACKEND_NAME = os.environ[oc.BACKEND_ENV_VAR]
except KeyError:
BACKEND_NAME = oc.DEFAULT_BACKEND
np = NumPy(backend_name=BACKEND_NAME)
la = LinAlg(backend_name=BACKEND_NAME)

def set_numerical_backends(
backend_name: str
):
"""Function to set numerical backend.

Parameters
----------
backend_name: str
Name of the backend. Options are `'jax'` and `'numpy'`.
"""
backends = get_numerical_backends(backend_name)
np.backend = backends[0]
la.backend = backends[1]
Loading