Skip to content

Commit

Permalink
array support for QASM3 subroutines (#151)
Browse files Browse the repository at this point in the history
* start array support for subroutine

* add support for range based array assignment, lil buggy but

* complete support for arrays
  • Loading branch information
TheGupta2012 authored Sep 2, 2024
1 parent 1a983e0 commit 1dea538
Show file tree
Hide file tree
Showing 16 changed files with 1,089 additions and 292 deletions.
99 changes: 68 additions & 31 deletions qbraid_qir/qasm3/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
Module with analysis functions for QASM3 visitor
"""
# pylint: disable=cyclic-import, import-outside-toplevel

from typing import Any, Optional, Union

import numpy as np
from openqasm3.ast import (
BinaryExpression,
DiscreteSet,
Expression,
Identifier,
IndexExpression,
IntegerLiteral,
IntType,
RangeDefinition,
UnaryExpression,
)
Expand All @@ -31,27 +36,27 @@
class Qasm3Analyzer:
"""Class with utility functions for analyzing QASM3 elements"""

@staticmethod
def analyze_classical_indices(indices: list[Any], var: Variable) -> list:
@classmethod
def analyze_classical_indices(cls, indices: list[Any], var: Variable) -> list:
"""Validate the indices for a classical variable.
Args:
indices (list[list[Any]]): The indices to validate.
var_name (Variable): The variable to verify
var (Variable): The variable to verify
Raises:
Qasm3ConversionError: If the indices are invalid.
Returns:
list: The list of indices.
list[list]: The list of indices. Note, we can also have a list of indices within
a list if the variable is a multi-dimensional array.
"""
indices_list = []
var_name = var.name
var_dimensions: Optional[list[int]] = var.dims

if var_dimensions is None or len(var_dimensions) == 0:
raise_qasm3_error(
message=f"Indexing error. Variable {var_name} is not an array",
message=f"Indexing error. Variable {var.name} is not an array",
err_type=Qasm3ConversionError,
span=indices[0].span,
)
Expand All @@ -60,53 +65,85 @@ def analyze_classical_indices(indices: list[Any], var: Variable) -> list:

if len(indices) != len(var_dimensions): # type: ignore[arg-type]
raise_qasm3_error(
message=f"Invalid number of indices for variable {var_name}. "
message=f"Invalid number of indices for variable {var.name}. "
f"Expected {len(var_dimensions)} but got {len(indices)}", # type: ignore[arg-type]
err_type=Qasm3ConversionError,
span=indices[0].span,
)

for i, index in enumerate(indices):
if isinstance(index, RangeDefinition):
def _validate_index(index, dimension, var_name, span, dim_num):
if index < 0 or index >= dimension:
raise_qasm3_error(
message=f"Range based indexing {index} not supported for "
f"classical variable {var_name}",
message=f"Index {index} out of bounds for dimension {dim_num} "
f"of variable {var_name}",
err_type=Qasm3ConversionError,
span=index.span,
span=span,
)

if not isinstance(index, IntegerLiteral):
def _validate_step(start_id, end_id, step, span):
if (step < 0 and start_id < end_id) or (step > 0 and start_id > end_id):
direction = "less than" if step < 0 else "greater than"
raise_qasm3_error(
message=f"Unsupported index type {type(index)} for "
f"classical variable {var_name}",
message=f"Index {start_id} is {direction} {end_id} but step"
f" is {'negative' if step < 0 else 'positive'}",
err_type=Qasm3ConversionError,
span=index.span,
span=span,
)
index_value = index.value
curr_dimension = var_dimensions[i] # type: ignore[index]

if index_value < 0 or index_value >= curr_dimension:
from .expressions import Qasm3ExprEvaluator

for i, index in enumerate(indices):
if not isinstance(index, (Identifier, Expression, RangeDefinition, IntegerLiteral)):
raise_qasm3_error(
message=f"Index {index_value} out of bounds for dimension {i+1} "
f"of variable {var_name}",
message=f"Unsupported index type {type(index)} for "
f"classical variable {var.name}",
err_type=Qasm3ConversionError,
span=index.span,
)
indices_list.append(index_value)

if isinstance(index, RangeDefinition):
assert var_dimensions is not None

start_id = 0
if index.start is not None:
start_id = Qasm3ExprEvaluator.evaluate_expression(
index.start, reqd_type=IntType
)

end_id = var_dimensions[i] - 1
if index.end is not None:
end_id = Qasm3ExprEvaluator.evaluate_expression(index.end, reqd_type=IntType)

step = 1
if index.step is not None:
step = Qasm3ExprEvaluator.evaluate_expression(index.step, reqd_type=IntType)

_validate_index(start_id, var_dimensions[i], var.name, index.span, i)
_validate_index(end_id, var_dimensions[i], var.name, index.span, i)
_validate_step(start_id, end_id, step, index.span)

indices_list.append((start_id, end_id, step))

if isinstance(index, (Identifier, IntegerLiteral, Expression)):
index_value = Qasm3ExprEvaluator.evaluate_expression(index, reqd_type=IntType)
curr_dimension = var_dimensions[i] # type: ignore[index]
_validate_index(index_value, curr_dimension, var.name, index.span, i)

indices_list.append((index_value, index_value, 1))

return indices_list

@staticmethod
def analyze_index_expression(
index_expr: IndexExpression,
) -> tuple[str, list[Union[Any, Expression, RangeDefinition]]]:
"""analyze an index expression to get the variable name and indices.
"""Analyze an index expression to get the variable name and indices.
Args:
index_expr (IndexExpression): The index expression to analyze.
Returns:
tuple[str, list[Any]]: The variable name and indices.
tuple[str, list[Any]]: The variable name and indices in openqasm objects
"""
indices: list[Any] = []
Expand All @@ -132,20 +169,20 @@ def analyze_index_expression(
return var_name, indices

@staticmethod
def find_array_element(multi_dim_arr: list[Any], indices: list[int]) -> Any:
def find_array_element(multi_dim_arr: np.ndarray, indices: list[tuple[int, int, int]]) -> Any:
"""Find the value of an array at the specified indices.
Args:
multi_dim_arr (list): The multi-dimensional list to search.
indices (list[int]): The indices to search.
multi_dim_arr (np.ndarray): The multi-dimensional list to search.
indices (list[tuple[int,int,int]]): The indices to search.
Returns:
Any: The value at the specified indices.
"""
temp = multi_dim_arr
for index in indices:
temp = temp[index]
return temp
slicing = tuple(
slice(start, end + 1, step) if start != end else start for start, end, step in indices
)
return multi_dim_arr[slicing] # type: ignore[index]

@staticmethod
def analyse_branch_condition(condition: Any) -> bool:
Expand Down
6 changes: 5 additions & 1 deletion qbraid_qir/qasm3/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from enum import Enum
from typing import Optional, Union

import numpy as np
from openqasm3.ast import BitType, ClassicalDeclaration, Program, QubitDeclaration, Statement
from pyqir import Context as qirContext
from pyqir import Module
Expand Down Expand Up @@ -60,6 +61,7 @@ class Variable:
dims (list[int]): Dimensions of the variable.
value (Optional[Union[int, float, list]]): Value of the variable.
is_constant (bool): Flag indicating if the variable is constant.
readonly(bool): Flag indicating if the variable is readonly.
"""

Expand All @@ -70,15 +72,17 @@ def __init__(
base_type,
base_size: int,
dims: Optional[list[int]] = None,
value: Optional[Union[int, float, list]] = None,
value: Optional[Union[int, float, np.ndarray]] = None,
is_constant: bool = False,
readonly: bool = False,
):
self.name = name
self.base_type = base_type
self.base_size = base_size
self.dims = dims
self.value = value
self.is_constant = is_constant
self.readonly = readonly


class _ProgramElement(metaclass=ABCMeta):
Expand Down
3 changes: 2 additions & 1 deletion qbraid_qir/qasm3/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ def _process_variable(var_name: str, indices=None):
if reqd_type == Qasm3FloatType and isinstance(expression, FloatLiteral):
return expression.value
raise_qasm3_error(
f"Invalid type {type(expression)} for required type {reqd_type}",
f"Invalid value {expression.value} with type {type(expression)} "
f"for required type {reqd_type}",
Qasm3ConversionError,
expression.span,
)
Expand Down
22 changes: 16 additions & 6 deletions qbraid_qir/qasm3/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,12 +670,22 @@ def qasm_variable_type_cast(openqasm_type, var_name, base_size, rhs_value):

# Reference: https://openqasm.com/language/types.html#allowed-casts
VARIABLE_TYPE_CAST_MAP = {
BoolType: [int, float, bool],
IntType: [bool, int, float],
BitType: [bool, int],
UintType: [bool, int, float],
FloatType: [bool, int, float],
AngleType: [float],
BoolType: (int, float, bool, np.int64, np.float64, np.bool_),
IntType: (bool, int, float, np.int64, np.float64, np.bool_),
BitType: (bool, int, np.int64, np.bool_),
UintType: (bool, int, float, np.int64, np.uint64, np.float64, np.bool_),
FloatType: (bool, int, float, np.int64, np.float64, np.bool_),
AngleType: (float, np.float64),
}

ARRAY_TYPE_MAP = {
BitType: np.bool_,
IntType: np.int64,
UintType: np.uint64,
FloatType: np.float64,
ComplexType: np.complex128,
BoolType: np.bool_,
AngleType: np.float64,
}


Expand Down
Loading

0 comments on commit 1dea538

Please sign in to comment.