diff --git a/qbraid_qir/qasm3/utils/expressions.py b/qbraid_qir/qasm3/utils/expressions.py new file mode 100644 index 0000000..f4ef8a5 --- /dev/null +++ b/qbraid_qir/qasm3/utils/expressions.py @@ -0,0 +1,232 @@ +# Copyright (C) 2024 qBraid +# +# This file is part of the qBraid-SDK +# +# The qBraid-SDK is free software released under the GNU General Public License v3 +# or later. You can redistribute and/or modify it under the terms of the GPL v3. +# See the LICENSE file in the project root or . +# +# THERE IS NO WARRANTY for the qBraid-SDK, as per Section 15 of the GPL v3. + +""" +Module containing the class for evaluating QASM3 expressions. + +""" + +from openqasm3.ast import ( + BinaryExpression, + BooleanLiteral, + BoolType, + DurationLiteral, + FloatLiteral, + FunctionCall, + Identifier, + ImaginaryLiteral, + IndexExpression, + IntegerLiteral, + UnaryExpression, +) + +from ..exceptions import Qasm3ConversionError +from .imports import Qasm3FloatType, Qasm3IntType +from .maps import CONSTANTS_MAP, qasm3_expression_op_map +from .visitor_utils import Qasm3VisitorUtils + + +class Qasm3ExprEvaluator: + """Class for evaluating QASM3 expressions.""" + + @staticmethod + def _check_var_in_scope(visitor_obj, var_name, expression): + """ + Checks if a variable is in scope. + Args: + visitor_obj: The visitor object. + var_name: The name of the variable to check. + expression: The expression containing the variable. + Raises: + Qasm3ConversionError: If the variable is undefined in the current scope. + """ + + if not visitor_obj._check_in_scope(var_name, visitor_obj._get_curr_scope()): + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError(f"Undefined identifier {var_name} in expression") + + @staticmethod + def _check_var_constant(visitor_obj, var_name, const_expr, expression): + """ + Checks if a variable is constant. + + Args: + visitor_obj: The visitor object. + var_name: The name of the variable to check. + const_expr: Whether the expression is a constant. + expression: The expression containing the variable. + + Raises: + Qasm3ConversionError: If the variable is not a constant in the given + expression. + """ + const_var = visitor_obj._get_from_visible_scope(var_name).is_constant + if const_expr and not const_var: + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError( + f"Variable '{var_name}' is not a constant in given expression" + ) + + @staticmethod + def _check_var_type(visitor_obj, var_name, reqd_type, expression): + """ + Check the type of a variable and raise an error if it does not match the + required type. + + Args: + visitor_obj: The visitor object. + var_name: The name of the variable to check. + reqd_type: The required type of the variable. + expression: The expression where the variable is used. + + Raises: + Qasm3ConversionError: If the variable has an invalid type for the required type. + """ + + if not Qasm3VisitorUtils.validate_variable_type( + visitor_obj._get_from_visible_scope(var_name), reqd_type + ): + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError( + f"Invalid type of variable {var_name} for required type {reqd_type}" + ) + + @staticmethod + def _check_var_initialized(var_name, var_value, expression): + """ + Checks if a variable is initialized and raises an error if it is not. + Args: + var_name (str): The name of the variable. + var_value: The value of the variable. + expression: The expression where the variable is used. + Raises: + Qasm3ConversionError: If the variable is uninitialized. + """ + + if var_value is None: + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError(f"Uninitialized variable {var_name} in expression") + + @staticmethod + def _get_var_value(visitor_obj, var_name, indices, expression): + """ + Retrieves the value of a variable. + Args: + visitor_obj (Visitor): The visitor object. + var_name (str): The name of the variable. + indices (list): The indices of the variable (if it is an array). + expression (Identifier or Expression): The expression representing the variable. + Returns: + var_value: The value of the variable. + """ + + var_value = None + if isinstance(expression, Identifier): + var_value = visitor_obj._get_from_visible_scope(var_name).value + else: + validated_indices = Qasm3VisitorUtils.analyse_classical_indices( + indices, visitor_obj._get_from_visible_scope(var_name) + ) + var_value = Qasm3VisitorUtils.find_array_element( + visitor_obj._get_from_visible_scope(var_name).value, validated_indices + ) + return var_value + + # pylint: disable-next=too-many-return-statements, too-many-statements + @staticmethod + def evaluate_expression(visitor_obj, expression, const_expr: bool = False, reqd_type=None): + """Evaluate an expression. Scalar types are assigned by value. + + + Args: + expression (Any): The expression to evaluate. + const_expr (bool): Whether the expression is a constant. Defaults to False. + reqd_type (Any): The required type of the expression. Defaults to None. + + Returns: + Any : The result of the evaluation. + + Raises: + Qasm3ConversionError: If the expression is not supported. + """ + if expression is None: + return None + + if isinstance(expression, (ImaginaryLiteral, DurationLiteral)): + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError(f"Unsupported expression type {type(expression)}") + + def _process_variable(var_name, indices=None): + Qasm3ExprEvaluator._check_var_in_scope(visitor_obj, var_name, expression) + Qasm3ExprEvaluator._check_var_constant(visitor_obj, var_name, const_expr, expression) + Qasm3ExprEvaluator._check_var_type(visitor_obj, var_name, reqd_type, expression) + var_value = Qasm3ExprEvaluator._get_var_value( + visitor_obj, var_name, indices, expression + ) + Qasm3ExprEvaluator._check_var_initialized(var_name, var_value, expression) + return var_value + + if isinstance(expression, Identifier): + var_name = expression.name + if var_name in CONSTANTS_MAP: + if not reqd_type or reqd_type == Qasm3FloatType: + return CONSTANTS_MAP[var_name] + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError( + f"Constant {var_name} not allowed in non-float expression" + ) + return _process_variable(var_name) + + if isinstance(expression, IndexExpression): + var_name, indices = Qasm3VisitorUtils.analyse_index_expression(expression) + return _process_variable(var_name, indices) + + if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)): + if reqd_type: + if reqd_type == BoolType and isinstance(expression, BooleanLiteral): + return expression.value + if reqd_type == Qasm3IntType and isinstance(expression, IntegerLiteral): + return expression.value + if reqd_type == Qasm3FloatType and isinstance(expression, FloatLiteral): + return expression.value + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError( + f"Invalid type {type(expression)} for required type {reqd_type}" + ) + return expression.value + + if isinstance(expression, UnaryExpression): + operand = Qasm3ExprEvaluator.evaluate_expression( + visitor_obj, expression.expression, const_expr, reqd_type + ) + if expression.op.name == "~" and not isinstance(operand, int): + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError( + f"Unsupported expression type {type(operand)} in ~ operation" + ) + return qasm3_expression_op_map( + "UMINUS" if expression.op.name == "-" else expression.op.name, operand + ) + if isinstance(expression, BinaryExpression): + lhs = Qasm3ExprEvaluator.evaluate_expression( + visitor_obj, expression.lhs, const_expr, reqd_type + ) + rhs = Qasm3ExprEvaluator.evaluate_expression( + visitor_obj, expression.rhs, const_expr, reqd_type + ) + return qasm3_expression_op_map(expression.op.name, lhs, rhs) + + if isinstance(expression, FunctionCall): + # function will not return a reqd / const type + # Reference : https://openqasm.com/language/types.html#compile-time-constants + # para : 5 + return visitor_obj._visit_function_call(expression) + + Qasm3VisitorUtils.print_err_location(expression.span) + raise Qasm3ConversionError(f"Unsupported expression type {type(expression)}") diff --git a/qbraid_qir/qasm3/utils/imports.py b/qbraid_qir/qasm3/utils/imports.py index 2190cba..f3acafa 100644 --- a/qbraid_qir/qasm3/utils/imports.py +++ b/qbraid_qir/qasm3/utils/imports.py @@ -17,8 +17,6 @@ AliasStatement, ArrayLiteral, ArrayType, - BinaryExpression, - BooleanLiteral, BoolType, BranchingStatement, ClassicalArgument, @@ -26,9 +24,7 @@ ClassicalDeclaration, ConstantDeclaration, DiscreteSet, - DurationLiteral, ExpressionStatement, - FloatLiteral, ) from openqasm3.ast import FloatType as Qasm3FloatType from openqasm3.ast import ( @@ -36,7 +32,6 @@ FunctionCall, GateModifierName, Identifier, - ImaginaryLiteral, Include, IndexedIdentifier, IndexExpression, @@ -57,7 +52,6 @@ Statement, SubroutineDefinition, SwitchStatement, - UnaryExpression, WhileLoop, ) @@ -65,8 +59,6 @@ "AliasStatement", "ArrayLiteral", "ArrayType", - "BinaryExpression", - "BooleanLiteral", "BoolType", "BranchingStatement", "ClassicalArgument", @@ -74,15 +66,12 @@ "ClassicalDeclaration", "ConstantDeclaration", "DiscreteSet", - "DurationLiteral", "ExpressionStatement", - "FloatLiteral", "Qasm3FloatType", "ForInLoop", "FunctionCall", "GateModifierName", "Identifier", - "ImaginaryLiteral", "Include", "IndexedIdentifier", "IndexExpression", @@ -101,6 +90,5 @@ "Statement", "SubroutineDefinition", "SwitchStatement", - "UnaryExpression", "WhileLoop", ] diff --git a/qbraid_qir/qasm3/visitor.py b/qbraid_qir/qasm3/visitor.py index 4cdc183..17ab9e2 100644 --- a/qbraid_qir/qasm3/visitor.py +++ b/qbraid_qir/qasm3/visitor.py @@ -29,6 +29,7 @@ from .elements import Context, InversionOp, Qasm3Module, Variable from .exceptions import Qasm3ConversionError +from .utils.expressions import Qasm3ExprEvaluator from .utils.imports import * from .utils.maps import ( CONSTANTS_MAP, @@ -36,7 +37,6 @@ SWITCH_BLACKLIST_STMTS, map_qasm_inv_op_to_pyqir_callable, map_qasm_op_to_pyqir_callable, - qasm3_expression_op_map, ) from .utils.visitor_utils import Qasm3VisitorUtils @@ -239,24 +239,6 @@ def _add_var_in_scope(self, variable: Variable) -> None: raise ValueError(f"Variable '{variable.name}' already exists in current scope") curr_scope[variable.name] = variable - def _delete_var_from_scope(self, var_name: str) -> None: - """ - Deletes a variable from the current scope. - - Args: - var_name (str): The name of the variable to be deleted. - - Raises: - ValueError: If the variable is not found in the current scope. - - Returns: - None - """ - curr_scope = self._get_curr_scope() - if var_name not in curr_scope: - raise ValueError(f"Variable '{var_name}' not found in current scope") - del curr_scope[var_name] - def _update_var_in_scope(self, variable: Variable) -> None: """ Updates the variable in the current scope. @@ -369,9 +351,21 @@ def _get_qubits_from_range_definition( Returns: list[int]: The list of qubit identifiers. """ - start_qid = 0 if range_def.start is None else self._evaluate_expression(range_def.start) - end_qid = qreg_size if range_def.end is None else self._evaluate_expression(range_def.end) - step = 1 if range_def.step is None else self._evaluate_expression(range_def.step) + start_qid = ( + 0 + if range_def.start is None + else Qasm3ExprEvaluator.evaluate_expression(self, range_def.start) + ) + end_qid = ( + qreg_size + if range_def.end is None + else Qasm3ExprEvaluator.evaluate_expression(self, range_def.end) + ) + step = ( + 1 + if range_def.step is None + else Qasm3ExprEvaluator.evaluate_expression(self, range_def.step) + ) Qasm3VisitorUtils.validate_register_index(start_qid, qreg_size, qubit=is_qubit_reg) Qasm3VisitorUtils.validate_register_index(end_qid - 1, qreg_size, qubit=is_qubit_reg) return list(range(start_qid, end_qid, step)) @@ -423,7 +417,7 @@ def _get_op_qubits(self, operation, qreg_size_map, qir_form: bool = True) -> lis qubit.indices[0][0], qreg_size, is_qubit_reg=True ) else: - qid = self._evaluate_expression(qubit.indices[0][0]) + qid = Qasm3ExprEvaluator.evaluate_expression(self, qubit.indices[0][0]) Qasm3VisitorUtils.validate_register_index(qid, qreg_size, qubit=True) qids = [qid] openqasm_qubits.extend( @@ -586,7 +580,7 @@ def _get_op_parameters(self, operation: QuantumGate) -> list[float]: """ param_list = [] for param in operation.arguments: - param_value = self._evaluate_expression(param) + param_value = Qasm3ExprEvaluator.evaluate_expression(self, param) param_list.append(param_value) return param_list @@ -744,7 +738,7 @@ def _collapse_gate_modifiers(self, operation: QuantumGate) -> tuple: for modifier in operation.modifiers: modifier_name = modifier.modifier if modifier_name == GateModifierName.pow and modifier.argument is not None: - current_power = self._evaluate_expression(modifier.argument) + current_power = Qasm3ExprEvaluator.evaluate_expression(self, modifier.argument) if current_power < 0: inverse_value = not inverse_value power_value = power_value * abs(current_power) @@ -809,7 +803,9 @@ def _visit_constant_declaration(self, statement: ConstantDeclaration) -> None: Qasm3VisitorUtils.print_err_location(statement.span) raise Qasm3ConversionError(f"Re-declaration of variable {var_name}") - init_value = self._evaluate_expression(statement.init_expression, const_expr=True) + init_value = Qasm3ExprEvaluator.evaluate_expression( + self, statement.init_expression, const_expr=True + ) base_type = statement.type if isinstance(base_type, BoolType): @@ -817,7 +813,9 @@ def _visit_constant_declaration(self, statement: ConstantDeclaration) -> None: elif base_type.size is None: base_size = 32 # default for now else: - base_size = self._evaluate_expression(base_type.size, const_expr=True) + base_size = Qasm3ExprEvaluator.evaluate_expression( + self, base_type.size, const_expr=True + ) if not isinstance(base_size, int) or base_size <= 0: Qasm3VisitorUtils.print_err_location(statement.span) raise Qasm3ConversionError(f"Invalid base size {base_size} for variable {var_name}") @@ -873,7 +871,7 @@ def _visit_classical_declaration(self, statement: ClassicalDeclaration) -> None: base_type = base_type.base_type num_elements = 1 for dim in dimensions: - dim_value = self._evaluate_expression(dim) + dim_value = Qasm3ExprEvaluator.evaluate_expression(self, dim) if not isinstance(dim_value, int) or dim_value <= 0: Qasm3VisitorUtils.print_err_location(statement.span) raise Qasm3ConversionError( @@ -892,10 +890,14 @@ def _visit_classical_declaration(self, statement: ClassicalDeclaration) -> None: statement.init_expression, final_dimensions, base_type ) else: - init_value = self._evaluate_expression(statement.init_expression) + init_value = Qasm3ExprEvaluator.evaluate_expression(self, statement.init_expression) base_size = 1 if not isinstance(base_type, BoolType): - base_size = 32 if base_type.size is None else self._evaluate_expression(base_type.size) + base_size = ( + 32 + if base_type.size is None + else Qasm3ExprEvaluator.evaluate_expression(self, base_type.size) + ) if not isinstance(base_size, int) or base_size <= 0: Qasm3VisitorUtils.print_err_location(statement.span) @@ -946,7 +948,7 @@ def _visit_classical_assignment(self, statement: ClassicalAssignment) -> None: Qasm3VisitorUtils.print_err_location(statement.span) raise Qasm3ConversionError(f"Assignment to constant variable {var_name} not allowed") - var_value = self._evaluate_expression(statement.rvalue) + var_value = Qasm3ExprEvaluator.evaluate_expression(self, statement.rvalue) # currently we support single array assignment only # range based assignment not supported yet @@ -992,134 +994,11 @@ def _evaluate_array_initialization( self._evaluate_array_initialization(value, dimensions[1:], base_type) ) else: - eval_value = self._evaluate_expression(value) + eval_value = Qasm3ExprEvaluator.evaluate_expression(self, value) init_values.append(eval_value) return init_values - # pylint: disable-next=too-many-return-statements, too-many-statements - def _evaluate_expression(self, expression, const_expr: bool = False, reqd_type=None): - """Evaluate an expression. Scalar types are assigned by value. - - Args: - expression (Any): The expression to evaluate. - const_expr (bool): Whether the expression is a constant. Defaults to False. - reqd_type (Any): The required type of the expression. Defaults to None. - - Returns: - bool: The result of the evaluation. - - Raises: - Qasm3ConversionError: If the expression is not supported. - """ - if expression is None: - return None - - if isinstance(expression, (ImaginaryLiteral, DurationLiteral)): - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError(f"Unsupported expression type {type(expression)}") - - def _check_var_in_scope(var_name): - if not self._check_in_scope(var_name, self._get_curr_scope()): - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError(f"Undefined identifier {var_name} in expression") - - def _check_var_constant(var_name): - const_var = self._get_from_visible_scope(var_name).is_constant - if const_expr and not const_var: - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError( - f"Variable '{var_name}' is not a constant in given expression" - ) - - def _check_var_type(var_name, reqd_type): - if not Qasm3VisitorUtils.validate_variable_type( - self._get_from_visible_scope(var_name), reqd_type - ): - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError( - f"Invalid type of variable {var_name} for required type {reqd_type}" - ) - - def _check_var_initialized(var_name, var_value): - if var_value is None: - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError(f"Uninitialized variable {var_name} in expression") - - def _get_var_value(var_name, indices=None): - var_value = None - if isinstance(expression, Identifier): - var_value = self._get_from_visible_scope(var_name).value - else: - validated_indices = Qasm3VisitorUtils.analyse_classical_indices( - indices, self._get_from_visible_scope(var_name) - ) - var_value = Qasm3VisitorUtils.find_array_element( - self._get_from_visible_scope(var_name).value, validated_indices - ) - return var_value - - def process_variable(var_name, indices=None): - _check_var_in_scope(var_name) - _check_var_constant(var_name) - _check_var_type(var_name, reqd_type) - var_value = _get_var_value(var_name, indices) - _check_var_initialized(var_name, var_value) - return var_value - - if isinstance(expression, Identifier): - var_name = expression.name - if var_name in CONSTANTS_MAP: - if not reqd_type or reqd_type == Qasm3FloatType: - return CONSTANTS_MAP[var_name] - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError( - f"Constant {var_name} not allowed in non-float expression" - ) - return process_variable(var_name) - - if isinstance(expression, IndexExpression): - var_name, indices = Qasm3VisitorUtils.analyse_index_expression(expression) - return process_variable(var_name, indices) - - if isinstance(expression, (BooleanLiteral, IntegerLiteral, FloatLiteral)): - if reqd_type: - if reqd_type == BoolType and isinstance(expression, BooleanLiteral): - return expression.value - if reqd_type == Qasm3IntType and isinstance(expression, IntegerLiteral): - return expression.value - if reqd_type == Qasm3FloatType and isinstance(expression, FloatLiteral): - return expression.value - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError( - f"Invalid type {type(expression)} for required type {reqd_type}" - ) - return expression.value - - if isinstance(expression, UnaryExpression): - operand = self._evaluate_expression(expression.expression, const_expr, reqd_type) - if expression.op.name == "~" and not isinstance(operand, int): - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError( - f"Unsupported expression type {type(operand)} in ~ operation" - ) - return qasm3_expression_op_map( - "UMINUS" if expression.op.name == "-" else expression.op.name, operand - ) - if isinstance(expression, BinaryExpression): - lhs = self._evaluate_expression(expression.lhs, const_expr, reqd_type) - rhs = self._evaluate_expression(expression.rhs, const_expr, reqd_type) - return qasm3_expression_op_map(expression.op.name, lhs, rhs) - - if isinstance(expression, FunctionCall): - # function will not return a reqd / const type - # Reference : https://openqasm.com/language/types.html#compile-time-constants - # para : 5 - return self._visit_function_call(expression) - - Qasm3VisitorUtils.print_err_location(expression.span) - raise Qasm3ConversionError(f"Unsupported expression type {type(expression)}") - def _visit_branching_statement(self, statement: BranchingStatement) -> None: """Visit a branching statement element. @@ -1176,14 +1055,21 @@ def _visit_forin_loop(self, statement: ForInLoop) -> None: # Compute loop variable values if isinstance(statement.set_declaration, RangeDefinition): init_exp = statement.set_declaration.start - startval = self._evaluate_expression(init_exp) + startval = Qasm3ExprEvaluator.evaluate_expression(self, init_exp) range_def = statement.set_declaration - stepval = 1 if range_def.step is None else self._evaluate_expression(range_def.step) - endval = self._evaluate_expression(range_def.end) + stepval = ( + 1 + if range_def.step is None + else Qasm3ExprEvaluator.evaluate_expression(self, range_def.step) + ) + endval = Qasm3ExprEvaluator.evaluate_expression(self, range_def.end) irange = list(range(startval, endval + stepval, stepval)) elif isinstance(statement.set_declaration, DiscreteSet): init_exp = statement.set_declaration.values[0] - irange = [self._evaluate_expression(exp) for exp in statement.set_declaration.values] + irange = [ + Qasm3ExprEvaluator.evaluate_expression(self, exp) + for exp in statement.set_declaration.values + ] else: raise Qasm3ConversionError( f"Unexpected type {type(statement.set_declaration)} of set_declaration in loop." @@ -1301,7 +1187,7 @@ def _get_target_qubits(self, target, qreg_size_map, target_name): ) target_qubits_size = len(target_qids) elif isinstance(target.index[0], (IntegerLiteral, Identifier)): # "(q[0]); OR (q[i]);" - target_qids = [self._evaluate_expression(target.index[0])] + target_qids = [Qasm3ExprEvaluator.evaluate_expression(self, target.index[0])] Qasm3VisitorUtils.validate_register_index( target_qids[0], qreg_size_map[target_name], qubit=True ) @@ -1366,7 +1252,6 @@ def _validate_unique_qubits(reg_name, indices): f"Duplicate qubit argument '{reg_name}[{idx}]' " f"in function call for '{fn_name}'" ) - duplicate_qubit_detect_map[reg_name].add(idx) def _process_classical_arg(formal_arg, actual_arg, actual_arg_name): """ @@ -1404,14 +1289,14 @@ def _process_classical_arg(formal_arg, actual_arg, actual_arg_name): # NOTE: actual_argument can also be an EXPRESSION # Better to just evaluate that expression and assign that value later to # the formal argument - actual_arg_value = self._evaluate_expression(actual_arg) + actual_arg_value = Qasm3ExprEvaluator.evaluate_expression(self, actual_arg) # save this value to be updated later in scope classical_vars.append( Variable( formal_arg.name.name, formal_arg.type, - self._evaluate_expression(formal_arg.type.size), + Qasm3ExprEvaluator.evaluate_expression(self, formal_arg.type.size), None, actual_arg_value, False, @@ -1436,8 +1321,8 @@ def _process_quantum_arg(formal_arg, actual_arg, formal_reg_name, actual_arg_nam if the actual argument is not a qubit register. """ - formal_qubit_size = self._evaluate_expression( - formal_arg.size, reqd_type=Qasm3IntType, const_expr=True + formal_qubit_size = Qasm3ExprEvaluator.evaluate_expression( + self, formal_arg.size, reqd_type=Qasm3IntType, const_expr=True ) if formal_qubit_size is None: formal_qubit_size = 1 @@ -1506,7 +1391,7 @@ def _process_quantum_arg(formal_arg, actual_arg, formal_reg_name, actual_arg_nam break self.visit_statement(copy.deepcopy(function_op)) - return_value = self._evaluate_expression(return_statement.expression) + return_value = Qasm3ExprEvaluator.evaluate_expression(self, return_statement.expression) return_value = Qasm3VisitorUtils.validate_return_statement( subroutine_def, return_statement, return_value ) @@ -1634,7 +1519,7 @@ def _visit_switch_statement(self, statement: SwitchStatement) -> None: Qasm3VisitorUtils.print_err_location(statement.span) raise Qasm3ConversionError(f"Switch target {switch_target_name} must be of type int") - switch_target_val = self._evaluate_expression(switch_target) + switch_target_val = Qasm3ExprEvaluator.evaluate_expression(self, switch_target) if len(statement.cases) == 0: Qasm3VisitorUtils.print_err_location(statement.span) @@ -1665,8 +1550,8 @@ def _evaluate_case(statements): # 3. evaluate and verify that it is a const_expression # using vars only within the scope AND each component is either a # literal OR type int - case_val = self._evaluate_expression( - case_expr, const_expr=True, reqd_type=Qasm3IntType + case_val = Qasm3ExprEvaluator.evaluate_expression( + self, case_expr, const_expr=True, reqd_type=Qasm3IntType ) if case_val in seen_values: diff --git a/tests/qasm3_qir/converter/test_switch.py b/tests/qasm3_qir/converter/test_switch.py index 0e3f616..089aa0c 100644 --- a/tests/qasm3_qir/converter/test_switch.py +++ b/tests/qasm3_qir/converter/test_switch.py @@ -206,15 +206,16 @@ def test_nested_switch(): case 1,3,5,7 { int j = 4; // definition inside scope switch(j) { - case 1,3,5,7 { - x q; - } - case 2,4,6,8 { - y q; // this will be executed - } - default { - z q; - } + case 1,3,5,7 { + x q; + } + case 2,4,6,8 { + j = 5; // assignment inside scope + y q; // this will be executed + } + default { + z q; + } } } case 2,4,6,8 {