From 807a45d9190817df6bc06df7003f258aa4b0cc53 Mon Sep 17 00:00:00 2001 From: Daniel Smith Date: Sun, 19 May 2024 21:35:14 -0400 Subject: [PATCH] All tests passing --- opt_einsum/backends/object_arrays.py | 4 +++- opt_einsum/backends/torch.py | 2 +- opt_einsum/contract.py | 24 ++++++++++++++++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index 308cb671..eae0e92f 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -7,8 +7,10 @@ import numpy as np +from opt_einsum.typing import ArrayType -def object_einsum(eq, *arrays): + +def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: """A ``einsum`` implementation for ``numpy`` arrays with object dtype. The loop is performed in python, meaning the objects themselves need only to implement ``__mul__`` and ``__add__`` for the contraction to be diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index ed92fd53..c3ae9b5e 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -41,7 +41,7 @@ def transpose(a, axes): return a.permute(*axes) -def einsum(equation, *operands): +def einsum(equation, *operands, **kwargs): """Variadic version of torch.einsum to match numpy api.""" # rename symbols to support PyTorch 0.4.1 and earlier, # which allow only symbols a-z. diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index bc86b80a..7270ce47 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -123,6 +123,26 @@ def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Opti return int(memory_limit) +def _filter_einsum_defaults(kwargs: Dict[Literal["order", "casting", "dtype", "out"], Any]) -> Dict[str, Any]: + """Filters out default contract kwargs to pass to various backends.""" + kwargs = kwargs.copy() + ret = {} + if (order := kwargs.pop("order", "K")) != "K": + ret["order"] = order + + if (casting := kwargs.pop("casting", "safe")) != "safe": + ret["casting"] = casting + + if (dtype := kwargs.pop("dtype", None)) is not None: + ret["dtype"] = dtype + + if (out := kwargs.pop("out", None)) is not None: + ret["out"] = out + + ret.update(kwargs) + return ret + + @overload def contract_path( subscripts: str, @@ -330,7 +350,7 @@ def contract_path( path_tuple = [tuple(range(num_ops))] elif isinstance(optimize, paths.PathOptimizer): # Custom path optimizer supplied - path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) # type: ignore + path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) else: path_optimizer = paths.get_path_fn(optimize) path_tuple = path_optimizer(input_sets, output_set, size_dict, memory_arg) @@ -427,6 +447,7 @@ def _einsum(*operands: Any, **kwargs: Any) -> ArrayType: einsum_str = parser.convert_to_valid_einsum_chars(einsum_str) + kwargs = _filter_einsum_defaults(kwargs) return fn(einsum_str, *operands, **kwargs) @@ -906,7 +927,6 @@ def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType: return self._contract(ops, out=out, backend=backend, evaluate_constants=evaluate_constants) except ValueError as err: - raise original_msg = str(err.args) if err.args else "" msg = ( "Internal error while evaluating `ContractExpression`. Note that few checks are performed"