Skip to content

Commit

Permalink
update test_peephole_optimizations.py and MergeRotationsPatterns.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohxen committed Nov 4, 2024
1 parent d4dd01a commit a1d5504
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 45 deletions.
73 changes: 28 additions & 45 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,73 +27,56 @@
# Complex_merging_rotations
#

# Parameterize test with different angle sets for qml.Rot and qml.CRot to ensure coverage of complex cases.
@pytest.mark.parametrize("params1, params2", [
((0.5, 1.0, 1.5), (0.6, 0.8, 0.7)), # Arbitrary angles for general coverage
((np.pi / 2, np.pi / 4, np.pi / 6), (np.pi, 3 * np.pi / 4, np.pi / 3)) # Important angles with multiples of π
((0.5, 1.0, 1.5), (0.6, 0.8, 0.7)),

Check notice on line 31 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L31

Trailing whitespace (trailing-whitespace)
((np.pi / 2, np.pi / 4, np.pi / 6), (np.pi, 3 * np.pi / 4, np.pi / 3))

Check notice on line 32 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L32

Trailing whitespace (trailing-whitespace)
])
def test_complex_merge_rotation(params1, params2, backend):
"""Comprehensive test for complex merge rotations with qml.Rot and qml.CRot using full-angle formulas."""

# Test for qml.Rot
@qjit
def rot_workflow():
@qml.qnode(qml.device(backend, wires=1))
def f():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()

@merge_rotations
def create_rot_circuit():
"""Helper function to create qml.Rot circuit for testing."""
@qml.qnode(qml.device(backend, wires=1))
def g():
def circuit():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()
return circuit

return f(), g()
# Create unmerged and merged circuits
unmerged_rot_circuit = create_rot_circuit()
merged_rot_circuit = qml.transforms.merge_rotations(create_rot_circuit())

# Reference function for qml.Rot without merging
@qml.qnode(qml.device("default.qubit", wires=1))
def rot_reference():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()
# Verify that the circuits produce the same results
assert np.allclose(unmerged_rot_circuit(), merged_rot_circuit()), "Merged result for qml.Rot differs from unmerged."

Check notice on line 51 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L51

Line too long (120/100) (line-too-long)

# Verify results for qml.Rot
rot_results = rot_workflow()
assert np.allclose(rot_results[0], rot_results[1]), "Merged result for qml.Rot differs from unmerged."
assert np.allclose(rot_results[1], rot_reference()), "Merged result for qml.Rot differs from reference."
# Check if the merged circuit has fewer rotation gates
unmerged_rot_count = sum(1 for op in unmerged_rot_circuit.tape.operations if op.name == "Rot")
merged_rot_count = sum(1 for op in merged_rot_circuit.tape.operations if op.name == "Rot")
assert merged_rot_count < unmerged_rot_count, "Rotation gates were not merged in qml.Rot."

# Test for qml.CRot
@qjit
def crot_workflow():
def create_crot_circuit():
"""Helper function to create qml.CRot circuit for testing."""
@qml.qnode(qml.device(backend, wires=2))
def f():
def circuit():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()
return circuit

@merge_rotations
@qml.qnode(qml.device(backend, wires=2))
def g():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()
# Create unmerged and merged circuits for qml.CRot
unmerged_crot_circuit = create_crot_circuit()
merged_crot_circuit = qml.transforms.merge_rotations(create_crot_circuit())

return f(), g()
# Verify that the circuits produce the same results
assert np.allclose(unmerged_crot_circuit(), merged_crot_circuit()), "Merged result for qml.CRot differs from unmerged."

Check notice on line 73 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L73

Line too long (123/100) (line-too-long)

# Reference function for qml.CRot without merging
@qml.qnode(qml.device("default.qubit", wires=2))
def crot_reference():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()
# Check if the merged circuit has fewer controlled rotation gates
unmerged_crot_count = sum(1 for op in unmerged_crot_circuit.tape.operations if op.name == "CRot")

Check notice on line 76 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L76

Line too long (101/100) (line-too-long)
merged_crot_count = sum(1 for op in merged_crot_circuit.tape.operations if op.name == "CRot")
assert merged_crot_count < unmerged_crot_count, "Controlled rotation gates were not merged in qml.CRot."

Check notice on line 78 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L78

Line too long (108/100) (line-too-long)

# Verify results for qml.CRot
crot_results = crot_workflow()
assert np.allclose(crot_results[0], crot_results[1]), "Merged result for qml.CRot differs from unmerged."
assert np.allclose(crot_results[1], crot_reference()), "Merged result for qml.CRot differs from reference."

#
# cancel_inverses
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern<CustomOp> {
}

if (opGateName == "qml.Rot" || opGateName == "qml.CRot") {

Check notice on line 56 in mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp#L56

Redundant blank line at the start of a code block should be deleted. (whitespace/blank_line)
if (op.getAdjoint()) {
LLVM_DEBUG(dbgs() << "Skipping adjointed operation:\n" << op << "\n");
return failure();
}

LLVM_DEBUG(dbgs() << "Applying scalar formula for combined rotation operation:\n" << op << "\n");
auto params = op.getParams();
auto parentParams = parentOp.getParams();
Expand Down

0 comments on commit a1d5504

Please sign in to comment.