diff --git a/frontend/test/pytest/test_peephole_optimizations.py b/frontend/test/pytest/test_peephole_optimizations.py index 59970cd028..11b6b17a46 100644 --- a/frontend/test/pytest/test_peephole_optimizations.py +++ b/frontend/test/pytest/test_peephole_optimizations.py @@ -23,6 +23,60 @@ # pylint: disable=missing-function-docstring +# +# Complex_merging_rotations +# + +@pytest.mark.parametrize("params1, params2", [ + ((0.5, 1.0, 1.5), (0.6, 0.8, 0.7)), + ((np.pi / 2, np.pi / 4, np.pi / 6), (np.pi, 3 * np.pi / 4, np.pi / 3)) +]) +def test_complex_merge_rotation(params1, params2, backend): + + # Test for qml.Rot + def create_rot_circuit(): + """Helper function to create qml.Rot circuit for testing.""" + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + qml.Rot(params1[0], params1[1], params1[2], wires=0) + qml.Rot(params2[0], params2[1], params2[2], wires=0) + return qml.probs() + return circuit + + # Create unmerged and merged circuits + unmerged_rot_circuit = create_rot_circuit() + merged_rot_circuit = qml.transforms.merge_rotations(create_rot_circuit()) + + # Verify that the circuits produce the same results + assert np.allclose(unmerged_rot_circuit(), merged_rot_circuit()), "Merged result for qml.Rot differs from unmerged." + + # Check if the merged circuit has fewer rotation gates + unmerged_rot_count = sum(1 for op in unmerged_rot_circuit.tape.operations if op.name == "Rot") + merged_rot_count = sum(1 for op in merged_rot_circuit.tape.operations if op.name == "Rot") + assert merged_rot_count < unmerged_rot_count, "Rotation gates were not merged in qml.Rot." + + # Test for qml.CRot + def create_crot_circuit(): + """Helper function to create qml.CRot circuit for testing.""" + @qml.qnode(qml.device(backend, wires=2)) + def circuit(): + qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1]) + qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1]) + return qml.probs() + return circuit + + # Create unmerged and merged circuits for qml.CRot + unmerged_crot_circuit = create_crot_circuit() + merged_crot_circuit = qml.transforms.merge_rotations(create_crot_circuit()) + + # Verify that the circuits produce the same results + assert np.allclose(unmerged_crot_circuit(), merged_crot_circuit()), "Merged result for qml.CRot differs from unmerged." + + # Check if the merged circuit has fewer controlled rotation gates + unmerged_crot_count = sum(1 for op in unmerged_crot_circuit.tape.operations if op.name == "CRot") + merged_crot_count = sum(1 for op in merged_crot_circuit.tape.operations if op.name == "CRot") + assert merged_crot_count < unmerged_crot_count, "Controlled rotation gates were not merged in qml.CRot." + # # cancel_inverses diff --git a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp index d449773e1b..63497bb52e 100644 --- a/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp +++ b/mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp @@ -21,13 +21,16 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Errc.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Arith/IR/Arith.h" using llvm::dbgs; using namespace mlir; using namespace catalyst::quantum; static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift", - "CRX", "CRY", "CRZ", "ControlledPhaseShift"}; + "CRX", "CRY", "CRZ", "ControlledPhaseShift", + "qml.Rot", "qml.CRot"}; namespace { @@ -49,27 +52,111 @@ struct MergeRotationsRewritePattern : public mlir::OpRewritePattern { return failure(); } - TypeRange outQubitsTypes = op.getOutQubits().getTypes(); - TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); - ValueRange parentInQubits = parentOp.getInQubits(); - ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); - ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); - - auto parentParams = parentOp.getParams(); - auto params = op.getParams(); - SmallVector sumParams; - for (auto [param, parentParam] : llvm::zip(params, parentParams)) { - mlir::Value sumParam = - rewriter.create(loc, parentParam, param).getResult(); - sumParams.push_back(sumParam); - }; - auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, - parentInQubits, opGateName, nullptr, - parentInCtrlQubits, parentInCtrlValues); - - op.replaceAllUsesWith(mergeOp); - - return success(); + if (opGateName == "qml.Rot" || opGateName == "qml.CRot") { + + if (op.getAdjoint()) { + LLVM_DEBUG(dbgs() << "Skipping adjointed operation:\n" << op << "\n"); + return failure(); + } + + LLVM_DEBUG(dbgs() << "Applying scalar formula for combined rotation operation:\n" << op << "\n"); + auto params = op.getParams(); + auto parentParams = parentOp.getParams(); + + // Assuming params[0] = alpha1, params[1] = theta1, params[2] = beta1 + // and parentParams[0] = alpha2, parentParams[1] = theta2, parentParams[2] = beta2 + + // Step 1: Calculate c1, c2, s1, s2 + auto c1 = rewriter.create(loc, params[1]); + auto s1 = rewriter.create(loc, params[1]); + auto c2 = rewriter.create(loc, parentParams[1]); + auto s2 = rewriter.create(loc, parentParams[1]); + + // Step 2: Calculate cf + auto c1Squared = rewriter.create(loc, c1, c1); + auto c2Squared = rewriter.create(loc, c2, c2); + auto s1Squared = rewriter.create(loc, s1, s1); + auto s2Squared = rewriter.create(loc, s2, s2); + auto cosAlphaDiff = rewriter.create(loc, rewriter.create(loc, params[0], parentParams[0])); + + auto term1 = rewriter.create(loc, c1Squared, c2Squared); + auto term2 = rewriter.create(loc, s1Squared, s2Squared); + auto product = rewriter.create(loc, c1, c2); + product = rewriter.create(loc, product, s1); + product = rewriter.create(loc, product, s2); + auto two = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + auto term3 = rewriter.create(loc, two, rewriter.create(loc, product, cosAlphaDiff)); + + auto cfSquare = rewriter.create(loc, rewriter.create(loc, term1, term2), term3); + auto cf = rewriter.create(loc, cfSquare); + + // Step 3: Calculate theta_f = 2 * arccos(|cf|) + auto absCf = rewriter.create(loc, cf); + auto acosCf = rewriter.create(loc, absCf); + auto thetaF = rewriter.create(loc, two, acosCf); + + // Step 4: Calculate alpha_f + auto alphaSum = rewriter.create(loc, params[0], parentParams[0]); + auto betaDiff = rewriter.create(loc, parentParams[2], params[2]); + auto sinAlphaSum = rewriter.create(loc, alphaSum); + auto cosBetaDiff = rewriter.create(loc, betaDiff); + + auto term1_alpha = rewriter.create(loc, rewriter.create(loc, c1, s2), sinAlphaSum); + auto term2_alpha = rewriter.create(loc, rewriter.create(loc, s1, s2), cosBetaDiff); + auto numerator_alpha = rewriter.create(loc, rewriter.create(loc, term1_alpha), term2_alpha); + + auto cosAlphaSum = rewriter.create(loc, alphaSum); + auto denominator_alpha = rewriter.create(loc, rewriter.create(loc, rewriter.create(loc, c1, c2), cosAlphaSum), term2_alpha); + + auto alphaF = rewriter.create(loc, rewriter.create(loc, rewriter.create(loc, numerator_alpha, denominator_alpha))); + + // Step 5: Calculate beta_f + auto betaSum = rewriter.create(loc, params[2], parentParams[2]); + auto alphaDiffReversed = rewriter.create(loc, parentParams[0], params[0]); + auto sinBetaSum = rewriter.create(loc, betaSum); + auto cosAlphaDiffReversed = rewriter.create(loc, alphaDiffReversed); + + auto term1_beta = rewriter.create(loc, rewriter.create(loc, c1, s2), sinBetaSum); + auto term2_beta = rewriter.create(loc, rewriter.create(loc, s1, s2), cosAlphaDiffReversed); + auto numerator_beta = rewriter.create(loc, rewriter.create(loc, term1_beta), term2_beta); + + auto denominator_beta = denominator_alpha; // Reuse from alpha calculation if applicable + auto betaF = rewriter.create(loc, rewriter.create(loc, rewriter.create(loc, numerator_beta, denominator_beta))); + + // Step 6: Output angles (phi_f, theta_f, omega_f) + // Assign phi_f = alphaF, theta_f = thetaF, omega_f = betaF as the final values + SmallVector combinedAngles = {alphaF, thetaF, betaF}; + auto outQubitsTypes = op.getOutQubits().getTypes(); + auto outCtrlQubitsTypes = op.getOutCtrlQubits().getTypes(); + auto inQubits = op.getInQubits(); + auto inCtrlQubits = op.getInCtrlQubits(); + auto inCtrlValues = op.getInCtrlValues(); + rewriter.replaceOpWithNewOp(op, outQubitsTypes, outCtrlQubitsTypes, combinedAngles, inQubits, opGateName, nullptr, inCtrlQubits, inCtrlValues); + + return success(); + } + else { + TypeRange outQubitsTypes = op.getOutQubits().getTypes(); + TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); + ValueRange parentInQubits = parentOp.getInQubits(); + ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); + ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); + auto parentParams = parentOp.getParams(); + auto params = op.getParams(); + SmallVector sumParams; + for (auto [param, parentParam] : llvm::zip(params, parentParams)) { + mlir::Value sumParam = + rewriter.create(loc, parentParam, param).getResult(); + sumParams.push_back(sumParam); + }; + auto mergeOp = rewriter.create(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, + parentInQubits, opGateName, nullptr, + parentInCtrlQubits, parentInCtrlValues); + + op.replaceAllUsesWith(mergeOp); + + return success(); + } } };