Skip to content

Commit

Permalink
add {torch|tensorflow}.indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 14, 2024
1 parent b367fa8 commit 3b8efff
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1881,8 +1881,15 @@ def tensorflow_diag(x, **kwargs):
raise ValueError("Input must be 1- or 2-d.")


def tensorflow_indices(dimensions):
_meshgrid = get_lib_fn("tensorflow", "meshgrid")
_arange = get_lib_fn("tensorflow", "arange")
return _meshgrid(*map(_arange, dimensions))


_FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy
_FUNCS["tensorflow", "diag"] = tensorflow_diag
_FUNCS["tensorflow", "indices"] = tensorflow_indices

_SUBMODULE_ALIASES["tensorflow", "log"] = "tensorflow.math"
_SUBMODULE_ALIASES["tensorflow", "conj"] = "tensorflow.math"
Expand All @@ -1893,6 +1900,7 @@ def tensorflow_diag(x, **kwargs):
_SUBMODULE_ALIASES["tensorflow", "trace"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "tril"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "triu"] = "tensorflow.linalg"
_SUBMODULE_ALIASES["tensorflow", "allclose"] = "tensorflow.experimental.numpy"

_FUNC_ALIASES["tensorflow", "sum"] = "reduce_sum"
_FUNC_ALIASES["tensorflow", "min"] = "reduce_min"
Expand All @@ -1905,6 +1913,7 @@ def tensorflow_diag(x, **kwargs):
_FUNC_ALIASES["tensorflow", "tril"] = "band_part"
_FUNC_ALIASES["tensorflow", "triu"] = "band_part"
_FUNC_ALIASES["tensorflow", "array"] = "convert_to_tensor"
_FUNC_ALIASES["tensorflow", "asarray"] = "convert_to_tensor"
_FUNC_ALIASES["tensorflow", "astype"] = "cast"
_FUNC_ALIASES["tensorflow", "power"] = "pow"
_FUNC_ALIASES["tensorflow", "take"] = "gather"
Expand Down Expand Up @@ -2105,6 +2114,12 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
return numpy_like


def torch_indices(dimensions):
_meshgrid = get_lib_fn("torch", "meshgrid")
_arange = get_lib_fn("torch", "arange")
return _meshgrid(*map(_arange, dimensions), indexing="ij")


_FUNCS["torch", "pad"] = torch_pad
_FUNCS["torch", "real"] = torch_real
_FUNCS["torch", "imag"] = torch_imag
Expand All @@ -2114,8 +2129,10 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_FUNCS["torch", "transpose"] = torch_transpose
_FUNCS["torch", "count_nonzero"] = torch_count_nonzero
_FUNCS["torch", "get_dtype_name"] = torch_get_dtype_name
_FUNCS["torch", "indices"] = torch_indices

_FUNC_ALIASES["torch", "array"] = "tensor"
_FUNC_ALIASES["torch", "asarray"] = "as_tensor"
_FUNC_ALIASES["torch", "clip"] = "clamp"
_FUNC_ALIASES["torch", "concatenate"] = "cat"
_FUNC_ALIASES["torch", "conjugate"] = "conj"
Expand Down

0 comments on commit 3b8efff

Please sign in to comment.