From bfac925a30b27d286baeb35f24638fdbe0a13e3c Mon Sep 17 00:00:00 2001 From: Roland Siegbert Date: Thu, 22 Aug 2024 23:48:41 +0200 Subject: [PATCH] Add `reset` to QasmParser Add a test which can serve as example including the modifications to support the `reset` keyword in the import of QASM files in the PLY based lexer/parser. --- cirq-core/cirq/contrib/qasm_import/_lexer.py | 5 + cirq-core/cirq/contrib/qasm_import/_parser.py | 278 ++++++++++++------ .../cirq/contrib/qasm_import/_parser_test.py | 28 ++ 3 files changed, 223 insertions(+), 88 deletions(-) diff --git a/cirq-core/cirq/contrib/qasm_import/_lexer.py b/cirq-core/cirq/contrib/qasm_import/_lexer.py index 206d9e88d74f..a62ee52f84e8 100644 --- a/cirq-core/cirq/contrib/qasm_import/_lexer.py +++ b/cirq-core/cirq/contrib/qasm_import/_lexer.py @@ -29,6 +29,7 @@ def __init__(self): reserved = { 'qreg': 'QREG', 'creg': 'CREG', + 'reset': 'RESET', 'measure': 'MEASURE', 'if': 'IF', '->': 'ARROW', @@ -91,6 +92,10 @@ def t_CREG(self, t): r"""creg""" return t + def t_RESET(self, t): + r"""reset""" + return t + def t_MEASURE(self, t): r"""measure""" return t diff --git a/cirq-core/cirq/contrib/qasm_import/_parser.py b/cirq-core/cirq/contrib/qasm_import/_parser.py index e7bcdae06db2..d264619c896e 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser.py @@ -14,7 +14,17 @@ import functools import operator -from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Union, + TYPE_CHECKING, +) import numpy as np import sympy @@ -34,7 +44,12 @@ class Qasm: """Qasm stores the final result of the Qasm parsing.""" def __init__( - self, supported_format: bool, qelib1_include: bool, qregs: dict, cregs: dict, c: Circuit + self, + supported_format: bool, + qelib1_include: bool, + qregs: dict, + cregs: dict, + c: Circuit, ): # defines whether the Quantum Experience standard header # is present or not @@ -109,7 +124,9 @@ def on( # the actual gate we'll apply the arguments to might be a parameterized # or non-parameterized gate final_gate: ops.Gate = ( - self.cirq_gate if isinstance(self.cirq_gate, ops.Gate) else self.cirq_gate(params) + self.cirq_gate + if isinstance(self.cirq_gate, ops.Gate) + else self.cirq_gate(params) ) # OpenQASM gates can be applied on single qubits and qubit registers. # We represent single qubits as registers of size 1. @@ -120,7 +137,10 @@ def on( # through each qubit of the registers 0 to n-1 and use the same one # qubit from the "single-qubit registers" for each operation. op_qubits = functools.reduce( - cast(Callable[[List['cirq.Qid'], List['cirq.Qid']], List['cirq.Qid']], np.broadcast), + cast( + Callable[[List["cirq.Qid"], List["cirq.Qid"]], List["cirq.Qid"]], + np.broadcast, + ), args, ) for qubits in op_qubits: @@ -152,29 +172,29 @@ def __init__(self) -> None: self.parsedQasm: Optional[Qasm] = None self.qubits: Dict[str, ops.Qid] = {} self.functions = { - 'sin': np.sin, - 'cos': np.cos, - 'tan': np.tan, - 'exp': np.exp, - 'ln': np.log, - 'sqrt': np.sqrt, - 'acos': np.arccos, - 'atan': np.arctan, - 'asin': np.arcsin, + "sin": np.sin, + "cos": np.cos, + "tan": np.tan, + "exp": np.exp, + "ln": np.log, + "sqrt": np.sqrt, + "acos": np.arccos, + "atan": np.arctan, + "asin": np.arcsin, } self.binary_operators = { - '+': operator.add, - '-': operator.sub, - '*': operator.mul, - '/': operator.truediv, - '^': operator.pow, + "+": operator.add, + "-": operator.sub, + "*": operator.mul, + "/": operator.truediv, + "^": operator.pow, } basic_gates: Dict[str, QasmGateStatement] = { - 'CX': QasmGateStatement(qasm_gate='CX', cirq_gate=CX, num_params=0, num_args=2), - 'U': QasmGateStatement( - qasm_gate='U', + "CX": QasmGateStatement(qasm_gate="CX", cirq_gate=CX, num_params=0, num_args=2), + "U": QasmGateStatement( + qasm_gate="U", num_params=3, num_args=1, # QasmUGate expects half turns @@ -183,81 +203,128 @@ def __init__(self) -> None: } qelib_gates = { - 'rx': QasmGateStatement( - qasm_gate='rx', cirq_gate=(lambda params: ops.rx(params[0])), num_params=1, num_args=1 + "rx": QasmGateStatement( + qasm_gate="rx", + cirq_gate=(lambda params: ops.rx(params[0])), + num_params=1, + num_args=1, ), - 'sx': QasmGateStatement( - qasm_gate='sx', num_params=0, num_args=1, cirq_gate=ops.XPowGate(exponent=0.5) + "sx": QasmGateStatement( + qasm_gate="sx", + num_params=0, + num_args=1, + cirq_gate=ops.XPowGate(exponent=0.5), ), - 'sxdg': QasmGateStatement( - qasm_gate='sxdg', num_params=0, num_args=1, cirq_gate=ops.XPowGate(exponent=-0.5) + "sxdg": QasmGateStatement( + qasm_gate="sxdg", + num_params=0, + num_args=1, + cirq_gate=ops.XPowGate(exponent=-0.5), ), - 'ry': QasmGateStatement( - qasm_gate='ry', cirq_gate=(lambda params: ops.ry(params[0])), num_params=1, num_args=1 + "ry": QasmGateStatement( + qasm_gate="ry", + cirq_gate=(lambda params: ops.ry(params[0])), + num_params=1, + num_args=1, ), - 'rz': QasmGateStatement( - qasm_gate='rz', cirq_gate=(lambda params: ops.rz(params[0])), num_params=1, num_args=1 + "rz": QasmGateStatement( + qasm_gate="rz", + cirq_gate=(lambda params: ops.rz(params[0])), + num_params=1, + num_args=1, ), - 'id': QasmGateStatement( - qasm_gate='id', cirq_gate=ops.IdentityGate(1), num_params=0, num_args=1 + "id": QasmGateStatement( + qasm_gate="id", cirq_gate=ops.IdentityGate(1), num_params=0, num_args=1 ), - 'u1': QasmGateStatement( - qasm_gate='u1', + "u1": QasmGateStatement( + qasm_gate="u1", cirq_gate=(lambda params: QasmUGate(0, 0, params[0] / np.pi)), num_params=1, num_args=1, ), - 'u2': QasmGateStatement( - qasm_gate='u2', - cirq_gate=(lambda params: QasmUGate(0.5, params[0] / np.pi, params[1] / np.pi)), + "u2": QasmGateStatement( + qasm_gate="u2", + cirq_gate=( + lambda params: QasmUGate(0.5, params[0] / np.pi, params[1] / np.pi) + ), num_params=2, num_args=1, ), - 'u3': QasmGateStatement( - qasm_gate='u3', + "u3": QasmGateStatement( + qasm_gate="u3", num_params=3, num_args=1, cirq_gate=(lambda params: QasmUGate(*[p / np.pi for p in params])), ), - 'r': QasmGateStatement( - qasm_gate='r', + "r": QasmGateStatement( + qasm_gate="r", num_params=2, num_args=1, cirq_gate=( lambda params: QasmUGate( - params[0] / np.pi, (params[1] / np.pi) - 0.5, (-params[1] / np.pi) + 0.5 + params[0] / np.pi, + (params[1] / np.pi) - 0.5, + (-params[1] / np.pi) + 0.5, ) ), ), - 'x': QasmGateStatement(qasm_gate='x', num_params=0, num_args=1, cirq_gate=ops.X), - 'y': QasmGateStatement(qasm_gate='y', num_params=0, num_args=1, cirq_gate=ops.Y), - 'z': QasmGateStatement(qasm_gate='z', num_params=0, num_args=1, cirq_gate=ops.Z), - 'h': QasmGateStatement(qasm_gate='h', num_params=0, num_args=1, cirq_gate=ops.H), - 's': QasmGateStatement(qasm_gate='s', num_params=0, num_args=1, cirq_gate=ops.S), - 't': QasmGateStatement(qasm_gate='t', num_params=0, num_args=1, cirq_gate=ops.T), - 'cx': QasmGateStatement(qasm_gate='cx', cirq_gate=CX, num_params=0, num_args=2), - 'cy': QasmGateStatement( - qasm_gate='cy', cirq_gate=ops.ControlledGate(ops.Y), num_params=0, num_args=2 + "x": QasmGateStatement( + qasm_gate="x", num_params=0, num_args=1, cirq_gate=ops.X + ), + "y": QasmGateStatement( + qasm_gate="y", num_params=0, num_args=1, cirq_gate=ops.Y + ), + "z": QasmGateStatement( + qasm_gate="z", num_params=0, num_args=1, cirq_gate=ops.Z + ), + "h": QasmGateStatement( + qasm_gate="h", num_params=0, num_args=1, cirq_gate=ops.H + ), + "s": QasmGateStatement( + qasm_gate="s", num_params=0, num_args=1, cirq_gate=ops.S + ), + "t": QasmGateStatement( + qasm_gate="t", num_params=0, num_args=1, cirq_gate=ops.T + ), + "cx": QasmGateStatement(qasm_gate="cx", cirq_gate=CX, num_params=0, num_args=2), + "cy": QasmGateStatement( + qasm_gate="cy", + cirq_gate=ops.ControlledGate(ops.Y), + num_params=0, + num_args=2, + ), + "cz": QasmGateStatement( + qasm_gate="cz", cirq_gate=ops.CZ, num_params=0, num_args=2 + ), + "ch": QasmGateStatement( + qasm_gate="ch", + cirq_gate=ops.ControlledGate(ops.H), + num_params=0, + num_args=2, + ), + "swap": QasmGateStatement( + qasm_gate="swap", cirq_gate=ops.SWAP, num_params=0, num_args=2 + ), + "cswap": QasmGateStatement( + qasm_gate="cswap", num_params=0, num_args=3, cirq_gate=ops.CSWAP ), - 'cz': QasmGateStatement(qasm_gate='cz', cirq_gate=ops.CZ, num_params=0, num_args=2), - 'ch': QasmGateStatement( - qasm_gate='ch', cirq_gate=ops.ControlledGate(ops.H), num_params=0, num_args=2 + "ccx": QasmGateStatement( + qasm_gate="ccx", num_params=0, num_args=3, cirq_gate=ops.CCX ), - 'swap': QasmGateStatement(qasm_gate='swap', cirq_gate=ops.SWAP, num_params=0, num_args=2), - 'cswap': QasmGateStatement( - qasm_gate='cswap', num_params=0, num_args=3, cirq_gate=ops.CSWAP + "sdg": QasmGateStatement( + qasm_gate="sdg", num_params=0, num_args=1, cirq_gate=ops.S**-1 + ), + "tdg": QasmGateStatement( + qasm_gate="tdg", num_params=0, num_args=1, cirq_gate=ops.T**-1 ), - 'ccx': QasmGateStatement(qasm_gate='ccx', num_params=0, num_args=3, cirq_gate=ops.CCX), - 'sdg': QasmGateStatement(qasm_gate='sdg', num_params=0, num_args=1, cirq_gate=ops.S**-1), - 'tdg': QasmGateStatement(qasm_gate='tdg', num_params=0, num_args=1, cirq_gate=ops.T**-1), } all_gates = {**basic_gates, **qelib_gates} tokens = QasmLexer.tokens - start = 'start' + start = "start" - precedence = (('left', '+', '-'), ('left', '*', '/'), ('right', '^')) + precedence = (("left", "+", "-"), ("left", "*", "/"), ("right", "^")) def p_start(self, p): """start : qasm""" @@ -266,7 +333,9 @@ def p_start(self, p): def p_qasm_format_only(self, p): """qasm : format""" self.supported_format = True - p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) + p[0] = Qasm( + self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit + ) def p_qasm_no_format_specified_error(self, p): """qasm : QELIBINC @@ -277,7 +346,9 @@ def p_qasm_no_format_specified_error(self, p): def p_qasm_include(self, p): """qasm : qasm QELIBINC""" self.qelibinc = True - p[0] = Qasm(self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit) + p[0] = Qasm( + self.supported_format, self.qelibinc, self.qregs, self.cregs, self.circuit + ) def p_qasm_circuit(self, p): """qasm : qasm circuit""" @@ -293,6 +364,7 @@ def p_format(self, p): # circuit : new_reg circuit # | gate_op circuit # | measurement circuit + # | reset circuit # | if circuit # | empty @@ -303,8 +375,13 @@ def p_circuit_reg(self, p): def p_circuit_gate_or_measurement_or_if(self, p): """circuit : circuit gate_op | circuit measurement - | circuit if""" - self.circuit.append(p[2]) + | circuit reset + | circuit if + | circuit new_reg""" + if isinstance(p[2], list): + self.circuit.append(p[2]) + else: + self.circuit.append([p[2]]) p[0] = self.circuit def p_circuit_empty(self, p): @@ -320,11 +397,14 @@ def p_new_reg(self, p): if name in self.qregs.keys() or name in self.cregs.keys(): raise QasmException(f"{name} is already defined at line {p.lineno(2)}") if length == 0: - raise QasmException(f"Illegal, zero-length register '{name}' at line {p.lineno(4)}") + raise QasmException( + f"Illegal, zero-length register '{name}' at line {p.lineno(4)}" + ) if p[1] == "qreg": self.qregs[name] = length else: self.cregs[name] = length + print("Calling qrags to create a qubit") p[0] = (name, length) # gate operations @@ -378,13 +458,15 @@ def p_expr_function_call(self, p): """expr : ID '(' expr ')'""" func = p[1] if func not in self.functions.keys(): - raise QasmException(f"Function not recognized: '{func}' at line {p.lineno(1)}") + raise QasmException( + f"Function not recognized: '{func}' at line {p.lineno(1)}" + ) p[0] = self.functions[func](p[3]) def p_expr_unary(self, p): """expr : '-' expr | '+' expr""" - if p[1] == '-': + if p[1] == "-": p[0] = -p[2] else: p[0] = p[2] @@ -423,7 +505,9 @@ def p_quantum_arg_register(self, p): """qarg : ID""" reg = p[1] if reg not in self.qregs.keys(): - raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') + raise QasmException( + f'Undefined quantum register "{reg}" at line {p.lineno(1)}' + ) qubits = [] for idx in range(self.qregs[reg]): arg_name = self.make_name(idx, reg) @@ -439,7 +523,9 @@ def p_classical_arg_register(self, p): """carg : ID""" reg = p[1] if reg not in self.cregs.keys(): - raise QasmException(f'Undefined classical register "{reg}" at line {p.lineno(1)}') + raise QasmException( + f'Undefined classical register "{reg}" at line {p.lineno(1)}' + ) p[0] = [self.make_name(idx, reg) for idx in range(self.cregs[reg])] @@ -452,13 +538,15 @@ def p_quantum_arg_bit(self, p): idx = p[3] arg_name = self.make_name(idx, reg) if reg not in self.qregs.keys(): - raise QasmException(f'Undefined quantum register "{reg}" at line {p.lineno(1)}') + raise QasmException( + f'Undefined quantum register "{reg}" at line {p.lineno(1)}' + ) size = self.qregs[reg] if idx >= size: raise QasmException( - f'Out of bounds qubit index {idx} ' - f'on register {reg} of size {size} ' - f'at line {p.lineno(1)}' + f"Out of bounds qubit index {idx} " + f"on register {reg} of size {size} " + f"at line {p.lineno(1)}" ) if arg_name not in self.qubits.keys(): self.qubits[arg_name] = NamedQubit(arg_name) @@ -470,14 +558,16 @@ def p_classical_arg_bit(self, p): idx = p[3] arg_name = self.make_name(idx, reg) if reg not in self.cregs.keys(): - raise QasmException(f'Undefined classical register "{reg}" at line {p.lineno(1)}') + raise QasmException( + f'Undefined classical register "{reg}" at line {p.lineno(1)}' + ) size = self.cregs[reg] if idx >= size: raise QasmException( - f'Out of bounds bit index {idx} ' - f'on classical register {reg} of size {size} ' - f'at line {p.lineno(1)}' + f"Out of bounds bit index {idx} " + f"on classical register {reg} of size {size} " + f"at line {p.lineno(1)}" ) p[0] = [arg_name] @@ -491,14 +581,24 @@ def p_measurement(self, p): if len(qreg) != len(creg): raise QasmException( - f'mismatched register sizes {len(qreg)} -> {len(creg)} for measurement ' - f'at line {p.lineno(1)}' + f"mismatched register sizes {len(qreg)} -> {len(creg)} for measurement " + f"at line {p.lineno(1)}" ) p[0] = [ - ops.MeasurementGate(num_qubits=1, key=creg[i]).on(qreg[i]) for i in range(len(qreg)) + ops.MeasurementGate(num_qubits=1, key=creg[i]).on(qreg[i]) + for i in range(len(qreg)) ] + # reset operations + # reset : RESET qarg + + def p_reset(self, p): + """reset : RESET qarg ';'""" + qreg = p[2] + + p[0] = [ops.ResetChannel().on(qreg[i]) for i in range(len(qreg))] + # if operations # if : IF '(' carg EQ NATURAL_NUMBER ')' ID qargs @@ -511,12 +611,14 @@ def p_if(self, p): v = (p[5] >> i) & 1 conditions.append(sympy.Eq(sympy.Symbol(key), v)) p[0] = [ - ops.ClassicallyControlledOperation(conditions=conditions, sub_operation=tuple(p[7])[0]) + ops.ClassicallyControlledOperation( + conditions=conditions, sub_operation=tuple(p[7])[0] + ) ] def p_error(self, p): if p is None: - raise QasmException('Unexpected end of file') + raise QasmException("Unexpected end of file") raise QasmException( f"""Syntax error: '{p.value}' @@ -525,7 +627,7 @@ def p_error(self, p): ) def find_column(self, p): - line_start = self.qasm.rfind('\n', 0, p.lexpos) + 1 + line_start = self.qasm.rfind("\n", 0, p.lexpos) + 1 return (p.lexpos - line_start) + 1 def p_empty(self, p): @@ -539,8 +641,8 @@ def parse(self, qasm: str) -> Qasm: return self.parsedQasm def debug_context(self, p): - debug_start = max(self.qasm.rfind('\n', 0, p.lexpos) + 1, p.lexpos - 5) - debug_end = min(self.qasm.find('\n', p.lexpos, p.lexpos + 5), p.lexpos + 5) + debug_start = max(self.qasm.rfind("\n", 0, p.lexpos) + 1, p.lexpos - 5) + debug_end = min(self.qasm.find("\n", p.lexpos, p.lexpos + 5), p.lexpos + 5) return ( "..." diff --git a/cirq-core/cirq/contrib/qasm_import/_parser_test.py b/cirq-core/cirq/contrib/qasm_import/_parser_test.py index 4b0ca8e50f1e..e334f2ccb856 100644 --- a/cirq-core/cirq/contrib/qasm_import/_parser_test.py +++ b/cirq-core/cirq/contrib/qasm_import/_parser_test.py @@ -408,6 +408,34 @@ def test_U_gate_too_much_params_error(): with pytest.raises(QasmException, match=r"U takes 3.*got.*4.*line 3"): parser.parse(qasm) +def test_reset(): + qasm =""" + OPENQASM 2.0; + include "qelib1.inc"; + qreg q[1]; + creg c[1]; + x q[0]; + reset q[0]; + measure q[0] -> c[0]; + """ + + parser = QasmParser() + + q_0 = cirq.NamedQubit('q_0') + + expected_circuit = Circuit() + expected_circuit.append(cirq.X(q_0)) + expected_circuit.append(cirq.ResetChannel().on(q_0)) + expected_circuit.append(cirq.MeasurementGate(num_qubits=1, key='c_0').on(q_0)) + + parsed_qasm = parser.parse(qasm) + + assert parsed_qasm.supportedFormat + assert parsed_qasm.qelib1Include + + ct.assert_same_circuits(parsed_qasm.circuit, expected_circuit) + assert parsed_qasm.qregs == {'q': 1} + assert parsed_qasm.cregs == {'c': 1} @pytest.mark.parametrize( 'expr',