diff --git a/python/cudaq/runtime/state.py b/python/cudaq/runtime/state.py index 5d6746c7e9..b95b768be2 100644 --- a/python/cudaq/runtime/state.py +++ b/python/cudaq/runtime/state.py @@ -5,6 +5,7 @@ # This source code and the accompanying materials are made available under # # the terms of the Apache License 2.0 which accompanies this distribution. # # ============================================================================ # +from ..mlir._mlir_libs._quakeDialects import cudaq_runtime def to_cupy(state, dtype=None): @@ -17,8 +18,11 @@ def to_cupy(state, dtype=None): except ImportError: print('to_cupy not supported, CuPy not available. Please install CuPy.') - if dtype == None: - dtype = cp.complex64 + if dtype is None: + # Determine the correct data type based on the cudaq target's precision + target = cudaq_runtime.get_target() + precision = target.get_precision() + dtype = cp.complex128 if precision == cudaq_runtime.SimulationPrecision.fp64 else cp.complex64 if not state.is_on_gpu(): raise RuntimeError( diff --git a/python/tests/builder/test_cupy_integration.py b/python/tests/builder/test_cupy_integration.py index 320ea4ffdf..0ad5ca3f67 100644 --- a/python/tests/builder/test_cupy_integration.py +++ b/python/tests/builder/test_cupy_integration.py @@ -107,6 +107,20 @@ def test_cupy_to_state(): assert np.isclose(result, 1.0, atol=1e-3) +def test_cupy_to_state_without_dtype(): + cp_data = cp.array([.707107, 0, 0, .707107]) + state_from_cupy = cudaq.State.from_data(cp_data) + state_from_cupy.dump() + kernel = cudaq.make_kernel() + q = kernel.qalloc(2) + kernel.h(q[0]) + kernel.cx(q[0], q[1]) + # State is on the GPU, this is nvidia target + state = cudaq.get_state(kernel) + result = state.overlap(state_from_cupy) + assert np.isclose(result, 1.0, atol=1e-3) + + # leave for gdb debugging if __name__ == "__main__": loc = os.path.abspath(__file__)