Skip to content

Commit

Permalink
Get dtype from the set target (#2022)
Browse files Browse the repository at this point in the history
* Determine precision based on the set target

* adding a test

* fixing spelling in comment
  • Loading branch information
sacpis authored and bettinaheim committed Aug 5, 2024
1 parent 59a0843 commit 5f63a89
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/cudaq/runtime/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #
from ..mlir._mlir_libs._quakeDialects import cudaq_runtime


def to_cupy(state, dtype=None):
Expand All @@ -17,8 +18,11 @@ def to_cupy(state, dtype=None):
except ImportError:
print('to_cupy not supported, CuPy not available. Please install CuPy.')

if dtype == None:
dtype = cp.complex64
if dtype is None:
# Determine the correct data type based on the cudaq target's precision
target = cudaq_runtime.get_target()
precision = target.get_precision()
dtype = cp.complex128 if precision == cudaq_runtime.SimulationPrecision.fp64 else cp.complex64

if not state.is_on_gpu():
raise RuntimeError(
Expand Down
14 changes: 14 additions & 0 deletions python/tests/builder/test_cupy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def test_cupy_to_state():
assert np.isclose(result, 1.0, atol=1e-3)


def test_cupy_to_state_without_dtype():
cp_data = cp.array([.707107, 0, 0, .707107])
state_from_cupy = cudaq.State.from_data(cp_data)
state_from_cupy.dump()
kernel = cudaq.make_kernel()
q = kernel.qalloc(2)
kernel.h(q[0])
kernel.cx(q[0], q[1])
# State is on the GPU, this is nvidia target
state = cudaq.get_state(kernel)
result = state.overlap(state_from_cupy)
assert np.isclose(result, 1.0, atol=1e-3)


# leave for gdb debugging
if __name__ == "__main__":
loc = os.path.abspath(__file__)
Expand Down

0 comments on commit 5f63a89

Please sign in to comment.