Skip to content

Commit

Permalink
add more type fix
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGupta2012 committed Aug 26, 2024
1 parent 7473cde commit 9a53561
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 42 deletions.
2 changes: 1 addition & 1 deletion qbraid_qir/qasm3/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def transform_function_qubits(
q_op: Union[QuantumGate, QuantumBarrier, QuantumReset],
formal_qreg_sizes: dict[str, int],
qubit_map: dict[tuple, tuple],
) -> list:
) -> list[IndexedIdentifier]:
"""Transform the qubits of a function call to the actual qubits.
Args:
Expand Down
101 changes: 60 additions & 41 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# pylint: disable=too-many-instance-attributes,too-many-lines
from collections import deque
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import openqasm3.ast as qasm3_ast
import pyqir
Expand Down Expand Up @@ -312,10 +312,16 @@ def visit_register(

current_size = len(self._qubit_labels) if is_qubit else len(self._clbit_labels)
if is_qubit:
register_size = 1 if register.size is None else register.size.value
register_size = (
1 if register.size is None else register.size.value
) # type: ignore[union-attr]
else:
register_size = 1 if register.type.size is None else register.type.size.value
register_name = register.qubit.name if is_qubit else register.identifier.name
register_size = (
1 if register.type.size is None else register.type.size.value
) # type: ignore[union-attr]
register_name = (
register.qubit.name if is_qubit else register.identifier.name
) # type: ignore[union-attr]

size_map = self._global_qreg_size_map if is_qubit else self._global_creg_size_map
label_map = self._qubit_labels if is_qubit else self._clbit_labels
Expand Down Expand Up @@ -364,7 +370,7 @@ def _check_if_name_in_scope(self, name: str, operation: Any) -> None:

def _get_op_qubits(
self, operation: Any, qreg_size_map: dict, qir_form: bool = True
) -> list[Union[pyqir.qubit, qasm3_ast.IndexedIdentifier]]:
) -> list[Union[Callable[[pyqir.Context, int], pyqir.Constant], qasm3_ast.IndexedIdentifier]]:
"""Get the qubits for the operation.
Args:
Expand Down Expand Up @@ -395,6 +401,7 @@ def _get_op_qubits(
qreg_size = qreg_size_map[qreg_name]

if isinstance(qubit, qasm3_ast.IndexedIdentifier):
assert not isinstance(qubit.indices[0], qasm3_ast.DiscreteSet)
if isinstance(qubit.indices[0][0], qasm3_ast.RangeDefinition):
qids = Qasm3Transformer.get_qubits_from_range_definition(
qubit.indices[0][0], qreg_size, is_qubit_reg=True
Expand All @@ -411,6 +418,7 @@ def _get_op_qubits(
for i in qids
]
)

else:
qids = list(range(qreg_size))
openqasm_qubits.extend(
Expand Down Expand Up @@ -449,8 +457,9 @@ def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) -
source_id, target_id = None, None
# TODO: handle in-function measurements
source_name = source.name
if isinstance(source, qasm3_ast.IndexedIdentifier):
source_name = source.name.name
if isinstance(source_name, qasm3_ast.Identifier):
source_name = source_name.name
assert source
if isinstance(source.indices[0][0], qasm3_ast.RangeDefinition):
raise_qasm3_error(
f"Range based measurement {statement} not supported at the moment",
Expand All @@ -459,8 +468,9 @@ def _visit_measurement(self, statement: qasm3_ast.QuantumMeasurementStatement) -
source_id = source.indices[0][0].value

target_name = target.name
if isinstance(target, qasm3_ast.IndexedIdentifier):
target_name = target.name.name
if isinstance(target_name, qasm3_ast.Identifier):
target_name = target_name.name
assert target
if isinstance(target.indices[0][0], qasm3_ast.RangeDefinition):
raise_qasm3_error(
f"Range based measurement {statement} not supported at the moment",
Expand Down Expand Up @@ -699,15 +709,14 @@ def _visit_custom_gate_operation(
self._push_context(Context.GATE)

for gate_op in gate_definition_ops:
if gate_op.name.name == gate_name:
raise_qasm3_error(
f"Recursive definitions not allowed for gate {gate_name}", span=gate_op.span
)

# necessary to avoid modifying the original gate definition
# in case the gate is reapplied
gate_op_copy = copy.deepcopy(gate_op)
if isinstance(gate_op, qasm3_ast.QuantumGate):
# necessary to avoid modifying the original gate definition
# in case the gate is reapplied
gate_op_copy = copy.deepcopy(gate_op)
if gate_op.name.name == gate_name:
raise_qasm3_error(
f"Recursive definitions not allowed for gate {gate_name}", span=gate_op.span
)
Qasm3Transformer.transform_gate_params(gate_op_copy, param_map)
Qasm3Transformer.transform_gate_qubits(gate_op_copy, qubit_map)
# need to trickle the inverse down to the child gates
Expand Down Expand Up @@ -815,14 +824,16 @@ def _visit_constant_declaration(self, statement: qasm3_ast.ConstantDeclaration)
base_type = statement.type
if isinstance(base_type, qasm3_ast.BoolType):
base_size = 1
elif base_type.size is None:
base_size = 32 # default for now
else:
base_size = Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)
if not isinstance(base_size, int) or base_size <= 0:
raise_qasm3_error(
f"Invalid base size {base_size} for variable {var_name}", span=statement.span
)
elif hasattr(base_type, "size"):
if base_type.size is None:
base_size = 32 # default for now
else:
base_size = Qasm3ExprEvaluator.evaluate_expression(base_type.size, const_expr=True)
if not isinstance(base_size, int) or base_size <= 0:
raise_qasm3_error(
f"Invalid base size {base_size} for variable {var_name}",
span=statement.span,
)

variable = Variable(var_name, base_type, base_size, [], init_value, is_constant=True)

Expand Down Expand Up @@ -899,7 +910,7 @@ def _visit_classical_declaration(self, statement: qasm3_ast.ClassicalDeclaration
if not isinstance(base_type, qasm3_ast.BoolType):
base_size = (
32
if base_type.size is None
if not hasattr(base_type, "size") or base_type.size is None
else Qasm3ExprEvaluator.evaluate_expression(base_type.size)
)

Expand All @@ -917,6 +928,7 @@ def _visit_classical_declaration(self, statement: qasm3_ast.ClassicalDeclaration

if statement.init_expression:
if isinstance(init_value, list):
assert variable.dims is not None
Qasm3Validator.validate_array_assignment_values(variable, variable.dims, init_value)
else:
variable.value = Qasm3Validator.validate_variable_assignment_value(
Expand All @@ -937,42 +949,45 @@ def _visit_classical_assignment(self, statement: qasm3_ast.ClassicalAssignment)
lvalue = statement.lvalue
var_name = lvalue.name

if isinstance(lvalue, qasm3_ast.IndexedIdentifier):
if isinstance(var_name, qasm3_ast.Identifier):
var_name = var_name.name

var = self._get_from_visible_scope(var_name)

if var is None:
if var is None: # we check for none here, so type errors are irrelevant afterwards
raise_qasm3_error(f"Undefined variable {var_name} in assignment", span=statement.span)

if var.is_constant:
if var.is_constant: # type: ignore[union-attr]
raise_qasm3_error(
f"Assignment to constant variable {var_name} not allowed", span=statement.span
)

var_value = Qasm3ExprEvaluator.evaluate_expression(statement.rvalue)

# currently we support single array assignment only
# range based assignment not supported yet
# currently we support single array assignment only.
# TODO: range based assignment

# cast + validation
var_value = Qasm3Validator.validate_variable_assignment_value(var, var_value)

var_value = Qasm3Validator.validate_variable_assignment_value(
var, var_value # type: ignore[arg-type]
)
# handle assignment for arrays
if isinstance(lvalue, qasm3_ast.IndexedIdentifier):
# stupid indices structure in openqasm :/
if len(lvalue.indices[0]) > 1:
if len(lvalue.indices[0]) > 1: # type: ignore[arg-type]
indices = lvalue.indices[0]
else:
indices = [idx[0] for idx in lvalue.indices]

validated_indices = Qasm3Analyzer.analyze_classical_indices(
indices, self._get_from_visible_scope(var_name)
indices, self._get_from_visible_scope(var_name) # type: ignore[arg-type]
) # type: ignore[arg-type]
Qasm3Transformer.update_array_element(
var.value, validated_indices, var_value # type: ignore[union-attr, arg-type]
)
Qasm3Transformer.update_array_element(var.value, validated_indices, var_value)
else:
var.value = var_value
self._update_var_in_scope(var)
var.value = var_value # type: ignore[union-attr]
self._update_var_in_scope(var) # type: ignore[arg-type]

def _evaluate_array_initialization(
self, array_literal: qasm3_ast.ArrayLiteral, dimensions: list[int], base_type: Any
Expand Down Expand Up @@ -1328,7 +1343,9 @@ def _process_quantum_arg(formal_arg, actual_arg, formal_reg_name, actual_arg_nam
actual_arg_name = None
if isinstance(actual_arg, qasm3_ast.Identifier):
actual_arg_name = actual_arg.name
elif isinstance(actual_arg, qasm3_ast.IndexExpression):
elif isinstance(actual_arg, qasm3_ast.IndexExpression) and isinstance(
actual_arg.collection, qasm3_ast.Identifier
):
actual_arg_name = actual_arg.collection.name

if isinstance(formal_arg, qasm3_ast.ClassicalArgument):
Expand Down Expand Up @@ -1402,7 +1419,9 @@ def _visit_alias_statement(self, statement: qasm3_ast.AliasStatement) -> None:

if isinstance(value, qasm3_ast.Identifier):
aliased_reg_name = value.name
elif isinstance(value, qasm3_ast.IndexExpression):
elif isinstance(value, qasm3_ast.IndexExpression) and isinstance(
value.collection, qasm3_ast.Identifier
):
aliased_reg_name = value.collection.name
else:
raise_qasm3_error(f"Unsupported aliasing {statement}", span=statement.span)
Expand Down Expand Up @@ -1570,7 +1589,7 @@ def visit_statement(self, statement: qasm3_ast.Statement) -> None:
visitor_function = visit_map.get(type(statement))

if visitor_function:
visitor_function(statement)
visitor_function(statement) # type: ignore[operator]
else:
raise_qasm3_error(
f"Unsupported statement of type {type(statement)}", span=statement.span
Expand Down

0 comments on commit 9a53561

Please sign in to comment.