Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring the Qasm3Visitor #142

Merged
merged 10 commits into from
Aug 22, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Types of changes:
* Add support for pauli measurement operators in `cirq` converter ( [#144](https://github.com/qBraid/qbraid-qir/pull/144) )

### 🌟 Improvements
* Re-factor the `BasicQasmVisitor` and improve modularity ( [#142](https://github.com/qBraid/qbraid-qir/pull/142) )

### 📜 Documentation
* Housekeeping updates for release ( [#135](https://github.com/qBraid/qbraid-qir/pull/135) )
Expand Down
10 changes: 5 additions & 5 deletions qbraid_qir/cirq/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .elements import CirqModule
from .opsets import map_cirq_op_to_pyqir_callable

_log = logging.getLogger(name=__name__)
logger = logging.getLogger(__name__)


class CircuitElementVisitor(metaclass=ABCMeta):
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(self, initialize_runtime: bool = True, record_output: bool = True):
self._record_output = record_output

def visit_cirq_module(self, module: CirqModule) -> None:
_log.debug("Visiting Cirq module '%s' (%d)", module.name, module.num_qubits)
logger.debug("Visiting Cirq module '%s' (%d)", module.name, module.num_qubits)
self._module = module.module
context = self._module.context
entry = pyqir.entry_point(self._module, module.name, module.num_qubits, module.num_clbits)
Expand Down Expand Up @@ -90,7 +90,7 @@ def record_output(self, module: CirqModule) -> None:
pyqir.rt.result_record_output(self._builder, result_ref, Constant.null(i8p))

def visit_register(self, qids: List[cirq.Qid]) -> None:
_log.debug("Visiting qids '%s'", str(qids))
logger.debug("Visiting qids '%s'", str(qids))

if not isinstance(qids, list):
raise TypeError("Parameter is not a list.")
Expand All @@ -99,15 +99,15 @@ def visit_register(self, qids: List[cirq.Qid]) -> None:
raise TypeError("All elements in the list must be of type cirq.Qid.")

self._qubit_labels.update({bit: n + len(self._qubit_labels) for n, bit in enumerate(qids)})
_log.debug("Added labels for qubits %s", str(qids))
logger.debug("Added labels for qubits %s", str(qids))

def visit_operation(self, operation: cirq.Operation) -> None:
qlabels = [self._qubit_labels.get(bit) for bit in operation.qubits]
qubits = [pyqir.qubit(self._module.context, n) for n in qlabels]
results = [pyqir.result(self._module.context, n) for n in qlabels]

def handle_measurement(pyqir_func):
_log.debug("Visiting measurement operation '%s'", str(operation))
logger.debug("Visiting measurement operation '%s'", str(operation))
for qubit, result in zip(qubits, results):
self._measured_qubits[pyqir.qubit_id(qubit)] = True
pyqir_func(self._builder, qubit, result)
Expand Down
2 changes: 1 addition & 1 deletion qbraid_qir/qasm3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
| BranchingStatement | 🔜 | In progress |
| SubroutineDefinition | 🔜 | In progress |
| Looping statements(eg. for) | 🔜 | In progress |
| RangeDefinition | 🔜 | In progress |
| IODeclaration | 📋 | Planned |
| RangeDefinition | 📋 | Planned |
| Pragma | ❓ | Unsure |
| Annotations | ❓ | Unsure |
| Pulse-level ops (e.g. delay) | ❌ | Not supported by QIR |
Expand Down
182 changes: 182 additions & 0 deletions qbraid_qir/qasm3/analyzer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (C) 2024 qBraid
#
# This file is part of the qBraid-SDK
#
# The qBraid-SDK is free software released under the GNU General Public License v3
# or later. You can redistribute and/or modify it under the terms of the GPL v3.
# See the LICENSE file in the project root or <https://www.gnu.org/licenses/gpl-3.0.html>.
#
# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3.

"""
Module with analysis functions for QASM3 visitor

"""
from typing import Any

from openqasm3.ast import (
BinaryExpression,
IndexExpression,
IntegerLiteral,
RangeDefinition,
UnaryExpression,
)

from .elements import Variable
from .exceptions import Qasm3ConversionError, raise_qasm3_error


class Qasm3Analyzer:
"""Class with utility functions for analyzing QASM3 elements"""

@staticmethod
def analyze_classical_indices(indices: list[IntegerLiteral], var: Variable) -> None:
"""Validate the indices for a classical variable.

Args:
indices (list[list[Any]]): The indices to validate.
var_name (Variable): The variable to verify

Raises:
Qasm3ConversionError: If the indices are invalid.

Returns:
list: The list of indices.
"""
indices_list = []
var_name = var.name
var_dimensions = var.dims

if not var_dimensions:
raise_qasm3_error(
message=f"Indexing error. Variable {var_name} is not an array",
err_type=Qasm3ConversionError,
span=indices[0].span,
)
if len(indices) != len(var_dimensions):
raise_qasm3_error(
message=f"Invalid number of indices for variable {var_name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}",
err_type=Qasm3ConversionError,
span=indices[0].span,
)

for i, index in enumerate(indices):
if isinstance(index, RangeDefinition):
raise_qasm3_error(
message=f"Range based indexing {index} not supported for "
f"classical variable {var_name}",
err_type=Qasm3ConversionError,
span=index.span,
)

if not isinstance(index, IntegerLiteral):
raise_qasm3_error(
message=f"Unsupported index type {type(index)} for "
f"classical variable {var_name}",
err_type=Qasm3ConversionError,
span=index.span,
)
index_value = index.value
curr_dimension = var_dimensions[i]

if index_value < 0 or index_value >= curr_dimension:
raise_qasm3_error(
message=f"Index {index_value} out of bounds for dimension {i+1} "
f"of variable {var_name}",
err_type=Qasm3ConversionError,
span=index.span,
)
indices_list.append(index_value)

return indices_list

@staticmethod
def analyze_index_expression(index_expr: IndexExpression) -> tuple[str, list[list]]:
"""analyze an index expression to get the variable name and indices.

Args:
index_expr (IndexExpression): The index expression to analyze.

Returns:
tuple[str, list[list]]: The variable name and indices.

"""
indices = []
var_name = None
comma_separated = False

if isinstance(index_expr.collection, IndexExpression):
while isinstance(index_expr, IndexExpression):
indices.append(index_expr.index[0])
index_expr = index_expr.collection
else:
comma_separated = True
indices = index_expr.index

var_name = index_expr.collection.name if comma_separated else index_expr.name
if not comma_separated:
indices = indices[::-1]

return var_name, indices

@staticmethod
def find_array_element(multi_dim_arr: list[Any], indices: list[int]) -> Any:
"""Find the value of an array at the specified indices.

Args:
multi_dim_arr (list): The multi-dimensional list to search.
indices (list[int]): The indices to search.

Returns:
Any: The value at the specified indices.
"""
temp = multi_dim_arr
for index in indices:
temp = temp[index]
return temp

@staticmethod
def analyse_branch_condition(condition: Any) -> bool:
"""
analyze the branching condition to determine the branch to take

Args:
condition (Any): The condition to analyze

Returns:
bool: The branch to take
"""

if isinstance(condition, UnaryExpression):
if condition.op.name != "!":
raise_qasm3_error(
message=f"Unsupported unary expression '{condition.op.name}' in if condition",
err_type=Qasm3ConversionError,
span=condition.span,
)
return False
if isinstance(condition, BinaryExpression):
if condition.op.name != "==":
raise_qasm3_error(
message=f"Unsupported binary expression '{condition.op.name}' in if condition",
err_type=Qasm3ConversionError,
span=condition.span,
)
if not isinstance(condition.lhs, IndexExpression):
raise_qasm3_error(
message=f"Unsupported expression type '{type(condition.rhs)}' in if condition",
err_type=Qasm3ConversionError,
span=condition.span,
)
return condition.rhs.value != 0
if not isinstance(condition, IndexExpression):
raise_qasm3_error(
message=(
f"Unsupported expression type '{type(condition)}' in if condition. "
"Can only be a simple comparison"
),
err_type=Qasm3ConversionError,
span=condition.span,
)
TheGupta2012 marked this conversation as resolved.
Show resolved Hide resolved
return True
32 changes: 32 additions & 0 deletions qbraid_qir/qasm3/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,40 @@
Module defining exceptions for errors raised during QASM3 conversions.

"""
import logging
from typing import Optional, Type

from openqasm3.ast import Span

from qbraid_qir.exceptions import QirConversionError


class Qasm3ConversionError(QirConversionError):
"""Class for errors raised when converting an OpenQASM 3 program to QIR."""


def raise_qasm3_error(
message: Optional[str] = None,
err_type: Type[Exception] = Qasm3ConversionError,
span: Optional[Span] = None,
raised_from: Optional[Exception] = None,
) -> None:
"""Raises a QASM3 conversion error with optional chaining from another exception.

Args:
message: The error message. If not provided, a default message will be used.
err_type: The type of error to raise.
span: The span (location) in the QASM file where the error occurred.
raised_from: Optional exception from which this error was raised (chaining).

Raises:
err_type: The error type initialized with the specified message and chained exception.
"""
if span:
logging.error(
"Error at line %s, column %s in QASM file", span.start_line, span.start_column
)

if raised_from:
raise err_type(message) from raised_from
raise err_type(message)
Loading