Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy static types to qbraid-qir #150

Merged
merged 6 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ jobs:
run: |
python3 -m pip install --upgrade pip
python3 -m pip install tox>=4.2.0
- name: Check isort, black, headers
- name: Check isort, black, mypy, headers
run: |
tox -e format-check
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Types of changes:

### 🌟 Improvements
* Re-factor the `BasicQasmVisitor` and improve modularity ( [#142](https://github.com/qBraid/qbraid-qir/pull/142) )
* Add static type checking with `mypy` ( [#150](https://github.com/qBraid/qbraid-qir/pull/150) )
* Improve measurement statement parsing logic and add support for range definition and discrete set ( [#150](https://github.com/qBraid/qbraid-qir/pull/150) )

### 📜 Documentation
* Housekeeping updates for release ( [#135](https://github.com/qBraid/qbraid-qir/pull/135) )
Expand Down
16 changes: 16 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[mypy]

# Ignore missing imports
ignore_missing_imports = True

# Enable incremental mode
incremental = True

# Show error codes in output
show_error_codes = True

# Follow imports for type checking
follow_imports = normal

# Enable cache
cache_dir = .mypy_cache
6 changes: 3 additions & 3 deletions qbraid_qir/cirq/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
visitor = BasicCirqVisitor(**kwargs)
module.accept(visitor)

err = llvm_module.verify()
if err is not None:
raise CirqConversionError(err)
error = llvm_module.verify()
if error is not None:
raise CirqConversionError(error)

Check warning on line 69 in qbraid_qir/cirq/convert.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/cirq/convert.py#L69

Added line #L69 was not covered by tests
return llvm_module
10 changes: 5 additions & 5 deletions qbraid_qir/cirq/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import hashlib
from abc import ABCMeta, abstractmethod
from typing import FrozenSet, List, Optional
from typing import Optional

import cirq
from pyqir import Context, Module
Expand Down Expand Up @@ -50,7 +50,7 @@ def accept(self, visitor):


class _Register(_CircuitElement):
def __init__(self, register: FrozenSet[cirq.Qid]):
def __init__(self, register: list[cirq.Qid]):
self._register = register

def accept(self, visitor):
Expand All @@ -77,7 +77,7 @@ class CirqModule:
name (str): Name of the module.
module (Module): QIR Module instance.
num_qubits (int): Number of qubits in the circuit.
elements (List[_CircuitElement]): List of circuit elements.
elements (list[_CircuitElement]): list of circuit elements.

Example:
>>> circuit = cirq.Circuit()
Expand All @@ -90,7 +90,7 @@ def __init__(
name: str,
module: Module,
num_qubits: int,
elements: List[_CircuitElement],
elements: list[_CircuitElement],
):
self._name = name
self._module = module
Expand Down Expand Up @@ -122,7 +122,7 @@ def num_clbits(self) -> int:
def from_circuit(cls, circuit: cirq.Circuit, module: Optional[Module] = None) -> "CirqModule":
"""Class method. Constructs a CirqModule from a given cirq.Circuit object
and an optional QIR Module."""
elements: List[_CircuitElement] = []
elements: list[_CircuitElement] = []

# Register(s). Tentatively using cirq.Qid as input. Better approaches might exist tbd.
elements.append(_Register(list(circuit.all_qubits())))
Expand Down
8 changes: 4 additions & 4 deletions qbraid_qir/cirq/opsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Module mapping supported Cirq gates/operations to pyqir functions.

"""
from typing import Callable, Tuple
from typing import Callable

import cirq
import pyqir._native
Expand All @@ -38,7 +38,7 @@ def measure_y(builder, qubit, result):
pyqir._native.mz(builder, qubit, result)


PYQIR_OP_MAP = {
PYQIR_OP_MAP: dict[str, Callable] = {
# Identity Gate
"I": i,
# Single-Qubit Clifford Gates
Expand Down Expand Up @@ -71,15 +71,15 @@ def measure_y(builder, qubit, result):

def map_cirq_op_to_pyqir_callable(
operation: cirq.Operation,
) -> Tuple[Callable, str]:
) -> tuple[Callable, str]:
"""
Maps a Cirq operation to its corresponding PyQIR callable function.

Args:
operation (cirq.Operation): The Cirq operation to map.

Returns:
Tuple[Callable, str]: Tuple containing the corresponding PyQIR callable function,
tuple[Callable, str]: tuple containing the corresponding PyQIR callable function,
and a string representing the gate/operation type.

Raises:
Expand Down
10 changes: 5 additions & 5 deletions qbraid_qir/cirq/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@

"""
import itertools
from typing import List
from typing import Iterable

import cirq

from .exceptions import CirqConversionError
from .opsets import map_cirq_op_to_pyqir_callable


def _decompose_gate_op(operation: cirq.GateOperation) -> List[cirq.OP_TREE]:
def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
"""Decomposes a single Cirq gate operation into a sequence of operations
that are directly supported by PyQIR.

Args:
operation (cirq.GateOperation): The gate operation to decompose.
operation (cirq.Operation): The gate operation to decompose.

Returns:
List[cirq.OP_TREE]: A list of decomposed gate operations.
Iterable[cirq.OP_TREE]: A list of decomposed gate operations.
"""
try:
# Try converting to PyQIR. If successful, keep the operation.
Expand Down Expand Up @@ -58,7 +58,7 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
new_ops = []
for operation in moment:
if isinstance(operation, cirq.GateOperation):
decomposed_ops = _decompose_gate_op(operation)
decomposed_ops = list(_decompose_gate_op(operation))
new_ops.extend(decomposed_ops)
elif isinstance(operation, cirq.ClassicallyControlledOperation):
new_ops.append(operation)
Expand Down
28 changes: 12 additions & 16 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
import logging
from abc import ABCMeta, abstractmethod
from typing import List

import cirq
import pyqir
Expand Down Expand Up @@ -49,11 +48,11 @@ class BasicCirqVisitor(CircuitElementVisitor):
"""

def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
self._module = None
self._builder = None
self._entry_point = None
self._qubit_labels = {}
self._measured_qubits = {}
self._module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str
self._qubit_labels: dict[cirq.Qid, int] = {}
self._measured_qubits: dict = {}
self._initialize_runtime = initialize_runtime
self._record_output = record_output

Expand Down Expand Up @@ -89,20 +88,18 @@ def record_output(self, module: CirqModule) -> None:
result_ref = pyqir.result(self._module.context, i)
pyqir.rt.result_record_output(self._builder, result_ref, Constant.null(i8p))

def visit_register(self, qids: List[cirq.Qid]) -> None:
def visit_register(self, qids: list[cirq.Qid]) -> None:
logger.debug("Visiting qids '%s'", str(qids))

if not isinstance(qids, list):
raise TypeError("Parameter is not a list.")

if not all(isinstance(x, cirq.Qid) for x in qids):
raise TypeError("All elements in the list must be of type cirq.Qid.")

self._qubit_labels.update({bit: n + len(self._qubit_labels) for n, bit in enumerate(qids)})
logger.debug("Added labels for qubits %s", str(qids))

def visit_operation(self, operation: cirq.Operation) -> None:
qlabels = [self._qubit_labels.get(bit) for bit in operation.qubits]
qlabels = [self._qubit_labels[bit] for bit in operation.qubits]

qubits = [pyqir.qubit(self._module.context, n) for n in qlabels]
results = [pyqir.result(self._module.context, n) for n in qlabels]

Expand All @@ -125,7 +122,9 @@ def handle_measurement(pyqir_func):
# pylint: disable=unnecessary-lambda-assignment
if op_str in ["Rx", "Ry", "Rz"]:
pyqir_func = lambda: temp_pyqir_func(
self._builder, operation._sub_operation.gate._rads, *qubits
self._builder,
operation._sub_operation.gate._rads, # type: ignore[union-attr]
*qubits,
)
else:
pyqir_func = lambda: temp_pyqir_func(self._builder, *qubits)
Expand All @@ -149,12 +148,9 @@ def _branch(conds, pyqir_func):
if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits)
pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr]
else:
pyqir_func(self._builder, *qubits)

def ir(self) -> str:
return str(self._module)

def bitcode(self) -> bytes:
return self._module.bitcode()
45 changes: 28 additions & 17 deletions qbraid_qir/qasm3/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
Module with analysis functions for QASM3 visitor

"""
from typing import Any
from typing import Any, Optional, Union

from openqasm3.ast import (
BinaryExpression,
DiscreteSet,
Expression,
IndexExpression,
IntegerLiteral,
RangeDefinition,
Expand All @@ -30,7 +32,7 @@ class Qasm3Analyzer:
"""Class with utility functions for analyzing QASM3 elements"""

@staticmethod
def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> None:
def analyze_classical_indices(indices: list[Any], var: Variable) -> list:
"""Validate the indices for a classical variable.

Args:
Expand All @@ -45,18 +47,21 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
"""
indices_list = []
var_name = var.name
var_dimensions = var.dims
var_dimensions: Optional[list[int]] = var.dims

if not var_dimensions:
if var_dimensions is None or len(var_dimensions) == 0:
raise_qasm3_error(
message=f"Indexing error. Variable {var_name} is not an array",
err_type=Qasm3ConversionError,
span=indices[0].span,
)
if len(indices) != len(var_dimensions):
if isinstance(indices, DiscreteSet):
indices = indices.values

if len(indices) != len(var_dimensions): # type: ignore[arg-type]
raise_qasm3_error(
message=f"Invalid number of indices for variable {var_name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}",
f"Expected {len(var_dimensions)} but got {len(indices)}", # type: ignore[arg-type]
err_type=Qasm3ConversionError,
span=indices[0].span,
)
Expand All @@ -78,7 +83,7 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
span=index.span,
)
index_value = index.value
curr_dimension = var_dimensions[i]
curr_dimension = var_dimensions[i] # type: ignore[index]

if index_value < 0 or index_value >= curr_dimension:
raise_qasm3_error(
Expand All @@ -92,29 +97,35 @@ def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> N
return indices_list

@staticmethod
def analyze_index_expression(index_expr: IndexExpression) -> tuple[str, list[list]]:
def analyze_index_expression(
index_expr: IndexExpression,
) -> tuple[str, list[Union[Any, Expression, RangeDefinition]]]:
"""analyze an index expression to get the variable name and indices.

Args:
index_expr (IndexExpression): The index expression to analyze.

Returns:
tuple[str, list[list]]: The variable name and indices.
tuple[str, list[Any]]: The variable name and indices.

"""
indices = []
var_name = None
indices: list[Any] = []
var_name = ""
comma_separated = False

if isinstance(index_expr.collection, IndexExpression):
while isinstance(index_expr, IndexExpression):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
if isinstance(index_expr.index, list):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
else:
comma_separated = True
indices = index_expr.index

var_name = index_expr.collection.name if comma_separated else index_expr.name
indices = index_expr.index # type: ignore[assignment]
var_name = (
index_expr.collection.name # type: ignore[attr-defined]
if comma_separated
else index_expr.name # type: ignore[attr-defined]
)
if not comma_separated:
indices = indices[::-1]

Expand Down Expand Up @@ -169,7 +180,7 @@ def analyse_branch_condition(condition: Any) -> bool:
err_type=Qasm3ConversionError,
span=condition.span,
)
return condition.rhs.value != 0
return condition.rhs.value != 0 # type: ignore[attr-defined]
if not isinstance(condition, IndexExpression):
raise_qasm3_error(
message=(
Expand Down
Loading