From 94c62a05d5ebcedd30f59c90b9926de967ed10b5 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Fri, 5 Jul 2024 10:54:06 -0400 Subject: [PATCH] Remove NumPy as a hard dependency (#204) * Removes old tensor blas * Reworks tests to remove numpy * Removes numpy from helpers code * Starts fixing up tests * Better functional names * Absolute paths and finishes merge conflicts * misc * 44 tests passing * 44 tests passing * 44 tests passing * Removes numpy from core code base * Updates MyPy * Adds ellipses test case, closes #235, 236 * Adds a pure torch environment and potentially fixes codecov * Adds a raw torch test * Overhauls test_backends, allows torch to be tested without NumPy * Removes generic shape issue * Adds a get_shapes functionality to replace NumPy casting * Adds edge case tests with new shape kinds * Removes debug flag * Attempts to fix the torch only env * Adds a CI check to ensure there is no NumPy in the torch only env --- .github/workflows/Linting.yml | 2 +- .github/workflows/Tests.yml | 11 +- devtools/ci_scripts/check_no_numpy.py | 5 + devtools/conda-envs/min-deps-environment.yaml | 4 +- .../conda-envs/torch-only-environment.yaml | 19 ++ opt_einsum/__init__.py | 15 ++ opt_einsum/backends/__init__.py | 17 +- opt_einsum/backends/cupy.py | 7 +- opt_einsum/backends/dispatch.py | 36 +-- opt_einsum/backends/jax.py | 6 +- opt_einsum/backends/object_arrays.py | 3 +- opt_einsum/backends/tensorflow.py | 9 +- opt_einsum/backends/theano.py | 7 +- opt_einsum/backends/torch.py | 9 +- opt_einsum/blas.py | 148 +---------- opt_einsum/contract.py | 13 +- opt_einsum/helpers.py | 181 +------------- opt_einsum/parser.py | 95 +++++--- opt_einsum/path_random.py | 4 +- opt_einsum/paths.py | 58 +++-- opt_einsum/sharing.py | 4 +- opt_einsum/testing.py | 229 ++++++++++++++++++ opt_einsum/tests/test_backends.py | 176 +++++++------- opt_einsum/tests/test_blas.py | 27 +-- opt_einsum/tests/test_contract.py | 33 +-- opt_einsum/tests/test_edge_cases.py | 23 +- opt_einsum/tests/test_input.py | 5 +- opt_einsum/tests/test_parser.py | 45 +++- opt_einsum/tests/test_paths.py | 89 ++++--- opt_einsum/tests/test_sharing.py | 34 ++- opt_einsum/typing.py | 5 +- setup.cfg | 3 + 32 files changed, 694 insertions(+), 628 deletions(-) create mode 100644 devtools/ci_scripts/check_no_numpy.py create mode 100644 devtools/conda-envs/torch-only-environment.yaml create mode 100644 opt_einsum/testing.py diff --git a/.github/workflows/Linting.yml b/.github/workflows/Linting.yml index 72b8d637..e590ee3a 100644 --- a/.github/workflows/Linting.yml +++ b/.github/workflows/Linting.yml @@ -47,7 +47,7 @@ jobs: mypy opt_einsum black: - name: black + name: Black runs-on: ubuntu-latest steps: diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 035b59c5..18ab6d21 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -8,7 +8,7 @@ on: jobs: miniconda-setup: - name: Env (${{ matrix.environment }}) - Py${{ matrix.python-version }} + name: Env strategy: fail-fast: false matrix: @@ -21,6 +21,8 @@ jobs: environment: "min-ver" - python-version: 3.11 environment: "full" + - python-version: 3.12 + environment: "torch-only" runs-on: ubuntu-latest @@ -43,6 +45,11 @@ jobs: conda config --show-sources conda config --show + - name: Check no NumPy for torch-only environment + if: matrix.environment == 'torch-only' + run: | + python devtools/ci_scripts/check_no_numpy.py + - name: Install shell: bash -l {0} run: | @@ -58,7 +65,7 @@ jobs: run: | coverage report - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml diff --git a/devtools/ci_scripts/check_no_numpy.py b/devtools/ci_scripts/check_no_numpy.py new file mode 100644 index 00000000..7cd526f9 --- /dev/null +++ b/devtools/ci_scripts/check_no_numpy.py @@ -0,0 +1,5 @@ +try: + import numpy + exit(1) +except ModuleNotFoundError: + exit(0) \ No newline at end of file diff --git a/devtools/conda-envs/min-deps-environment.yaml b/devtools/conda-envs/min-deps-environment.yaml index e6ea3825..d439666a 100644 --- a/devtools/conda-envs/min-deps-environment.yaml +++ b/devtools/conda-envs/min-deps-environment.yaml @@ -4,12 +4,10 @@ channels: dependencies: # Base depends - python >=3.9 - - numpy >=1.23 - - nomkl # Testing - autoflake - - black + - black - codecov - isort - mypy diff --git a/devtools/conda-envs/torch-only-environment.yaml b/devtools/conda-envs/torch-only-environment.yaml new file mode 100644 index 00000000..fa469fb0 --- /dev/null +++ b/devtools/conda-envs/torch-only-environment.yaml @@ -0,0 +1,19 @@ +name: test +channels: + - pytorch + - conda-forge +dependencies: + # Base depends + - python >=3.9 + - pytorch::pytorch >=2.0,<3.0.0a + - pytorch::cpuonly + - mkl + + # Testing + - autoflake + - black + - codecov + - isort + - mypy + - pytest + - pytest-cov diff --git a/opt_einsum/__init__.py b/opt_einsum/__init__.py index 828fc529..853587ec 100644 --- a/opt_einsum/__init__.py +++ b/opt_einsum/__init__.py @@ -9,6 +9,21 @@ from opt_einsum.paths import BranchBound, DynamicProgramming from opt_einsum.sharing import shared_intermediates +__all__ = [ + "blas", + "helpers", + "path_random", + "paths", + "contract", + "contract_expression", + "contract_path", + "get_symbol", + "RandomGreedy", + "BranchBound", + "DynamicProgramming", + "shared_intermediates", +] + # Handle versioneer from opt_einsum._version import get_versions # isort:skip diff --git a/opt_einsum/backends/__init__.py b/opt_einsum/backends/__init__.py index a9b85795..99839d20 100644 --- a/opt_einsum/backends/__init__.py +++ b/opt_einsum/backends/__init__.py @@ -3,11 +3,18 @@ """ # Backends -from .cupy import to_cupy -from .dispatch import build_expression, evaluate_constants, get_func, has_backend, has_einsum, has_tensordot -from .tensorflow import to_tensorflow -from .theano import to_theano -from .torch import to_torch +from opt_einsum.backends.cupy import to_cupy +from opt_einsum.backends.dispatch import ( + build_expression, + evaluate_constants, + get_func, + has_backend, + has_einsum, + has_tensordot, +) +from opt_einsum.backends.tensorflow import to_tensorflow +from opt_einsum.backends.theano import to_theano +from opt_einsum.backends.torch import to_torch __all__ = [ "get_func", diff --git a/opt_einsum/backends/cupy.py b/opt_einsum/backends/cupy.py index f17ff80a..9fd25bac 100644 --- a/opt_einsum/backends/cupy.py +++ b/opt_einsum/backends/cupy.py @@ -2,9 +2,8 @@ Required functions for optimized contractions of numpy arrays using cupy. """ -import numpy as np - -from ..sharing import to_backend_cache_wrap +from opt_einsum.helpers import has_array_interface +from opt_einsum.sharing import to_backend_cache_wrap __all__ = ["to_cupy", "build_expression", "evaluate_constants"] @@ -13,7 +12,7 @@ def to_cupy(array): # pragma: no cover import cupy - if isinstance(array, np.ndarray): + if has_array_interface(array): return cupy.asarray(array) return array diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index 0af11de3..0abad459 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -5,16 +5,14 @@ """ import importlib -from typing import Any, Dict +from typing import Any, Dict, Tuple -import numpy - -from . import cupy as _cupy -from . import jax as _jax -from . import object_arrays -from . import tensorflow as _tensorflow -from . import theano as _theano -from . import torch as _torch +from opt_einsum.backends import cupy as _cupy +from opt_einsum.backends import jax as _jax +from opt_einsum.backends import object_arrays +from opt_einsum.backends import tensorflow as _tensorflow +from opt_einsum.backends import theano as _theano +from opt_einsum.backends import torch as _torch __all__ = [ "get_func", @@ -57,16 +55,22 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any: # manually cache functions as python2 doesn't support functools.lru_cache # other libs will be added to this if needed, but pre-populate with numpy -_cached_funcs = { - ("tensordot", "numpy"): numpy.tensordot, - ("transpose", "numpy"): numpy.transpose, - ("einsum", "numpy"): numpy.einsum, - # also pre-populate with the arbitrary object backend - ("tensordot", "object"): numpy.tensordot, - ("transpose", "object"): numpy.transpose, +_cached_funcs: Dict[Tuple[str, str], Any] = { ("einsum", "object"): object_arrays.object_einsum, } +try: + import numpy as np + + _cached_funcs[("tensordot", "numpy")] = np.tensordot + _cached_funcs[("transpose", "numpy")] = np.transpose + _cached_funcs[("einsum", "numpy")] = np.einsum + # also pre-populate with the arbitrary object backend + _cached_funcs[("tensordot", "object")] = np.tensordot + _cached_funcs[("transpose", "object")] = np.transpose +except ModuleNotFoundError: + pass + def get_func(func: str, backend: str = "numpy", default: Any = None) -> Any: """Return ``{backend}.{func}``, e.g. ``numpy.einsum``, diff --git a/opt_einsum/backends/jax.py b/opt_einsum/backends/jax.py index a9e22df0..d2346932 100644 --- a/opt_einsum/backends/jax.py +++ b/opt_einsum/backends/jax.py @@ -2,9 +2,7 @@ Required functions for optimized contractions of numpy arrays using jax. """ -import numpy as np - -from ..sharing import to_backend_cache_wrap +from opt_einsum.sharing import to_backend_cache_wrap __all__ = ["build_expression", "evaluate_constants"] @@ -33,6 +31,8 @@ def build_expression(_, expr): # pragma: no cover jax_expr = jax.jit(expr._contract) def jax_contract(*arrays): + import numpy as np + return np.asarray(jax_expr(arrays)) return jax_contract diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index eae0e92f..d870beb7 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -5,8 +5,6 @@ import functools import operator -import numpy as np - from opt_einsum.typing import ArrayType @@ -31,6 +29,7 @@ def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: out : numpy.ndarray The output tensor, with ``dtype=object``. """ + import numpy as np # when called by ``opt_einsum`` we will always be given a full eq lhs, output = eq.split("->") diff --git a/opt_einsum/backends/tensorflow.py b/opt_einsum/backends/tensorflow.py index ad7a3bbb..ef2a23cb 100644 --- a/opt_einsum/backends/tensorflow.py +++ b/opt_einsum/backends/tensorflow.py @@ -2,9 +2,8 @@ Required functions for optimized contractions of numpy arrays using tensorflow. """ -import numpy as np - -from ..sharing import to_backend_cache_wrap +from opt_einsum.helpers import has_array_interface +from opt_einsum.sharing import to_backend_cache_wrap __all__ = ["to_tensorflow", "build_expression", "evaluate_constants"] @@ -40,13 +39,13 @@ def to_tensorflow(array, constant=False): tf, device, eager = _get_tensorflow_and_device() if eager: - if isinstance(array, np.ndarray): + if has_array_interface(array): with tf.device(device): return tf.convert_to_tensor(array) return array - if isinstance(array, np.ndarray): + if has_array_interface(array): if constant: return tf.convert_to_tensor(array) diff --git a/opt_einsum/backends/theano.py b/opt_einsum/backends/theano.py index 0abdf0df..c7ab2abb 100644 --- a/opt_einsum/backends/theano.py +++ b/opt_einsum/backends/theano.py @@ -2,9 +2,8 @@ Required functions for optimized contractions of numpy arrays using theano. """ -import numpy as np - -from ..sharing import to_backend_cache_wrap +from opt_einsum.helpers import has_array_interface +from opt_einsum.sharing import to_backend_cache_wrap __all__ = ["to_theano", "build_expression", "evaluate_constants"] @@ -14,7 +13,7 @@ def to_theano(array, constant=False): """Convert a numpy array to ``theano.tensor.TensorType`` instance.""" import theano - if isinstance(array, np.ndarray): + if has_array_interface(array): if constant: return theano.tensor.constant(array) diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index c3ae9b5e..561a0a07 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -2,10 +2,9 @@ Required functions for optimized contractions of numpy arrays using pytorch. """ -import numpy as np - -from ..parser import convert_to_valid_einsum_chars -from ..sharing import to_backend_cache_wrap +from opt_einsum.helpers import has_array_interface +from opt_einsum.parser import convert_to_valid_einsum_chars +from opt_einsum.sharing import to_backend_cache_wrap __all__ = [ "transpose", @@ -104,7 +103,7 @@ def tensordot(x, y, axes=2): def to_torch(array): torch, device = _get_torch_and_device() - if isinstance(array, np.ndarray): + if has_array_interface(array): return torch.from_numpy(array).to(device) return array diff --git a/opt_einsum/blas.py b/opt_einsum/blas.py index 4912cafb..487ace2a 100644 --- a/opt_einsum/blas.py +++ b/opt_einsum/blas.py @@ -4,10 +4,7 @@ from typing import List, Sequence, Tuple, Union -import numpy as np - -from . import helpers -from .typing import ArrayIndexType +from opt_einsum.typing import ArrayIndexType __all__ = ["can_blas", "tensor_blas"] @@ -126,146 +123,3 @@ def can_blas( # Conventional tensordot else: return "TDOT" - - -def tensor_blas( - view_left: np.ndarray, - input_left: str, - view_right: np.ndarray, - input_right: str, - index_result: str, - idx_removed: ArrayIndexType, -) -> np.ndarray: - """ - Computes the dot product between two tensors, attempts to use np.dot and - then tensordot if that fails. - - Parameters - ---------- - view_left : array_like - The left hand view - input_left : str - Indices of the left view - view_right : array_like - The right hand view - input_right : str - Indices of the right view - index_result : str - The resulting indices - idx_removed : set - Indices removed in the contraction - - Returns - ------- - type : array - The resulting BLAS operation. - - Notes - ----- - Interior function for tensor BLAS. - - This function will attempt to use `np.dot` by the iterating through the - four possible transpose cases. If this fails all inner and matrix-vector - operations will be handed off to einsum while all matrix-matrix operations will - first copy the data, perform the DGEMM, and then copy the data to the required - order. - - Examples - -------- - - >>> a = np.random.rand(4, 4) - >>> b = np.random.rand(4, 4) - >>> tmp = tensor_blas(a, 'ij', b, 'jk', 'ik', set('j')) - >>> np.allclose(tmp, np.dot(a, b)) - - """ - - idx_removed = frozenset(idx_removed) - keep_left = frozenset(input_left) - idx_removed - keep_right = frozenset(input_right) - idx_removed - - # We trust this must be called correctly - dimension_dict = {} - for i, s in zip(input_left, view_left.shape): - dimension_dict[i] = s - for i, s in zip(input_right, view_right.shape): - dimension_dict[i] = s - - # Do we want to be able to do this? - - # Check for duplicate indices, cannot do einsum('iij,jkk->ik') operations here - # if (len(set(input_left)) != len(input_left)): - # new_inds = ''.join(keep_left) + ''.join(idx_removed) - # view_left = np.einsum(input_left + '->' + new_inds, view_left, order='C') - # input_left = new_inds - - # if (len(set(input_right)) != len(input_right)): - # new_inds = ''.join(idx_removed) + ''.join(keep_right) - # view_right = np.einsum(input_right + '->' + new_inds, view_right, order='C') - # input_right = new_inds - - # Tensordot guarantees a copy for ndim > 2, should avoid skip if possible - rs = len(idx_removed) - dim_left = helpers.compute_size_by_dict(keep_left, dimension_dict) - dim_right = helpers.compute_size_by_dict(keep_right, dimension_dict) - dim_removed = helpers.compute_size_by_dict(idx_removed, dimension_dict) - tensor_result = input_left + input_right - for sidx in idx_removed: - tensor_result = tensor_result.replace(sidx, "") - - # This is ugly, but can vastly speed up certain operations - # Vectordot - if input_left == input_right: - new_view = np.dot(view_left.ravel(), view_right.ravel()) - - # Matrix multiply - # No transpose needed - elif input_left[-rs:] == input_right[:rs]: - new_view = np.dot( - view_left.reshape(dim_left, dim_removed), - view_right.reshape(dim_removed, dim_right), - ) - - # Transpose both - elif input_left[:rs] == input_right[-rs:]: - new_view = np.dot( - view_left.reshape(dim_removed, dim_left).T, - view_right.reshape(dim_right, dim_removed).T, - ) - - # Transpose right - elif input_left[-rs:] == input_right[-rs:]: - new_view = np.dot( - view_left.reshape(dim_left, dim_removed), - view_right.reshape(dim_right, dim_removed).T, - ) - - # Transpose left - elif input_left[:rs] == input_right[:rs]: - new_view = np.dot( - view_left.reshape(dim_removed, dim_left).T, - view_right.reshape(dim_removed, dim_right), - ) - - # Conventional tensordot - else: - # Find indices to contract over - left_pos: Tuple[int, ...] = () - right_pos: Tuple[int, ...] = () - for fidx in idx_removed: - left_pos += (input_left.find(fidx),) - right_pos += (input_right.find(fidx),) - new_view = np.tensordot(view_left, view_right, axes=(left_pos, right_pos)) - - # Make sure the resulting shape is correct - tensor_shape = tuple(dimension_dict[x] for x in tensor_result) - if new_view.shape != tensor_shape: - if len(tensor_result) > 0: - new_view.shape = tensor_shape - else: - new_view = np.squeeze(new_view) - - if tensor_result != index_result: - new_view = np.einsum(tensor_result + "->" + index_result, new_view) - - return new_view diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index e44d12da..ee3c2dbd 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -2,7 +2,6 @@ Contains the primary optimization and contraction routines. """ -from collections import namedtuple from decimal import Decimal from functools import lru_cache from typing import Any, Collection, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union, overload @@ -10,6 +9,7 @@ from opt_einsum import backends, blas, helpers, parser, paths, sharing from opt_einsum.typing import ( ArrayIndexType, + ArrayShaped, ArrayType, BackendType, ContractionListType, @@ -305,7 +305,7 @@ def contract_path( if shapes: input_shapes = operands_prepped else: - input_shapes = [x.shape for x in operands_prepped] + input_shapes = [parser.get_shape(x) for x in operands_prepped] output_set = frozenset(output_subscript) indices = frozenset(input_subscripts.replace(",", "")) @@ -957,14 +957,11 @@ def __str__(self) -> str: return "".join(s) -Shaped = namedtuple("Shaped", ["shape"]) - - -def shape_only(shape: TensorShapeType) -> Shaped: +def shape_only(shape: TensorShapeType) -> ArrayShaped: """Dummy ``numpy.ndarray`` which has a shape only - for generating contract expressions. """ - return Shaped(shape) + return ArrayShaped(shape) # Overlaod for contract(einsum_string, *operands) @@ -1069,7 +1066,7 @@ def contract_expression( ) if not isinstance(subscripts, str): - subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) # type: ignore + subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) kwargs["_gen_expression"] = True diff --git a/opt_einsum/helpers.py b/opt_einsum/helpers.py index 594212b3..57752590 100644 --- a/opt_einsum/helpers.py +++ b/opt_einsum/helpers.py @@ -2,55 +2,17 @@ Contains helper functions for opt_einsum testing scripts """ -from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Literal, Optional, Tuple, Union, overload +from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload -import numpy as np - -from opt_einsum.parser import get_symbol -from opt_einsum.typing import ArrayIndexType, PathType +from opt_einsum.typing import ArrayIndexType, ArrayType __all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"] _valid_chars = "abcdefghijklmopqABC" -_sizes = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]) +_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] _default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} -def build_views(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> List[np.ndarray]: - """ - Builds random numpy arrays for testing. - - Parameters - ---------- - string : str - List of tensor strings to build - dimension_dict : dictionary - Dictionary of index _sizes - - Returns - ------- - ret : list of np.ndarry's - The resulting views. - - Examples - -------- - >>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5}) - >>> view[0].shape - (2, 3, 3, 5) - - """ - - if dimension_dict is None: - dimension_dict = _default_dim_dict - - views = [] - terms = string.split("->")[0].split(",") - for term in terms: - dims = [dimension_dict[x] for x in term] - views.append(np.random.rand(*dims)) - return views - - @overload def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: ... @@ -191,137 +153,8 @@ def flop_count( return overall_size * op_factor -@overload -def rand_equation( - n: int, - regularity: int, - n_out: int = ..., - d_min: int = ..., - d_max: int = ..., - seed: Optional[int] = ..., - global_dim: bool = ..., - *, - return_size_dict: Literal[True], -) -> Tuple[str, PathType, Dict[str, int]]: ... - - -@overload -def rand_equation( - n: int, - regularity: int, - n_out: int = ..., - d_min: int = ..., - d_max: int = ..., - seed: Optional[int] = ..., - global_dim: bool = ..., - return_size_dict: Literal[False] = ..., -) -> Tuple[str, PathType]: ... - - -def rand_equation( - n: int, - regularity: int, - n_out: int = 0, - d_min: int = 2, - d_max: int = 9, - seed: Optional[int] = None, - global_dim: bool = False, - return_size_dict: bool = False, -) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]: - """Generate a random contraction and shapes. - - Parameters: - n: Number of array arguments. - regularity: 'Regularity' of the contraction graph. This essentially determines how - many indices each tensor shares with others on average. - n_out: Number of output indices (i.e. the number of non-contracted indices). - Defaults to 0, i.e., a contraction resulting in a scalar. - d_min: Minimum dimension size. - d_max: Maximum dimension size. - seed: If not None, seed numpy's random generator with this. - global_dim: Add a global, 'broadcast', dimension to every operand. - return_size_dict: Return the mapping of indices to sizes. - - Returns: - eq: The equation string. - shapes: The array shapes. - size_dict: The dict of index sizes, only returned if ``return_size_dict=True``. - - Examples: - ```python - >>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42) - >>> eq - 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' - - >>> shapes - [(9, 5, 4, 5, 4), - (4, 4, 8, 5), - (9, 4, 6, 9), - (6, 6), - (6, 9, 7, 8), - (4,), - (9, 3, 9, 4, 9), - (6, 8, 4, 6, 8, 6, 3), - (4, 7, 8, 8, 6, 9, 6), - (9, 5, 3, 3, 9, 5)] - ``` - """ - - if seed is not None: - np.random.seed(seed) - - # total number of indices - num_inds = n * regularity // 2 + n_out - inputs = ["" for _ in range(n)] - output = [] - - size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)} - - # generate a list of indices to place either once or twice - def gen(): - for i, ix in enumerate(size_dict): - # generate an outer index - if i < n_out: - output.append(ix) - yield ix - # generate a bond - else: - yield ix - yield ix - - # add the indices randomly to the inputs - for i, ix in enumerate(np.random.permutation(list(gen()))): - # make sure all inputs have at least one index - if i < n: - inputs[i] += ix - else: - # don't add any traces on same op - where = np.random.randint(0, n) - while ix in inputs[where]: - where = np.random.randint(0, n) - - inputs[where] += ix - - # possibly add the same global dim to every arg - if global_dim: - gdim = get_symbol(num_inds) - size_dict[gdim] = np.random.randint(d_min, d_max + 1) - for i in range(n): - inputs[i] += gdim - output += gdim - - # randomly transpose the output indices and form equation - output = "".join(np.random.permutation(output)) # type: ignore - eq = "{}->{}".format(",".join(inputs), output) - - # make the shapes - shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] - - if return_size_dict: - return ( - eq, - shapes, - size_dict, - ) +def has_array_interface(array: ArrayType) -> ArrayType: + if hasattr(array, "__array_interface__"): + return True else: - return (eq, shapes) + return False diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 47567ae5..15f7181c 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -3,16 +3,16 @@ """ import itertools -from typing import Any, Dict, Iterator, List, Tuple, Union +from collections.abc import Sequence +from typing import Any, Dict, Iterator, List, Tuple -import numpy as np - -from .typing import ArrayType, TensorShapeType +from opt_einsum.typing import ArrayType, TensorShapeType __all__ = [ "is_valid_einsum_char", "has_valid_einsum_chars_only", "get_symbol", + "get_shape", "gen_unused_symbols", "convert_to_valid_einsum_chars", "alpha_canonicalize", @@ -173,6 +173,30 @@ def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output) +_BaseTypes = (bool, int, float, complex, str, bytes) + + +def get_shape(x: Any) -> TensorShapeType: + """Get the shape of the array-like object `x`. If `x` is not array-like, raise an error. + + Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes. + BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`. + """ + + if hasattr(x, "shape"): + return x.shape + elif isinstance(x, _BaseTypes): + return tuple() + elif isinstance(x, Sequence): + shape = [] + while isinstance(x, Sequence) and not isinstance(x, _BaseTypes): + shape.append(len(x)) + x = x[0] + return tuple(shape) + else: + raise ValueError(f"Cannot determine the shape of {x}, can only determine the shape of array-like objects.") + + def possibly_convert_to_numpy(x: Any) -> Any: """Convert things without a 'shape' to ndarrays, but leave everything else. @@ -199,6 +223,13 @@ def possibly_convert_to_numpy(x: Any) -> Any: """ if not hasattr(x, "shape"): + try: + import numpy as np + except ModuleNotFoundError: + raise ModuleNotFoundError( + "numpy is required to convert non-array objects to arrays. This function will be deprecated in the future." + ) + return np.asanyarray(x) else: return x @@ -224,17 +255,16 @@ def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: return new_sub -def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[str, Tuple[ArrayType, ...]]: +def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]: """Convert 'interleaved' input to standard einsum input.""" tmp_operands = list(operands) operand_list = [] subscript_list = [] - for p in range(len(operands) // 2): + for _ in range(len(operands) // 2): operand_list.append(tmp_operands.pop(0)) subscript_list.append(tmp_operands.pop(0)) output_list = tmp_operands[-1] if len(tmp_operands) else None - operands = [possibly_convert_to_numpy(x) for x in operand_list] # build a map from user symbols to single-character symbols based on `get_symbol` # The map retains the intrinsic order of user symbols @@ -259,39 +289,36 @@ def convert_interleaved_input(operands: Union[List[Any], Tuple[Any]]) -> Tuple[s subscripts += "->" subscripts += convert_subscripts(output_list, symbol_map) - return subscripts, tuple(operands) + return subscripts, tuple(operand_list) def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: """ A reproduction of einsum c side einsum parsing in python. - **Parameters:** - Intakes the same inputs as `contract_path`, but NOT the keyword args. The only - supported keyword argument is: - - **shapes** - *(bool, optional)* Whether ``parse_einsum_input`` should assume arrays (the default) or - array shapes have been supplied. + Parameters: + operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only + supported keyword argument is: + shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or + array shapes have been supplied. Returns - ------- - input_strings : str - Parsed input strings - output_string : str - Parsed output string - operands : list of array_like - The operands to use in the numpy contraction - - Examples - -------- - The operand list is simplified to reduce printing: - - >>> a = np.random.rand(4, 4) - >>> b = np.random.rand(4, 4, 4) - >>> parse_einsum_input(('...a,...a->...', a, b)) - ('za,xza', 'xz', [a, b]) - - >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) - ('za,xza', 'xz', [a, b]) + input_strings: Parsed input strings + output_string: Parsed output string + operands: The operands to use in the numpy contraction + + Examples: + The operand list is simplified to reduce printing: + + ```python + >>> a = np.random.rand(4, 4) + >>> b = np.random.rand(4, 4, 4) + >>> parse_einsum_input(('...a,...a->...', a, b)) + ('za,xza', 'xz', [a, b]) + + >>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0])) + ('za,xza', 'xz', [a, b]) + ``` """ if len(operands) == 0: @@ -305,14 +332,14 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L "shapes is set to True but given at least one operand looks like an array" " (at least one operand has a shape attribute). " ) - operands = [possibly_convert_to_numpy(x) for x in operands[1:]] + operands = operands[1:] else: subscripts, operands = convert_interleaved_input(operands) if shapes: operand_shapes = operands else: - operand_shapes = [o.shape for o in operands] + operand_shapes = [get_shape(o) for o in operands] # Check for proper "->" if ("-" in subscripts) or (">" in subscripts): diff --git a/opt_einsum/path_random.py b/opt_einsum/path_random.py index ae7eff5a..3ed51e64 100644 --- a/opt_einsum/path_random.py +++ b/opt_einsum/path_random.py @@ -12,8 +12,8 @@ from random import seed as random_seed from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union -from . import helpers, paths -from .typing import ArrayIndexType, ArrayType, PathType +from opt_einsum import helpers, paths +from opt_einsum.typing import ArrayIndexType, ArrayType, PathType __all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"] diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index 902ace30..ac330b8a 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -2,6 +2,7 @@ Contains the path technology behind opt_einsum in addition to several path helpers """ +import bisect import functools import heapq import itertools @@ -13,8 +14,6 @@ from typing import Counter as CounterType from typing import Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union -import numpy as np - from opt_einsum.helpers import compute_size_by_dict, flop_count from opt_einsum.typing import ArrayIndexType, PathSearchFunctionType, PathType, TensorShapeType @@ -39,17 +38,13 @@ class PathOptimizer: Subclassed optimizers should define a call method with signature: ```python - def __call__(self, inputs, output, size_dict, memory_limit=None): + def __call__(self, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: dict[str, int], memory_limit: int | None = None) -> list[tuple[int, ...]]: \"\"\" Parameters: - inputs : list[set[str]] - The indices of each input array. - outputs : set[str] - The output indices - size_dict : dict[str, int] - The size of each index - memory_limit : int, optional - If given, the maximum allowed memory. + inputs: The indices of each input array. + outputs: The output indices + size_dict: The size of each index + memory_limit: If given, the maximum allowed memory. \"\"\" # ... compute path here ... return path @@ -98,13 +93,40 @@ def ssa_to_linear(ssa_path: PathType) -> PathType: #> [(0, 3), (1, 2), (0, 1)] ``` """ - ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) + # ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) # type: ignore + # path = [] + # for ssa_ids in ssa_path: + # path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids)) + # for ssa_id in ssa_ids: + # ids[ssa_id:] -= 1 + # return path + + N = sum(map(len, ssa_path)) - len(ssa_path) + 1 + ids = list(range(N)) path = [] - for ssa_ids in ssa_path: - path.append(tuple(int(ids[ssa_id]) for ssa_id in ssa_ids)) - for ssa_id in ssa_ids: - ids[ssa_id:] -= 1 - return path + ssa = N + for scon in ssa_path: + con = sorted([bisect.bisect_left(ids, s) for s in scon]) + for j in reversed(con): + ids.pop(j) + ids.append(ssa) + path.append(con) + ssa += 1 + return [tuple(x) for x in path] + + # N = sum(map(len, ssa_path)) - len(ssa_path) + 1 + # ids = list(range(N)) + # ids = np.arange(1 + max(map(max, ssa_path)), dtype=np.int32) + # path = [] + # ssa = N + # for scon in ssa_path: + # con = sorted(map(ids.index, scon)) + # for j in reversed(con): + # ids.pop(j) + # ids.append(ssa) + # path.append(con) + # ssa += 1 + # return path def linear_to_ssa(path: PathType) -> PathType: @@ -1361,7 +1383,7 @@ def auto_hq( how many input arguments there are, but targeting a more generous amount of search time than ``'auto'``. """ - from .path_random import random_greedy_128 + from opt_einsum.path_random import random_greedy_128 N = len(inputs) return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit) diff --git a/opt_einsum/sharing.py b/opt_einsum/sharing.py index 7bb6e22f..721d227e 100644 --- a/opt_einsum/sharing.py +++ b/opt_einsum/sharing.py @@ -13,8 +13,8 @@ from typing import Counter as CounterType from typing import Dict, Generator, List, Optional, Tuple, Union -from .parser import alpha_canonicalize, parse_einsum_input -from .typing import ArrayType +from opt_einsum.parser import alpha_canonicalize, parse_einsum_input +from opt_einsum.typing import ArrayType CacheKeyType = Union[Tuple[str, str, int, Tuple[int, ...]], Tuple[str, int]] CacheType = Dict[CacheKeyType, ArrayType] diff --git a/opt_einsum/testing.py b/opt_einsum/testing.py new file mode 100644 index 00000000..5c41bf9a --- /dev/null +++ b/opt_einsum/testing.py @@ -0,0 +1,229 @@ +""" +Testing routines for opt_einsum. +""" + +import random +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload + +import pytest + +from opt_einsum.parser import get_symbol +from opt_einsum.typing import ArrayType, PathType, TensorShapeType + +_valid_chars = "abcdefghijklmopqABC" +_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] +_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} + + +def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]: + """ + Builds random tensor shapes for testing. + + Parameters: + string: List of tensor strings to build + dimension_dict: Dictionary of index sizes, defaults to indices size of 2-7 + + Returns + The resulting shapes. + + Examples: + ```python + >>> shapes = build_shapes('abbc', {'a': 2, 'b':3, 'c':5}) + >>> shapes + [(2, 3), (3, 3, 5), (5,)] + ``` + + """ + + if dimension_dict is None: + dimension_dict = _default_dim_dict + + shapes = [] + terms = string.split("->")[0].split(",") + for term in terms: + dims = [dimension_dict[x] for x in term] + shapes.append(tuple(dims)) + return tuple(shapes) + + +def build_views( + string: str, dimension_dict: Optional[Dict[str, int]] = None, array_function: Optional[Any] = None +) -> Tuple[ArrayType]: + """ + Builds random numpy arrays for testing. + + Parameters: + string: List of tensor strings to build + dimension_dict: Dictionary of index _sizes + array_function: Function to build the arrays, defaults to np.random.rand + + Returns + The resulting views. + + Examples: + ```python + >>> view = build_views('abbc', {'a': 2, 'b':3, 'c':5}) + >>> view[0].shape + (2, 3, 3, 5) + ``` + + """ + if array_function is None: + np = pytest.importorskip("numpy") + array_function = np.random.rand + + views = [] + for shape in build_shapes(string, dimension_dict=dimension_dict): + if shape: + views.append(array_function(*shape)) + else: + views.append(random.random()) + return tuple(views) + + +@overload +def rand_equation( + n: int, + regularity: int, + n_out: int = ..., + d_min: int = ..., + d_max: int = ..., + seed: Optional[int] = ..., + global_dim: bool = ..., + *, + return_size_dict: Literal[True], +) -> Tuple[str, PathType, Dict[str, int]]: ... + + +@overload +def rand_equation( + n: int, + regularity: int, + n_out: int = ..., + d_min: int = ..., + d_max: int = ..., + seed: Optional[int] = ..., + global_dim: bool = ..., + return_size_dict: Literal[False] = ..., +) -> Tuple[str, PathType]: ... + + +def rand_equation( + n: int, + regularity: int, + n_out: int = 0, + d_min: int = 2, + d_max: int = 9, + seed: Optional[int] = None, + global_dim: bool = False, + return_size_dict: bool = False, +) -> Union[Tuple[str, PathType, Dict[str, int]], Tuple[str, PathType]]: + """Generate a random contraction and shapes. + + Parameters: + n: Number of array arguments. + regularity: 'Regularity' of the contraction graph. This essentially determines how + many indices each tensor shares with others on average. + n_out: Number of output indices (i.e. the number of non-contracted indices). + Defaults to 0, i.e., a contraction resulting in a scalar. + d_min: Minimum dimension size. + d_max: Maximum dimension size. + seed: If not None, seed numpy's random generator with this. + global_dim: Add a global, 'broadcast', dimension to every operand. + return_size_dict: Return the mapping of indices to sizes. + + Returns: + eq: The equation string. + shapes: The array shapes. + size_dict: The dict of index sizes, only returned if ``return_size_dict=True``. + + Examples: + ```python + >>> eq, shapes = rand_equation(n=10, regularity=4, n_out=5, seed=42) + >>> eq + 'oyeqn,tmaq,skpo,vg,hxui,n,fwxmr,hitplcj,kudlgfv,rywjsb->cebda' + + >>> shapes + [(9, 5, 4, 5, 4), + (4, 4, 8, 5), + (9, 4, 6, 9), + (6, 6), + (6, 9, 7, 8), + (4,), + (9, 3, 9, 4, 9), + (6, 8, 4, 6, 8, 6, 3), + (4, 7, 8, 8, 6, 9, 6), + (9, 5, 3, 3, 9, 5)] + ``` + """ + + np = pytest.importorskip("numpy") + if seed is not None: + np.random.seed(seed) + + # total number of indices + num_inds = n * regularity // 2 + n_out + inputs = ["" for _ in range(n)] + output = [] + + size_dict = {get_symbol(i): np.random.randint(d_min, d_max + 1) for i in range(num_inds)} + + # generate a list of indices to place either once or twice + def gen(): + for i, ix in enumerate(size_dict): + # generate an outer index + if i < n_out: + output.append(ix) + yield ix + # generate a bond + else: + yield ix + yield ix + + # add the indices randomly to the inputs + for i, ix in enumerate(np.random.permutation(list(gen()))): + # make sure all inputs have at least one index + if i < n: + inputs[i] += ix + else: + # don't add any traces on same op + where = np.random.randint(0, n) + while ix in inputs[where]: + where = np.random.randint(0, n) + + inputs[where] += ix + + # possibly add the same global dim to every arg + if global_dim: + gdim = get_symbol(num_inds) + size_dict[gdim] = np.random.randint(d_min, d_max + 1) + for i in range(n): + inputs[i] += gdim + output += gdim + + # randomly transpose the output indices and form equation + output = "".join(np.random.permutation(output)) # type: ignore + eq = "{}->{}".format(",".join(inputs), output) + + # make the shapes + shapes = [tuple(size_dict[ix] for ix in op) for op in inputs] + + ret = (eq, shapes) + + if return_size_dict: + return ret + (size_dict,) + else: + return ret + + +def build_arrays_from_tuples(path: PathType) -> List[Any]: + """Build random numpy arrays from a path. + + Parameters: + path: The path to build arrays from. + + Returns: + The resulting arrays.""" + np = pytest.importorskip("numpy") + + return [np.random.rand(*x) for x in path] diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 3481f86d..90883354 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -1,62 +1,24 @@ from typing import Set -import numpy as np import pytest -from opt_einsum import backends, contract, contract_expression, helpers, sharing -from opt_einsum.contract import Shaped, infer_backend, parse_backend +from opt_einsum import backends, contract, contract_expression, sharing +from opt_einsum.contract import ArrayShaped, infer_backend, parse_backend +from opt_einsum.testing import build_views try: - import cupy - - found_cupy = True -except ImportError: - found_cupy = False - -try: - import tensorflow as tf - # needed so tensorflow doesn't allocate all gpu mem try: from tensorflow import ConfigProto + from tensorflow import Session as TFSession except ImportError: - from tf.compat.v1 import ConfigProto + from tensorflow.compat.v1 import ConfigProto + from tensorflow.compat.v1 import Session as TFSession _TF_CONFIG = ConfigProto() _TF_CONFIG.gpu_options.allow_growth = True - found_tensorflow = True except ImportError: - found_tensorflow = False - -try: - import os + pass - os.environ["MKL_THREADING_LAYER"] = "GNU" - import theano - - found_theano = True -except ImportError: - found_theano = False - -try: - import torch - - found_torch = True -except ImportError: - found_torch = False - -try: - import jax - - found_jax = True -except ImportError: - found_jax = False - -try: - import autograd - - found_autograd = True -except ImportError: - found_autograd = False tests = [ "ab,bc->ca", @@ -71,17 +33,19 @@ ] -@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("string", tests) def test_tensorflow(string: str) -> None: - views = helpers.build_views(string) + np = pytest.importorskip("numpy") + pytest.importorskip("tensorflow") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) opt = np.empty_like(ein) shps = [v.shape for v in views] expr = contract_expression(string, *shps, optimize=True) - sess = tf.Session(config=_TF_CONFIG) + sess = TFSession(config=_TF_CONFIG) with sess.as_default(): expr(*views, backend="tensorflow", out=opt) sess.close() @@ -93,9 +57,11 @@ def test_tensorflow(string: str) -> None: expr(*tensorflow_views) -@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) def test_tensorflow_with_constants(constants: Set[int]) -> None: + np = pytest.importorskip("numpy") + tf = pytest.importorskip("tensorflow") + eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -106,7 +72,7 @@ def test_tensorflow_with_constants(constants: Set[int]) -> None: expr = contract_expression(eq, *ops, constants=constants) # check tensorflow - with tf.Session(config=_TF_CONFIG).as_default(): + with TFSession(config=_TF_CONFIG).as_default(): res_got = expr(var, backend="tensorflow") assert all( array is None or infer_backend(array) == "tensorflow" for array in expr._evaluated_constants["tensorflow"] @@ -122,16 +88,18 @@ def test_tensorflow_with_constants(constants: Set[int]) -> None: assert isinstance(res_got3, tf.Tensor) -@pytest.mark.skipif(not found_tensorflow, reason="Tensorflow not installed.") @pytest.mark.parametrize("string", tests) def test_tensorflow_with_sharing(string: str) -> None: - views = helpers.build_views(string) + np = pytest.importorskip("numpy") + tf = pytest.importorskip("tensorflow") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] expr = contract_expression(string, *shps, optimize=True) - sess = tf.Session(config=_TF_CONFIG) + sess = TFSession(config=_TF_CONFIG) with sess.as_default(), sharing.shared_intermediates() as cache: tfl1 = expr(*views, backend="tensorflow") @@ -147,10 +115,12 @@ def test_tensorflow_with_sharing(string: str) -> None: assert np.allclose(ein, tfl2) -@pytest.mark.skipif(not found_theano, reason="Theano not installed.") @pytest.mark.parametrize("string", tests) def test_theano(string: str) -> None: - views = helpers.build_views(string) + np = pytest.importorskip("numpy") + theano = pytest.importorskip("theano") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -165,9 +135,11 @@ def test_theano(string: str) -> None: assert isinstance(theano_opt, theano.tensor.TensorVariable) -@pytest.mark.skipif(not found_theano, reason="theano not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) def test_theano_with_constants(constants: Set[int]) -> None: + np = pytest.importorskip("numpy") + theano = pytest.importorskip("theano") + eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -191,10 +163,12 @@ def test_theano_with_constants(constants: Set[int]) -> None: assert isinstance(res_got3, theano.tensor.TensorVariable) -@pytest.mark.skipif(not found_theano, reason="Theano not installed.") @pytest.mark.parametrize("string", tests) def test_theano_with_sharing(string: str) -> None: - views = helpers.build_views(string) + np = pytest.importorskip("numpy") + theano = pytest.importorskip("theano") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -214,10 +188,12 @@ def test_theano_with_sharing(string: str) -> None: assert np.allclose(ein, thn2) -@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") @pytest.mark.parametrize("string", tests) -def test_cupy(string: str) -> None: # pragma: no cover - views = helpers.build_views(string) +def test_cupy(string: str) -> None: + np = pytest.importorskip("numpy") # pragma: no cover + cupy = pytest.importorskip("cupy") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -233,9 +209,11 @@ def test_cupy(string: str) -> None: # pragma: no cover assert np.allclose(ein, cupy.asnumpy(cupy_opt)) -@pytest.mark.skipif(not found_cupy, reason="Cupy not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover +def test_cupy_with_constants(constants: Set[int]) -> None: + np = pytest.importorskip("numpy") # pragma: no cover + cupy = pytest.importorskip("cupy") + eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants @@ -261,10 +239,12 @@ def test_cupy_with_constants(constants: Set[int]) -> None: # pragma: no cover assert np.allclose(res_exp, res_got3.get()) -@pytest.mark.skipif(not found_jax, reason="jax not installed.") @pytest.mark.parametrize("string", tests) -def test_jax(string: str) -> None: # pragma: no cover - views = helpers.build_views(string) +def test_jax(string: str) -> None: + np = pytest.importorskip("numpy") # pragma: no cover + pytest.importorskip("jax") + + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -275,14 +255,16 @@ def test_jax(string: str) -> None: # pragma: no cover assert isinstance(opt, np.ndarray) -@pytest.mark.skipif(not found_jax, reason="jax not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) -def test_jax_with_constants(constants: Set[int]) -> None: # pragma: no cover +def test_jax_with_constants(constants: Set[int]) -> None: + jax = pytest.importorskip("jax") + key = jax.random.PRNGKey(42) + eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants - ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] - var = np.random.rand(*shapes[non_const]) + ops = [jax.random.uniform(key, shp) if i in constants else shp for i, shp in enumerate(shapes)] + var = jax.random.uniform(key, shapes[non_const]) res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) expr = contract_expression(eq, *ops, constants=constants) @@ -291,15 +273,16 @@ def test_jax_with_constants(constants: Set[int]) -> None: # pragma: no cover res_got = expr(var, backend="jax") # check jax versions of constants exist assert all(array is None or infer_backend(array).startswith("jax") for array in expr._evaluated_constants["jax"]) - - assert np.allclose(res_exp, res_got) + assert jax.numpy.sum(jax.numpy.abs(res_exp - res_got)) < 1e-8 -@pytest.mark.skipif(not found_jax, reason="jax not installed.") def test_jax_jit_gradient() -> None: + jax = pytest.importorskip("jax") + key = jax.random.PRNGKey(42) + eq = "ij,jk,kl->" shapes = (2, 3), (3, 4), (4, 2) - views = [np.random.randn(*s) for s in shapes] + views = [jax.random.uniform(key, s) for s in shapes] expr = contract_expression(eq, *shapes) x0 = expr(*views) @@ -318,8 +301,10 @@ def test_jax_jit_gradient() -> None: assert x2 < x1 -@pytest.mark.skipif(not found_autograd, reason="autograd not installed.") def test_autograd_gradient() -> None: + np = pytest.importorskip("numpy") + autograd = pytest.importorskip("autograd") + eq = "ij,jk,kl->" shapes = (2, 3), (3, 4), (4, 2) views = [np.random.randn(*s) for s in shapes] @@ -339,9 +324,10 @@ def test_autograd_gradient() -> None: @pytest.mark.parametrize("string", tests) def test_dask(string: str) -> None: + np = pytest.importorskip("numpy") da = pytest.importorskip("dask.array") - views = helpers.build_views(string) + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] expr = contract_expression(string, *shps, optimize=True) @@ -363,9 +349,10 @@ def test_dask(string: str) -> None: @pytest.mark.parametrize("string", tests) def test_sparse(string: str) -> None: + np = pytest.importorskip("numpy") sparse = pytest.importorskip("sparse") - views = helpers.build_views(string) + views = build_views(string) # sparsify views so they don't become dense during contraction for view in views: @@ -396,56 +383,56 @@ def test_sparse(string: str) -> None: assert np.allclose(ein, sparse_opt.todense()) -@pytest.mark.skipif(not found_torch, reason="Torch not installed.") @pytest.mark.parametrize("string", tests) def test_torch(string: str) -> None: + torch = pytest.importorskip("torch") - views = helpers.build_views(string) - ein = contract(string, *views, optimize=False, use_blas=False) - shps = [v.shape for v in views] + views = build_views(string, array_function=torch.rand) + ein = torch.einsum(string, *views) + shps = [v.shape for v in views] expr = contract_expression(string, *shps, optimize=True) opt = expr(*views, backend="torch") - assert np.allclose(ein, opt) + torch.testing.assert_close(ein, opt) # test non-conversion mode torch_views = [backends.to_torch(view) for view in views] torch_opt = expr(*torch_views) assert isinstance(torch_opt, torch.Tensor) - assert np.allclose(ein, torch_opt.cpu().numpy()) + torch.testing.assert_close(ein, torch_opt) -@pytest.mark.skipif(not found_torch, reason="Torch not installed.") @pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) def test_torch_with_constants(constants: Set[int]) -> None: + torch = pytest.importorskip("torch") + eq = "ij,jk,kl->li" shapes = (2, 3), (3, 4), (4, 5) (non_const,) = {0, 1, 2} - constants - ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] - var = np.random.rand(*shapes[non_const]) - res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) + ops = [torch.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] + var = torch.rand(*shapes[non_const]) + res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)), backend="torch") expr = contract_expression(eq, *ops, constants=constants) # check torch res_got = expr(var, backend="torch") assert all(array is None or infer_backend(array) == "torch" for array in expr._evaluated_constants["torch"]) - assert np.allclose(res_exp, res_got) + torch.testing.assert_close(res_exp, res_got) # check can call with numpy still - res_got2 = expr(var, backend="numpy") - assert np.allclose(res_exp, res_got2) + res_got2 = expr(var, backend="torch") + torch.testing.assert_close(res_exp, res_got2) # check torch call returns torch still res_got3 = expr(backends.to_torch(var)) assert isinstance(res_got3, torch.Tensor) - res_got3 = res_got3.numpy() if res_got3.device.type == "cpu" else res_got3.cpu().numpy() - assert np.allclose(res_exp, res_got3) + torch.testing.assert_close(res_exp, res_got3) def test_auto_backend_custom_array_no_tensordot() -> None: - x = Shaped((1, 2, 3)) + x = ArrayShaped((1, 2, 3)) # Shaped is an array-like object defined by opt_einsum - which has no TDOT assert infer_backend(x) == "opt_einsum" assert parse_backend([x], "auto") == "numpy" @@ -454,7 +441,8 @@ def test_auto_backend_custom_array_no_tensordot() -> None: @pytest.mark.parametrize("string", tests) def test_object_arrays_backend(string: str) -> None: - views = helpers.build_views(string) + np = pytest.importorskip("numpy") + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) assert ein.dtype != object diff --git a/opt_einsum/tests/test_blas.py b/opt_einsum/tests/test_blas.py index e72b8a56..28b723a5 100644 --- a/opt_einsum/tests/test_blas.py +++ b/opt_einsum/tests/test_blas.py @@ -4,10 +4,9 @@ from typing import Any -import numpy as np import pytest -from opt_einsum import blas, contract, helpers +from opt_einsum import blas, contract blas_tests = [ # DOT @@ -66,29 +65,9 @@ def test_can_blas(inp: Any, benchmark: bool) -> None: assert result == benchmark -@pytest.mark.parametrize("inp,benchmark", blas_tests) -def test_tensor_blas(inp: Any, benchmark: bool) -> None: - - # Weed out non-blas cases - if benchmark is False: - return - - tensor_strs, output, reduced_idx = inp - einsum_str = ",".join(tensor_strs) + "->" + output - - # Only binary operations should be here - if len(tensor_strs) != 2: - assert False - - view_left, view_right = helpers.build_views(einsum_str) - - einsum_result = np.einsum(einsum_str, view_left, view_right) - blas_result = blas.tensor_blas(view_left, tensor_strs[0], view_right, tensor_strs[1], output, reduced_idx) - - np.testing.assert_allclose(einsum_result, blas_result) - - def test_blas_out() -> None: + np = pytest.importorskip("numpy") + a = np.random.rand(4, 4) b = np.random.rand(4, 4) c = np.random.rand(4, 4) diff --git a/opt_einsum/tests/test_contract.py b/opt_einsum/tests/test_contract.py index 8e7cec10..30375033 100644 --- a/opt_einsum/tests/test_contract.py +++ b/opt_einsum/tests/test_contract.py @@ -4,13 +4,16 @@ from typing import Any, List -import numpy as np import pytest -from opt_einsum import contract, contract_expression, contract_path, helpers +from opt_einsum import contract, contract_expression, contract_path from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear +from opt_einsum.testing import build_views, rand_equation from opt_einsum.typing import OptimizeKind +# NumPy is required for the majority of this file +np = pytest.importorskip("numpy") + tests = [ # Test scalar-like operations "a,->a", @@ -99,7 +102,7 @@ @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) def test_compare(optimize: OptimizeKind, string: str) -> None: - views = helpers.build_views(string) + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) opt = contract(string, *views, optimize=optimize, use_blas=False) @@ -108,7 +111,7 @@ def test_compare(optimize: OptimizeKind, string: str) -> None: @pytest.mark.parametrize("string", tests) def test_drop_in_replacement(string: str) -> None: - views = helpers.build_views(string) + views = build_views(string) opt = contract(string, *views) assert np.allclose(opt, np.einsum(string, *views)) @@ -116,7 +119,7 @@ def test_drop_in_replacement(string: str) -> None: @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) def test_compare_greek(optimize: OptimizeKind, string: str) -> None: - views = helpers.build_views(string) + views = build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) @@ -130,7 +133,7 @@ def test_compare_greek(optimize: OptimizeKind, string: str) -> None: @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) def test_compare_blas(optimize: OptimizeKind, string: str) -> None: - views = helpers.build_views(string) + views = build_views(string) ein = contract(string, *views, optimize=False) opt = contract(string, *views, optimize=optimize) @@ -140,7 +143,7 @@ def test_compare_blas(optimize: OptimizeKind, string: str) -> None: @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None: - views = helpers.build_views(string) + views = build_views(string) ein = contract(string, *views, optimize=False) @@ -162,7 +165,7 @@ def test_some_non_alphabet_maintains_order() -> None: def test_printing(): string = "bbd,bda,fc,db->acf" - views = helpers.build_views(string) + views = build_views(string) ein = contract_path(string, *views) assert len(str(ein[1])) == 728 @@ -173,14 +176,14 @@ def test_printing(): @pytest.mark.parametrize("use_blas", [False, True]) @pytest.mark.parametrize("out_spec", [False, True]) def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None: - views = helpers.build_views(string) + views = build_views(string) shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] expected = contract(string, *views, optimize=False, use_blas=False) expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas) if out_spec and ("->" in string) and (string[-2:] != "->"): - (out,) = helpers.build_views(string.split("->")[1]) + (out,) = build_views(string.split("->")[1]) expr(*views, out=out) else: out = expr(*views) @@ -194,7 +197,7 @@ def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: boo def test_contract_expression_interleaved_input() -> None: x, y, z = (np.random.randn(2, 2) for _ in "xyz") - expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) # type: ignore + expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0]) xshp, yshp, zshp = ((2, 2) for _ in "xyz") expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0]) out = expr(x, y, z) @@ -214,7 +217,7 @@ def test_contract_expression_interleaved_input() -> None: ], ) def test_contract_expression_with_constants(string: str, constants: List[int]) -> None: - views = helpers.build_views(string) + views = build_views(string) expected = contract(string, *views, optimize=False, use_blas=False) shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] @@ -239,8 +242,8 @@ def test_contract_expression_with_constants(string: str, constants: List[int]) - @pytest.mark.parametrize("n_out", [0, 2, 4]) @pytest.mark.parametrize("global_dim", [False, True]) def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None: - eq, _, size_dict = helpers.rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True) - views = helpers.build_views(eq, size_dict) + eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True) + views = build_views(eq, size_dict) expected = contract(eq, *views, optimize=False) actual = contract(eq, *views, optimize=optimize) @@ -250,7 +253,7 @@ def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, glo @pytest.mark.parametrize("equation", tests) def test_linear_vs_ssa(equation: str) -> None: - views = helpers.build_views(equation) + views = build_views(equation) linear_path, _ = contract_path(equation, *views) ssa_path = linear_to_ssa(linear_path) linear_path2 = ssa_to_linear(ssa_path) diff --git a/opt_einsum/tests/test_edge_cases.py b/opt_einsum/tests/test_edge_cases.py index 80942495..48355317 100644 --- a/opt_einsum/tests/test_edge_cases.py +++ b/opt_einsum/tests/test_edge_cases.py @@ -2,12 +2,16 @@ Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths """ -import numpy as np +from typing import Any, Tuple + import pytest from opt_einsum import contract, contract_expression, contract_path from opt_einsum.typing import PathType +# NumPy is required for the majority of this file +np = pytest.importorskip("numpy") + def test_contract_expression_checks() -> None: # check optimize needed @@ -129,3 +133,20 @@ def test_pathinfo_for_empty_contraction() -> None: # some info is built lazily, so check repr assert repr(info) assert info.largest_intermediate == 1 + + +@pytest.mark.parametrize( + "expression, operands", + [ + [",,->", (5, 5.0, 2.0j)], + ["ab,->", ([[5, 5], [2.0, 1]], 2.0j)], + ["ab,bc->ac", ([[5, 5], [2.0, 1]], [[2.0, 1], [3.0, 4]])], + ["ab,->", ([[5, 5], [2.0, 1]], True)], + ], +) +def test_contract_with_assumed_shapes(expression: str, operands: Tuple[Any]) -> None: + """Test that we can contract with assumed shapes, and that the output is correct. This is required as we need to infer intermediate shape sizes.""" + + benchmark = np.einsum(expression, *operands) + result = contract(expression, *operands, optimize=True) + assert np.allclose(benchmark, result) diff --git a/opt_einsum/tests/test_input.py b/opt_einsum/tests/test_input.py index 6f1ecc13..fefbf575 100644 --- a/opt_einsum/tests/test_input.py +++ b/opt_einsum/tests/test_input.py @@ -4,14 +4,17 @@ from typing import Any -import numpy as np import pytest from opt_einsum import contract, contract_path from opt_einsum.typing import ArrayType +np = pytest.importorskip("numpy") + def build_views(string: str) -> list[ArrayType]: + """Builds random numpy arrays for testing by using a fixed size dictionary and an input string.""" + chars = "abcdefghij" sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4]) sizes = {c: s for c, s in zip(chars, sizes_array)} diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index d582ca4d..6fcd3b2e 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -2,10 +2,12 @@ Directly tests various parser utility functions. """ -import numpy as np +from typing import Any, Tuple + import pytest -from opt_einsum.parser import get_symbol, parse_einsum_input, possibly_convert_to_numpy +from opt_einsum.parser import get_shape, get_symbol, parse_einsum_input +from opt_einsum.testing import build_arrays_from_tuples def test_get_symbol() -> None: @@ -19,7 +21,7 @@ def test_get_symbol() -> None: def test_parse_einsum_input() -> None: eq = "ab,bc,cd" - ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)] + ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)]) input_subscripts, output_subscript, operands = parse_einsum_input([eq, *ops]) assert input_subscripts == eq assert output_subscript == "ad" @@ -28,7 +30,7 @@ def test_parse_einsum_input() -> None: def test_parse_einsum_input_shapes_error() -> None: eq = "ab,bc,cd" - ops = [np.random.rand(2, 3), np.random.rand(3, 4), np.random.rand(4, 5)] + ops = build_arrays_from_tuples([(2, 3), (3, 4), (4, 5)]) with pytest.raises(ValueError): _ = parse_einsum_input([eq, *ops], shapes=True) @@ -36,8 +38,37 @@ def test_parse_einsum_input_shapes_error() -> None: def test_parse_einsum_input_shapes() -> None: eq = "ab,bc,cd" - shps = [(2, 3), (3, 4), (4, 5)] - input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shps], shapes=True) + shapes = [(2, 3), (3, 4), (4, 5)] + input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) assert input_subscripts == eq assert output_subscript == "ad" - assert np.allclose([possibly_convert_to_numpy(shp) for shp in shps], operands) + assert shapes == operands + + +def test_parse_with_ellisis() -> None: + eq = "...a,ab" + shapes = [(2, 3), (3, 4)] + input_subscripts, output_subscript, operands = parse_einsum_input([eq, *shapes], shapes=True) + assert input_subscripts == "da,ab" + assert output_subscript == "db" + assert shapes == operands + + +@pytest.mark.parametrize( + "array, shape", + [ + [[5], (1,)], + [[5, 5], (2,)], + [(5, 5), (2,)], + [[[[[[5, 2]]]]], (1, 1, 1, 1, 2)], + [[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)], + ["A", tuple()], + [b"A", tuple()], + [True, tuple()], + [5, tuple()], + [5.0, tuple()], + [5.0 + 0j, tuple()], + ], +) +def test_get_shapes(array: Any, shape: Tuple[int]) -> None: + assert get_shape(array) == shape diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index 4566e8bd..70f0904e 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -4,13 +4,13 @@ """ import itertools -import sys +from concurrent.futures import ProcessPoolExecutor from typing import Any, Dict, List, Optional -import numpy as np import pytest import opt_einsum as oe +from opt_einsum.testing import build_shapes, rand_equation from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType explicit_path_tests = { @@ -127,11 +127,12 @@ def test_flop_cost() -> None: def test_bad_path_option() -> None: - with pytest.raises(KeyError): - oe.contract("a,b,c", [1], [2], [3], optimize="optimall") # type: ignore + with pytest.raises(TypeError): + oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore def test_explicit_path() -> None: + pytest.importorskip("numpy") x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)]) assert x.item() == 6 @@ -158,39 +159,39 @@ def test_memory_paths() -> None: expression = "abc,bdef,fghj,cem,mhk,ljk->adgl" - views = oe.helpers.build_views(expression) + views = build_shapes(expression) # Test tiny memory limit - path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5) + path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True) assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) - path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5) + path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True) assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)]) # Check the possibilities, greedy is capped - path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1) + path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True) assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]) - path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1) + path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True) assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)]) @pytest.mark.parametrize("alg,expression,order", path_edge_tests) def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None: - views = oe.helpers.build_views(expression) + views = build_shapes(expression) # Test tiny memory limit - path_ret = oe.contract_path(expression, *views, optimize=alg) + path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True) assert check_path(path_ret[0], order) @pytest.mark.parametrize("expression,order", path_scalar_tests) @pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS) def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None: - views = oe.helpers.build_views(expression) + views = build_shapes(expression) # Test tiny memory limit - path_ret = oe.contract_path(expression, *views, optimize=alg) + path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True) # print(path_ret[0]) assert len(path_ret[0]) == order @@ -199,11 +200,11 @@ def test_optimal_edge_cases() -> None: # Edge test5 expression = "a,ac,ab,ad,cd,bd,bc->" - edge_test4 = oe.helpers.build_views(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20}) - path, path_str = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input") + edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20}) + path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True) assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) - path, path_str = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input") + path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True) assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)]) @@ -211,12 +212,12 @@ def test_greedy_edge_cases() -> None: expression = "abc,cfd,dbe,efa" dim_dict = {k: 20 for k in expression.replace(",", "")} - tensors = oe.helpers.build_views(expression, dimension_dict=dim_dict) + tensors = build_shapes(expression, dimension_dict=dim_dict) - path, path_str = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input") + path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True) assert check_path(path, [(0, 1, 2, 3)]) - path, path_str = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1) + path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True) assert check_path(path, [(0, 1), (0, 2), (0, 1)]) @@ -250,7 +251,7 @@ def test_custom_dp_can_optimize_for_outer_products() -> None: def test_custom_dp_can_optimize_for_size() -> None: - eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) + eq, shapes = rand_equation(10, 4, seed=43) opt1 = oe.DynamicProgramming(minimize="flops") opt2 = oe.DynamicProgramming(minimize="size") @@ -263,7 +264,7 @@ def test_custom_dp_can_optimize_for_size() -> None: def test_custom_dp_can_set_cost_cap() -> None: - eq, shapes = oe.helpers.rand_equation(5, 3, seed=42) + eq, shapes = rand_equation(5, 3, seed=42) opt1 = oe.DynamicProgramming(cost_cap=True) opt2 = oe.DynamicProgramming(cost_cap=False) opt3 = oe.DynamicProgramming(cost_cap=100) @@ -286,7 +287,7 @@ def test_custom_dp_can_set_cost_cap() -> None: ], ) def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None: - eq, shapes = oe.helpers.rand_equation(10, 4, seed=43) + eq, shapes = rand_equation(10, 4, seed=43) opt = oe.DynamicProgramming(minimize=minimize) info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1] assert info.path == path @@ -295,7 +296,7 @@ def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: def test_dp_errors_when_no_contractions_found() -> None: - eq, shapes = oe.helpers.rand_equation(10, 3, seed=42) + eq, shapes = rand_equation(10, 3, seed=42) # first get the actual minimum cost opt = oe.DynamicProgramming(minimize="size") @@ -312,9 +313,11 @@ def test_dp_errors_when_no_contractions_found() -> None: @pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"]) def test_can_optimize_outer_products(optimize: OptimizeKind) -> None: - a, b, c = [np.random.randn(10, 10) for _ in range(3)] - d = np.random.randn(10, 2) - assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize)[0] == [ + + a, b, c = [(10, 10) for _ in range(3)] + d = (10, 2) + + assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [ (2, 3), (0, 2), (0, 1), @@ -326,14 +329,16 @@ def test_large_path(num_symbols: int) -> None: symbols = "".join(oe.get_symbol(i) for i in range(num_symbols)) dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4]))) expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1)) - tensors = oe.helpers.build_views(expression, dimension_dict=dimension_dict) + tensors = build_shapes(expression, dimension_dict=dimension_dict) # Check that path construction does not crash - oe.contract_path(expression, *tensors, optimize="greedy") + oe.contract_path(expression, *tensors, optimize="greedy", shapes=True) def test_custom_random_greedy() -> None: - eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) + np = pytest.importorskip("numpy") + + eq, shapes = rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): @@ -364,14 +369,16 @@ def test_custom_random_greedy() -> None: assert path_info.opt_cost == optimizer.best["flops"] # check error if we try and reuse the optimizer on a different expression - eq, shapes = oe.helpers.rand_equation(10, 4, seed=41) + eq, shapes = rand_equation(10, 4, seed=41) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) def test_custom_branchbound() -> None: - eq, shapes = oe.helpers.rand_equation(8, 4, seed=42) + np = pytest.importorskip("numpy") + + eq, shapes = rand_equation(8, 4, seed=42) views = list(map(np.ones, shapes)) optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size") @@ -391,7 +398,7 @@ def test_custom_branchbound() -> None: assert path_info.opt_cost == optimizer.best["flops"] # check error if we try and reuse the optimizer on a different expression - eq, shapes = oe.helpers.rand_equation(8, 4, seed=41) + eq, shapes = rand_equation(8, 4, seed=41) views = list(map(np.ones, shapes)) with pytest.raises(ValueError): path, path_info = oe.contract_path(eq, *views, optimize=optimizer) @@ -402,13 +409,12 @@ def test_branchbound_validation() -> None: oe.BranchBound(nbranch=0) -@pytest.mark.skipif(sys.version_info < (3, 2), reason="requires python3.2 or higher") def test_parallel_random_greedy() -> None: - from concurrent.futures import ProcessPoolExecutor + np = pytest.importorskip("numpy") pool = ProcessPoolExecutor(2) - eq, shapes = oe.helpers.rand_equation(10, 4, seed=42) + eq, shapes = rand_equation(10, 4, seed=42) views = list(map(np.ones, shapes)) optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool) @@ -448,6 +454,7 @@ def test_parallel_random_greedy() -> None: def test_custom_path_optimizer() -> None: + np = pytest.importorskip("numpy") class NaiveOptimizer(oe.paths.PathOptimizer): def __call__( @@ -460,7 +467,7 @@ def __call__( self.was_used = True return [(0, 1)] * (len(inputs) - 1) - eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3) + eq, shapes = rand_equation(5, 3, seed=42, d_max=3) views = list(map(np.ones, shapes)) exp = oe.contract(eq, *views, optimize=False) @@ -472,6 +479,8 @@ def __call__( def test_custom_random_optimizer() -> None: + np = pytest.importorskip("numpy") + class NaiveRandomOptimizer(oe.path_random.RandomOptimizer): @staticmethod def random_path( @@ -497,7 +506,7 @@ def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any: trial_args = (n, inputs, output, size_dict) return trial_fn, trial_args - eq, shapes = oe.helpers.rand_equation(5, 3, seed=42, d_max=3) + eq, shapes = rand_equation(5, 3, seed=42, d_max=3) views = list(map(np.ones, shapes)) exp = oe.contract(eq, *views, optimize=False) @@ -527,3 +536,9 @@ def custom_optimizer( path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore assert path == [(0, 1), (0, 1)] del oe.paths._PATH_OPTIONS["custom"] + + +def test_path_with_assumed_shapes() -> None: + + path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]]) + assert path == [(0, 1), (0, 1)] diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index dcf880c3..42717fbc 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -3,16 +3,25 @@ from collections import Counter from typing import Any -import numpy as np import pytest -from opt_einsum import contract, contract_expression, contract_path, get_symbol, helpers, shared_intermediates +from opt_einsum import contract, contract_expression, contract_path, get_symbol, shared_intermediates from opt_einsum.backends import to_cupy, to_torch from opt_einsum.contract import _einsum from opt_einsum.parser import parse_einsum_input from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache +from opt_einsum.testing import build_views from opt_einsum.typing import BackendType +pytest.importorskip("numpy") + +try: + import numpy as np # noqa # type: ignore + + numpy_if_found = "numpy" +except ImportError: + numpy_if_found = pytest.param("numpy", marks=[pytest.mark.skip(reason="NumPy not installed.")]) # type: ignore + try: import cupy # noqa @@ -27,7 +36,7 @@ except ImportError: torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore -backends = ["numpy", torch_if_found, cupy_if_found] +backends = [numpy_if_found, torch_if_found, cupy_if_found] equations = [ "ab,bc->ca", "abc,bcd,dea", @@ -49,7 +58,7 @@ @pytest.mark.parametrize("eq", equations) @pytest.mark.parametrize("backend", backends) def test_sharing_value(eq: str, backend: BackendType) -> None: - views = helpers.build_views(eq) + views = build_views(eq) shapes = [v.shape for v in views] expr = contract_expression(eq, *shapes) @@ -63,7 +72,7 @@ def test_sharing_value(eq: str, backend: BackendType) -> None: @pytest.mark.parametrize("backend", backends) def test_complete_sharing(backend: BackendType) -> None: eq = "ab,bc,cd->" - views = helpers.build_views(eq) + views = build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print("-" * 40) @@ -88,7 +97,7 @@ def test_complete_sharing(backend: BackendType) -> None: @pytest.mark.parametrize("backend", backends) def test_sharing_reused_cache(backend: BackendType) -> None: eq = "ab,bc,cd->" - views = helpers.build_views(eq) + views = build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print("-" * 40) @@ -114,7 +123,7 @@ def test_sharing_reused_cache(backend: BackendType) -> None: @pytest.mark.parametrize("backend", backends) def test_no_sharing_separate_cache(backend: BackendType) -> None: eq = "ab,bc,cd->" - views = helpers.build_views(eq) + views = build_views(eq) expr = contract_expression(eq, *(v.shape for v in views)) print("-" * 40) @@ -142,7 +151,7 @@ def test_no_sharing_separate_cache(backend: BackendType) -> None: @pytest.mark.parametrize("backend", backends) def test_sharing_nesting(backend: BackendType) -> None: eqs = ["ab,bc,cd->a", "ab,bc,cd->b", "ab,bc,cd->c", "ab,bc,cd->c"] - views = helpers.build_views(eqs[0]) + views = build_views(eqs[0]) shapes = [v.shape for v in views] refs: Any = weakref.WeakValueDictionary() @@ -181,9 +190,8 @@ def method2(views): @pytest.mark.parametrize("eq", equations) @pytest.mark.parametrize("backend", backends) def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None: - ops = helpers.build_views(eq) - ops = [to_backend[backend](x) for x in ops] - inputs, output, _ = parse_einsum_input([eq] + ops) + ops = tuple(to_backend[backend](x) for x in build_views(eq)) + inputs, output, _ = parse_einsum_input([eq] + list(ops)) inputs_list = inputs.split(",") print("-" * 40) @@ -211,7 +219,7 @@ def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None: @pytest.mark.parametrize("backend", backends) def test_partial_sharing(backend: BackendType) -> None: eq = "ab,bc,de->" - x, y, z1 = helpers.build_views(eq) + x, y, z1 = build_views(eq) # type: ignore z2 = 2.0 * z1 - 1.0 expr = contract_expression(eq, x.shape, y.shape, z1.shape) @@ -366,7 +374,7 @@ def test_multithreaded_sharing() -> None: from multiprocessing.pool import ThreadPool def fn(): - X, Y, Z = helpers.build_views("ab,bc,cd") + X, Y, Z = build_views("ab,bc,cd") with shared_intermediates(): contract("ab,bc,cd->a", X, Y, Z) diff --git a/opt_einsum/typing.py b/opt_einsum/typing.py index 175bb480..5057011f 100644 --- a/opt_einsum/typing.py +++ b/opt_einsum/typing.py @@ -2,13 +2,16 @@ Types used in the opt_einsum package """ +from collections import namedtuple from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union TensorShapeType = Tuple[int, ...] PathType = Collection[TensorShapeType] -ArrayType = Any # TODO +ArrayType = Any + ArrayIndexType = FrozenSet[str] +ArrayShaped = namedtuple("ArrayShaped", ["shape"]) ContractionListType = List[Tuple[Any, ArrayIndexType, str, Optional[Tuple[str, ...]], Union[str, bool]]] PathSearchFunctionType = Callable[[List[ArrayIndexType], ArrayIndexType, Dict[str, int], Optional[int]], PathType] diff --git a/setup.cfg b/setup.cfg index 4718da8f..b65d8159 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,9 @@ ignore_missing_imports = True [mypy-cupy.*] ignore_missing_imports = True +[mypy-numpy.*] +ignore_missing_imports = True + [mypy-tensorflow.*] ignore_missing_imports = True