From 3e04dbbd1d18d4ae7a30e2064c5176e055b4a84a Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Thu, 9 May 2024 17:39:34 -0700 Subject: [PATCH] pass dtype/device to full array creation routines only --- autoray/autoray.py | 9 ++++--- pyproject.toml | 1 - tests/test_autoray.py | 62 ++++++++++++++++++++++++------------------- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/autoray/autoray.py b/autoray/autoray.py index a8b4efd..0081306 100644 --- a/autoray/autoray.py +++ b/autoray/autoray.py @@ -338,16 +338,17 @@ def infer_backend_multi(*arrays): # the set of functions that create new arrays, with `dtype` and possibly # `device` kwargs, that should be inferred from the like argument _CREATION_ROUTINES = { - "arange", "empty", "eye", "full", - "geomspace", "identity", - "linspace", - "logspace", "ones", "zeros", + # TODO: should these be included? + # "arange", + # "geomspace", + # "linspace", + # "logspace", } # cache for whether backends have a device attribute diff --git a/pyproject.toml b/pyproject.toml index 1e1d820..b07f3c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,6 @@ write_to = "autoray/_version.py" [tool.pytest.ini_options] testpaths = "tests" -addopts = "--cov=autoray --cov-report term-missing --cov-report xml:coverage.xml --verbose --durations=10" filterwarnings = "once" [tool.coverage.run] diff --git a/tests/test_autoray.py b/tests/test_autoray.py index dea0321..908aaac 100644 --- a/tests/test_autoray.py +++ b/tests/test_autoray.py @@ -786,18 +786,6 @@ def check_array_dtypes(x, y): "dtype", ["float32", "float64", "complex64", "complex128"] ) class TestCreationRoutines: - def test_arange_passes_dtype_device(self, backend, dtype): - if backend in ("sparse",): - pytest.xfail("Sparse doesn't support arange yet.") - if backend == "torch" and "complex" in dtype: - pytest.xfail("torch.arange doesn't support complex numbers yet.") - if backend == "tensorflow" and "complex" in dtype: - pytest.xfail("torch.arange doesn't support complex numbers yet.") - - x = gen_rand((1,), backend, dtype) - y = ar.do("arange", 1, 10, like=x) - check_array_dtypes(x, y) - def test_empty_passes_dtype_device(self, backend, dtype): if backend in ("tensorflow",): pytest.xfail(f"{backend} doesn't support empty yet.") @@ -822,21 +810,6 @@ def test_identity_passes_dtype_device(self, backend, dtype): y = ar.do("identity", 4, like=x) check_array_dtypes(x, y) - def test_linspace_passes_dtype_device(self, backend, dtype): - if backend in ("sparse", "tensorflow"): - pytest.xfail(f"{backend} doesn't support linspace yet.") - x = gen_rand((1,), backend, dtype) - y = ar.do("linspace", 10, 20, 11, like=x) - check_array_dtypes(x, y) - - def test_logspace_passes_dtype_device(self, backend, dtype): - if backend in ("sparse", "tensorflow"): - pytest.xfail(f"{backend} doesn't support logspace yet.") - x = gen_rand((1,), backend, dtype) - if backend not in {"dask"}: - y = ar.do("logspace", 10, 20, 11, like=x) - check_array_dtypes(x, y) - def test_ones_passes_dtype_device(self, backend, dtype): x = gen_rand((1,), backend, dtype) y = ar.do("ones", (2, 3), like=x) @@ -846,3 +819,38 @@ def test_zeros_passes_dtype_device(self, backend, dtype): x = gen_rand((1,), backend, dtype) y = ar.do("zeros", (2, 3), like=x) check_array_dtypes(x, y) + + # def test_arange_passes_dtype_device(self, backend, dtype): + # if backend in ("sparse",): + # pytest.xfail("Sparse doesn't support arange yet.") + # if backend == "torch" and "complex" in dtype: + # pytest.xfail("torch.arange doesn't support complex numbers yet.") + # if backend == "tensorflow" and "complex" in dtype: + # pytest.xfail("torch.arange doesn't support complex numbers yet.") + + # x = gen_rand((1,), backend, dtype) + # y = ar.do("arange", 1, 10, like=x) + # check_array_dtypes(x, y) + + # def test_linspace_passes_dtype_device(self, backend, dtype): + # if backend in ("sparse", "tensorflow"): + # pytest.xfail(f"{backend} doesn't support linspace yet.") + # x = gen_rand((1,), backend, dtype) + # y = ar.do("linspace", 10, 20, 11, like=x) + # check_array_dtypes(x, y) + + # def test_logspace_passes_dtype_device(self, backend, dtype): + # if backend in ("sparse", "tensorflow"): + # pytest.xfail(f"{backend} doesn't support logspace yet.") + # x = gen_rand((1,), backend, dtype) + # if backend not in {"dask"}: + # y = ar.do("logspace", 10, 20, 11, like=x) + # check_array_dtypes(x, y) + + # def test_geomspace_passes_dtype_device(self, backend, dtype): + # if backend in ("sparse", "tensorflow"): + # pytest.xfail(f"{backend} doesn't support logspace yet.") + # x = gen_rand((1,), backend, dtype) + # if backend not in {"dask"}: + # y = ar.do("logspace", 10, 20, 11, like=x) + # check_array_dtypes(x, y)