Skip to content

Commit

Permalink
tests: test non-contiguous arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
janden committed Aug 23, 2023
1 parent 8df027a commit ab64862
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions python/cufinufft/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MS = [256, 1024, 4096]
TOLS = [1e-2, 1e-3]
OUTPUT_ARGS = [False, True]
CONTIGUOUS = [False, True]

def _transfer_funcs(module_name):
if module_name == "pycuda":
Expand Down Expand Up @@ -83,25 +84,39 @@ def test_type1(framework, dtype, shape, M, tol, output_arg):
@pytest.mark.parametrize("M", MS)
@pytest.mark.parametrize("tol", TOLS)
@pytest.mark.parametrize("output_arg", OUTPUT_ARGS)
def test_type2(framework, dtype, shape, M, tol, output_arg):
@pytest.mark.parametrize("contiguous", CONTIGUOUS)
def test_type2(framework, dtype, shape, M, tol, output_arg, contiguous):
if framework == "pycuda" and not contiguous:
pytest.skip("Pycuda does not support copy to contiguous")

to_gpu, to_cpu = _transfer_funcs(framework)

complex_dtype = utils._complex_dtype(dtype)

k, fk = utils.type2_problem(dtype, shape, M)

plan = Plan(2, shape, eps=tol, dtype=complex_dtype)

if not contiguous and len(shape) > 1:
fk = fk.copy(order="F")

def _execute(*args, **kwargs):
with pytest.warns(UserWarning, match="requirement: C. Copying"):
return plan.execute(*args, **kwargs)
else:
def _execute(*args, **kwargs):
return plan.execute(*args, **kwargs)

k_gpu = to_gpu(k)
fk_gpu = to_gpu(fk)

plan = Plan(2, shape, eps=tol, dtype=complex_dtype)

plan.setpts(*k_gpu)

if output_arg:
c_gpu = _compat.array_empty_like(fk_gpu, (M,), dtype=complex_dtype)
plan.execute(fk_gpu, out=c_gpu)
_execute(fk_gpu, out=c_gpu)
else:
c_gpu = plan.execute(fk_gpu)
c_gpu = _execute(fk_gpu)

c = to_cpu(c_gpu)

Expand Down

0 comments on commit ab64862

Please sign in to comment.