Skip to content

Commit

Permalink
pass dtype/device to full array creation routines only
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed May 10, 2024
1 parent be12e8f commit 3e04dbb
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
9 changes: 5 additions & 4 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,17 @@ def infer_backend_multi(*arrays):
# the set of functions that create new arrays, with `dtype` and possibly
# `device` kwargs, that should be inferred from the like argument
_CREATION_ROUTINES = {
"arange",
"empty",
"eye",
"full",
"geomspace",
"identity",
"linspace",
"logspace",
"ones",
"zeros",
# TODO: should these be included?
# "arange",
# "geomspace",
# "linspace",
# "logspace",
}

# cache for whether backends have a device attribute
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ write_to = "autoray/_version.py"

[tool.pytest.ini_options]
testpaths = "tests"
addopts = "--cov=autoray --cov-report term-missing --cov-report xml:coverage.xml --verbose --durations=10"
filterwarnings = "once"

[tool.coverage.run]
Expand Down
62 changes: 35 additions & 27 deletions tests/test_autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,18 +786,6 @@ def check_array_dtypes(x, y):
"dtype", ["float32", "float64", "complex64", "complex128"]
)
class TestCreationRoutines:
def test_arange_passes_dtype_device(self, backend, dtype):
if backend in ("sparse",):
pytest.xfail("Sparse doesn't support arange yet.")
if backend == "torch" and "complex" in dtype:
pytest.xfail("torch.arange doesn't support complex numbers yet.")
if backend == "tensorflow" and "complex" in dtype:
pytest.xfail("torch.arange doesn't support complex numbers yet.")

x = gen_rand((1,), backend, dtype)
y = ar.do("arange", 1, 10, like=x)
check_array_dtypes(x, y)

def test_empty_passes_dtype_device(self, backend, dtype):
if backend in ("tensorflow",):
pytest.xfail(f"{backend} doesn't support empty yet.")
Expand All @@ -822,21 +810,6 @@ def test_identity_passes_dtype_device(self, backend, dtype):
y = ar.do("identity", 4, like=x)
check_array_dtypes(x, y)

def test_linspace_passes_dtype_device(self, backend, dtype):
if backend in ("sparse", "tensorflow"):
pytest.xfail(f"{backend} doesn't support linspace yet.")
x = gen_rand((1,), backend, dtype)
y = ar.do("linspace", 10, 20, 11, like=x)
check_array_dtypes(x, y)

def test_logspace_passes_dtype_device(self, backend, dtype):
if backend in ("sparse", "tensorflow"):
pytest.xfail(f"{backend} doesn't support logspace yet.")
x = gen_rand((1,), backend, dtype)
if backend not in {"dask"}:
y = ar.do("logspace", 10, 20, 11, like=x)
check_array_dtypes(x, y)

def test_ones_passes_dtype_device(self, backend, dtype):
x = gen_rand((1,), backend, dtype)
y = ar.do("ones", (2, 3), like=x)
Expand All @@ -846,3 +819,38 @@ def test_zeros_passes_dtype_device(self, backend, dtype):
x = gen_rand((1,), backend, dtype)
y = ar.do("zeros", (2, 3), like=x)
check_array_dtypes(x, y)

# def test_arange_passes_dtype_device(self, backend, dtype):
# if backend in ("sparse",):
# pytest.xfail("Sparse doesn't support arange yet.")
# if backend == "torch" and "complex" in dtype:
# pytest.xfail("torch.arange doesn't support complex numbers yet.")
# if backend == "tensorflow" and "complex" in dtype:
# pytest.xfail("torch.arange doesn't support complex numbers yet.")

# x = gen_rand((1,), backend, dtype)
# y = ar.do("arange", 1, 10, like=x)
# check_array_dtypes(x, y)

# def test_linspace_passes_dtype_device(self, backend, dtype):
# if backend in ("sparse", "tensorflow"):
# pytest.xfail(f"{backend} doesn't support linspace yet.")
# x = gen_rand((1,), backend, dtype)
# y = ar.do("linspace", 10, 20, 11, like=x)
# check_array_dtypes(x, y)

# def test_logspace_passes_dtype_device(self, backend, dtype):
# if backend in ("sparse", "tensorflow"):
# pytest.xfail(f"{backend} doesn't support logspace yet.")
# x = gen_rand((1,), backend, dtype)
# if backend not in {"dask"}:
# y = ar.do("logspace", 10, 20, 11, like=x)
# check_array_dtypes(x, y)

# def test_geomspace_passes_dtype_device(self, backend, dtype):
# if backend in ("sparse", "tensorflow"):
# pytest.xfail(f"{backend} doesn't support logspace yet.")
# x = gen_rand((1,), backend, dtype)
# if backend not in {"dask"}:
# y = ar.do("logspace", 10, 20, 11, like=x)
# check_array_dtypes(x, y)

0 comments on commit 3e04dbb

Please sign in to comment.