Skip to content

Commit

Permalink
complete static types
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Aug 28, 2024
1 parent 9899fd7 commit 7b4b142
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 59 deletions.
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,3 @@ follow_imports = normal

# Enable cache
cache_dir = .mypy_cache

# TODO: fix typing for cirq
[mypy-qbraid_qir.cirq.*]
ignore_errors = True
6 changes: 3 additions & 3 deletions qbraid_qir/cirq/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def cirq_to_qir(circuit: cirq.Circuit, name: Optional[str] = None, **kwargs) ->
visitor = BasicCirqVisitor(**kwargs)
module.accept(visitor)

err = llvm_module.verify()
if err is not None:
raise CirqConversionError(err)
error = llvm_module.verify()
if error is not None:
raise CirqConversionError(error)

Check warning on line 69 in qbraid_qir/cirq/convert.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/cirq/convert.py#L69

Added line #L69 was not covered by tests
return llvm_module
10 changes: 5 additions & 5 deletions qbraid_qir/cirq/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import hashlib
from abc import ABCMeta, abstractmethod
from typing import FrozenSet, List, Optional
from typing import Optional

import cirq
from pyqir import Context, Module
Expand Down Expand Up @@ -50,7 +50,7 @@ def accept(self, visitor):


class _Register(_CircuitElement):
def __init__(self, register: FrozenSet[cirq.Qid]):
def __init__(self, register: list[cirq.Qid]):
self._register = register

def accept(self, visitor):
Expand All @@ -77,7 +77,7 @@ class CirqModule:
name (str): Name of the module.
module (Module): QIR Module instance.
num_qubits (int): Number of qubits in the circuit.
elements (List[_CircuitElement]): List of circuit elements.
elements (list[_CircuitElement]): list of circuit elements.
Example:
>>> circuit = cirq.Circuit()
Expand All @@ -90,7 +90,7 @@ def __init__(
name: str,
module: Module,
num_qubits: int,
elements: List[_CircuitElement],
elements: list[_CircuitElement],
):
self._name = name
self._module = module
Expand Down Expand Up @@ -122,7 +122,7 @@ def num_clbits(self) -> int:
def from_circuit(cls, circuit: cirq.Circuit, module: Optional[Module] = None) -> "CirqModule":
"""Class method. Constructs a CirqModule from a given cirq.Circuit object
and an optional QIR Module."""
elements: List[_CircuitElement] = []
elements: list[_CircuitElement] = []

# Register(s). Tentatively using cirq.Qid as input. Better approaches might exist tbd.
elements.append(_Register(list(circuit.all_qubits())))
Expand Down
8 changes: 4 additions & 4 deletions qbraid_qir/cirq/opsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Module mapping supported Cirq gates/operations to pyqir functions.
"""
from typing import Callable, Tuple
from typing import Callable

import cirq
import pyqir._native
Expand All @@ -38,7 +38,7 @@ def measure_y(builder, qubit, result):
pyqir._native.mz(builder, qubit, result)


PYQIR_OP_MAP = {
PYQIR_OP_MAP: dict[str, Callable] = {
# Identity Gate
"I": i,
# Single-Qubit Clifford Gates
Expand Down Expand Up @@ -71,15 +71,15 @@ def measure_y(builder, qubit, result):

def map_cirq_op_to_pyqir_callable(
operation: cirq.Operation,
) -> Tuple[Callable, str]:
) -> tuple[Callable, str]:
"""
Maps a Cirq operation to its corresponding PyQIR callable function.
Args:
operation (cirq.Operation): The Cirq operation to map.
Returns:
Tuple[Callable, str]: Tuple containing the corresponding PyQIR callable function,
tuple[Callable, str]: tuple containing the corresponding PyQIR callable function,
and a string representing the gate/operation type.
Raises:
Expand Down
10 changes: 5 additions & 5 deletions qbraid_qir/cirq/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,23 @@
"""
import itertools
from typing import List
from typing import Iterable

import cirq

from .exceptions import CirqConversionError
from .opsets import map_cirq_op_to_pyqir_callable


def _decompose_gate_op(operation: cirq.GateOperation) -> List[cirq.OP_TREE]:
def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]:
"""Decomposes a single Cirq gate operation into a sequence of operations
that are directly supported by PyQIR.
Args:
operation (cirq.GateOperation): The gate operation to decompose.
operation (cirq.Operation): The gate operation to decompose.
Returns:
List[cirq.OP_TREE]: A list of decomposed gate operations.
Iterable[cirq.OP_TREE]: A list of decomposed gate operations.
"""
try:
# Try converting to PyQIR. If successful, keep the operation.
Expand Down Expand Up @@ -58,7 +58,7 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit:
new_ops = []
for operation in moment:
if isinstance(operation, cirq.GateOperation):
decomposed_ops = _decompose_gate_op(operation)
decomposed_ops = list(_decompose_gate_op(operation))
new_ops.extend(decomposed_ops)
elif isinstance(operation, cirq.ClassicallyControlledOperation):
new_ops.append(operation)
Expand Down
27 changes: 13 additions & 14 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""
import logging
from abc import ABCMeta, abstractmethod
from typing import List

import cirq
import pyqir
Expand Down Expand Up @@ -49,11 +48,11 @@ class BasicCirqVisitor(CircuitElementVisitor):
"""

