Skip to content

Commit

Permalink
All tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed May 20, 2024
1 parent 96c3b22 commit 807a45d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
4 changes: 3 additions & 1 deletion opt_einsum/backends/object_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import numpy as np

from opt_einsum.typing import ArrayType

def object_einsum(eq, *arrays):

def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
"""A ``einsum`` implementation for ``numpy`` arrays with object dtype.
The loop is performed in python, meaning the objects themselves need
only to implement ``__mul__`` and ``__add__`` for the contraction to be
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transpose(a, axes):
return a.permute(*axes)


def einsum(equation, *operands):
def einsum(equation, *operands, **kwargs):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
Expand Down
24 changes: 22 additions & 2 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Opti
return int(memory_limit)


def _filter_einsum_defaults(kwargs: Dict[Literal["order", "casting", "dtype", "out"], Any]) -> Dict[str, Any]:
"""Filters out default contract kwargs to pass to various backends."""
kwargs = kwargs.copy()
ret = {}
if (order := kwargs.pop("order", "K")) != "K":
ret["order"] = order

if (casting := kwargs.pop("casting", "safe")) != "safe":
ret["casting"] = casting

if (dtype := kwargs.pop("dtype", None)) is not None:
ret["dtype"] = dtype

if (out := kwargs.pop("out", None)) is not None:
ret["out"] = out

ret.update(kwargs)
return ret


@overload
def contract_path(
subscripts: str,
Expand Down Expand Up @@ -330,7 +350,7 @@ def contract_path(
path_tuple = [tuple(range(num_ops))]
elif isinstance(optimize, paths.PathOptimizer):
# Custom path optimizer supplied
path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) # type: ignore
path_tuple = optimize(input_sets, output_set, size_dict, memory_arg)
else:
path_optimizer = paths.get_path_fn(optimize)
path_tuple = path_optimizer(input_sets, output_set, size_dict, memory_arg)
Expand Down Expand Up @@ -427,6 +447,7 @@ def _einsum(*operands: Any, **kwargs: Any) -> ArrayType:

einsum_str = parser.convert_to_valid_einsum_chars(einsum_str)

kwargs = _filter_einsum_defaults(kwargs)
return fn(einsum_str, *operands, **kwargs)


Expand Down Expand Up @@ -906,7 +927,6 @@ def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType:
return self._contract(ops, out=out, backend=backend, evaluate_constants=evaluate_constants)

except ValueError as err:
raise
original_msg = str(err.args) if err.args else ""
msg = (
"Internal error while evaluating `ContractExpression`. Note that few checks are performed"
Expand Down

0 comments on commit 807a45d

Please sign in to comment.