Skip to content

Commit

Permalink
update lazy imports, add linalg tests (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanhill1 authored Oct 18, 2024
1 parent fccfeee commit ecaa681
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Physics"
]
Expand Down
31 changes: 24 additions & 7 deletions qbraid_qir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
QirConversionError
"""
import importlib
from typing import TYPE_CHECKING

from ._version import __version__
from .exceptions import QbraidQirError, QirConversionError
from .serialization import dumps
Expand All @@ -41,20 +44,34 @@
"QbraidQirError",
"QirConversionError",
"dumps",
"qasm3_to_qir",
"cirq_to_qir",
]

_lazy_mods = ["cirq", "qasm3"]
_lazy = {"cirq": "cirq_to_qir", "qasm3": "qasm3_to_qir"}

if TYPE_CHECKING:
from .cirq import cirq_to_qir
from .qasm3 import qasm3_to_qir


def __getattr__(name):
if name in _lazy_mods:
import importlib # pylint: disable=import-outside-toplevel
for mod_name, objects in _lazy.items():
if name == mod_name:
module = importlib.import_module(f".{mod_name}", __name__)
globals()[mod_name] = module
return module

if name in objects:
module = importlib.import_module(f".{mod_name}", __name__)
obj = getattr(module, name)
globals()[name] = obj
return obj

module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
return module
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__():
return sorted(__all__ + _lazy_mods)
return sorted(
__all__ + list(_lazy.keys()) + [item for sublist in _lazy.values() for item in sublist]
)
71 changes: 70 additions & 1 deletion tests/qasm3_qir/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
"""
import numpy as np

from qbraid_qir.qasm3.linalg import _kak_canonicalize_vector
from qbraid_qir.qasm3.linalg import (
_block_diag,
_helper_svd,
_kak_canonicalize_vector,
_orthogonal_diagonalize,
_so4_to_su2,
kak_decomposition_angles,
orthogonal_bidiagonalize,
)


def test_kak_canonicalize_vector():
Expand All @@ -26,3 +34,64 @@ def test_kak_canonicalize_vector():
x, y, z = 1, 2, 1
result = _kak_canonicalize_vector(x, y, z)
assert result["single_qubit_operations_before"][0][0][0] == -np.sqrt(2) / 2


def test_helper_svd():
"""Test _helper_svd function."""
mat = np.random.rand(4, 4)
u, s, vh = _helper_svd(mat)
assert np.allclose(np.dot(u, np.dot(np.diag(s), vh)), mat)

mat_empty = np.array([[]])
u, s, vh = _helper_svd(mat_empty)
assert u.shape == (0, 0)
assert vh.shape == (0, 0)
assert len(s) == 0


def test_block_diag():
"""Test block diagonalization of matrices."""
a = np.random.rand(2, 2)
b = np.random.rand(3, 3)
res = _block_diag(a, b)

assert res.shape == (5, 5)
assert np.allclose(res[:2, :2], a)
assert np.allclose(res[2:, 2:], b)


def test_orthogonal_diagonalize():
"""Test orthogonal diagonalization of matrices."""
mat1 = np.eye(3)
mat2 = np.diag([1, 2, 3])
p = _orthogonal_diagonalize(mat1, mat2)

assert np.allclose(np.dot(p.T, np.dot(mat1, p)), np.eye(3))


def test_orthogonal_bidiagonalize():
"""Test orthogonal bidiagonalization of matrices."""
mat1 = np.random.rand(4, 4)
mat2 = np.random.rand(4, 4)
left, right = orthogonal_bidiagonalize(mat1, mat2)

assert left.shape == (4, 4)
assert right.shape == (4, 4)


def test_so4_to_su2():
"""Test SO4 to SU2 conversion."""
mat = np.eye(4)
a, b = _so4_to_su2(mat)

assert a.shape == (2, 2)
assert b.shape == (2, 2)


def test_kak_decomposition_angles():
"""Test KAK decomposition angles."""
mat = np.eye(4)
angles = kak_decomposition_angles(mat)

assert len(angles) == 4
assert all(len(a) == 3 for a in angles)

0 comments on commit ecaa681

Please sign in to comment.