def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
self._module = None
self._builder = None
self._entry_point = None
self._qubit_labels = {}
self._measured_qubits = {}
self._module: pyqir.Module
self._builder: pyqir.Builder
self._entry_point: str
self._qubit_labels: dict[cirq.Qid, int] = {}
self._measured_qubits: dict = {}
self._initialize_runtime = initialize_runtime
self._record_output = record_output

Expand Down Expand Up @@ -89,20 +88,18 @@ def record_output(self, module: CirqModule) -> None:
result_ref = pyqir.result(self._module.context, i)
pyqir.rt.result_record_output(self._builder, result_ref, Constant.null(i8p))

def visit_register(self, qids: List[cirq.Qid]) -> None:
def visit_register(self, qids: list[cirq.Qid]) -> None:
logger.debug("Visiting qids '%s'", str(qids))

if not isinstance(qids, list):
raise TypeError("Parameter is not a list.")

if not all(isinstance(x, cirq.Qid) for x in qids):
raise TypeError("All elements in the list must be of type cirq.Qid.")

self._qubit_labels.update({bit: n + len(self._qubit_labels) for n, bit in enumerate(qids)})
logger.debug("Added labels for qubits %s", str(qids))

def visit_operation(self, operation: cirq.Operation) -> None:
qlabels = [self._qubit_labels.get(bit) for bit in operation.qubits]
qlabels = [self._qubit_labels[bit] for bit in operation.qubits]

qubits = [pyqir.qubit(self._module.context, n) for n in qlabels]
results = [pyqir.result(self._module.context, n) for n in qlabels]

Expand All @@ -125,7 +122,9 @@ def handle_measurement(pyqir_func):
# pylint: disable=unnecessary-lambda-assignment
if op_str in ["Rx", "Ry", "Rz"]:
pyqir_func = lambda: temp_pyqir_func(
self._builder, operation._sub_operation.gate._rads, *qubits
self._builder,
operation._sub_operation.gate._rads, # type: ignore[union-attr]
*qubits,
)
else:
pyqir_func = lambda: temp_pyqir_func(self._builder, *qubits)
Expand All @@ -149,12 +148,12 @@ def _branch(conds, pyqir_func):
if op_str.startswith("measure"):
handle_measurement(pyqir_func)
elif op_str in ["Rx", "Ry", "Rz"]:
pyqir_func(self._builder, operation.gate._rads, *qubits)
pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr]
else:
pyqir_func(self._builder, *qubits)

def ir(self) -> str:
return str(self._module)

def bitcode(self) -> bytes:
return self._module.bitcode()
return self._module.bitcode

Check warning on line 159 in qbraid_qir/cirq/visitor.py

View check run for this annotation

Codecov / codecov/patch

qbraid_qir/cirq/visitor.py#L159

Added line #L159 was not covered by tests
4 changes: 2 additions & 2 deletions qbraid_qir/qasm3/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Variable:
name (str): Name of the variable.
base_type (Any): Base type of the variable.
base_size (int): Base size of the variable.
dims (List[int]): Dimensions of the variable.
dims (list[int]): Dimensions of the variable.
value (Optional[Union[int, float, list]]): Value of the variable.
is_constant (bool): Flag indicating if the variable is constant.
Expand Down Expand Up @@ -122,7 +122,7 @@ class Qasm3Module:
module (Module): QIR Module instance.
num_qubits (int): Number of qubits in the circuit.
num_clbits (int): Number of classical bits in the circuit.
elements (List[Statement]): List of openqasm3 Statements.
elements (list[Statement]): list of openqasm3 Statements.
"""

# pylint: disable-next=too-many-arguments
Expand Down
2 changes: 1 addition & 1 deletion qbraid_qir/qasm3/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def transform_gate_params(gate_op: QuantumGate, param_map: dict[str, Expression]
Args:
gate_op (QuantumGate): The gate operation to transform.
param_map (Dict[str, Expression]): The parameter map to use for transformation.
param_map (dict[str, Expression]): The parameter map to use for transformation.
Returns:
None
Expand Down
Loading

0 comments on commit 7b4b142

Please sign in to comment.