Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added merge rotation patterns for qml.Rot and qml.CRot #1270

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
54 changes: 54 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,60 @@

# pylint: disable=missing-function-docstring

#
# Complex_merging_rotations
#

@pytest.mark.parametrize("params1, params2", [
Copy link
Contributor

@paul0403 paul0403 Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So testing this new merge rotation pattern is a bit tricky: we know that regardless of whether the merge rotation transformation took effect or not, the circuit will produce the same results. Given that, does this test here actually test for whether the rotation gates are merged? If not, what is the best way to test that the rotation gates are merged, and what is the purpose of these end-to-end circuit execution tests here?

Hint: search through the code base and look for how the existing merge rotation patterns are tested!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403

I’ve added explicit checks in test_complex_merge_rotation to verify that the merge_rotations transformation reduces the number of rotation gates (Rot and CRot) and preserves the circuit's functionality. By explicitly calling qml.transforms.merge_rotations, we can compare the unoptimized and optimized circuits directly. This allows us to confirm both that the rotation gates are actually merged (fewer gates) and that the results remain the same, addressing the need for both functionality and transformation verification in the test. Please let me know if this method is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! An additional approach on top of plain functionality check is definitely good to have 💯

((0.5, 1.0, 1.5), (0.6, 0.8, 0.7)),

Check notice on line 31 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L31

Trailing whitespace (trailing-whitespace)
((np.pi / 2, np.pi / 4, np.pi / 6), (np.pi, 3 * np.pi / 4, np.pi / 3))

Check notice on line 32 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L32

Trailing whitespace (trailing-whitespace)
])
def test_complex_merge_rotation(params1, params2, backend):

# Test for qml.Rot
def create_rot_circuit():
"""Helper function to create qml.Rot circuit for testing."""
@qml.qnode(qml.device(backend, wires=1))
def circuit():
qml.Rot(params1[0], params1[1], params1[2], wires=0)
qml.Rot(params2[0], params2[1], params2[2], wires=0)
return qml.probs()
return circuit

# Create unmerged and merged circuits
unmerged_rot_circuit = create_rot_circuit()
merged_rot_circuit = qml.transforms.merge_rotations(create_rot_circuit())

# Verify that the circuits produce the same results
assert np.allclose(unmerged_rot_circuit(), merged_rot_circuit()), "Merged result for qml.Rot differs from unmerged."

Check notice on line 51 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L51

Line too long (120/100) (line-too-long)

# Check if the merged circuit has fewer rotation gates
unmerged_rot_count = sum(1 for op in unmerged_rot_circuit.tape.operations if op.name == "Rot")
merged_rot_count = sum(1 for op in merged_rot_circuit.tape.operations if op.name == "Rot")
assert merged_rot_count < unmerged_rot_count, "Rotation gates were not merged in qml.Rot."

# Test for qml.CRot
def create_crot_circuit():
"""Helper function to create qml.CRot circuit for testing."""
@qml.qnode(qml.device(backend, wires=2))
def circuit():
qml.CRot(params1[0], params1[1], params1[2], wires=[0, 1])
qml.CRot(params2[0], params2[1], params2[2], wires=[0, 1])
return qml.probs()
return circuit

# Create unmerged and merged circuits for qml.CRot
unmerged_crot_circuit = create_crot_circuit()
merged_crot_circuit = qml.transforms.merge_rotations(create_crot_circuit())

# Verify that the circuits produce the same results
assert np.allclose(unmerged_crot_circuit(), merged_crot_circuit()), "Merged result for qml.CRot differs from unmerged."

Check notice on line 73 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L73

Line too long (123/100) (line-too-long)

# Check if the merged circuit has fewer controlled rotation gates
unmerged_crot_count = sum(1 for op in unmerged_crot_circuit.tape.operations if op.name == "CRot")

Check notice on line 76 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L76

Line too long (101/100) (line-too-long)
merged_crot_count = sum(1 for op in merged_crot_circuit.tape.operations if op.name == "CRot")
assert merged_crot_count < unmerged_crot_count, "Controlled rotation gates were not merged in qml.CRot."

Check notice on line 78 in frontend/test/pytest/test_peephole_optimizations.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/test_peephole_optimizations.py#L78

Line too long (108/100) (line-too-long)


#
# cancel_inverses
Expand Down
131 changes: 109 additions & 22 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

using llvm::dbgs;
using namespace mlir;
using namespace catalyst::quantum;

static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift",
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};
"CRX", "CRY", "CRZ", "ControlledPhaseShift",
"qml.Rot", "qml.CRot"};

namespace {

Expand All @@ -49,27 +52,111 @@
return failure();
}

TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

auto parentParams = parentOp.getParams();
auto params = op.getParams();
SmallVector<mlir::Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, opGateName, nullptr,
parentInCtrlQubits, parentInCtrlValues);

op.replaceAllUsesWith(mergeOp);

return success();
if (opGateName == "qml.Rot" || opGateName == "qml.CRot") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the rotation gates are adjointed? Should the merge still happen?

(In Catalyst adjointed gates are indicated by a adjoint unit attribute, see for example here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403

Please let me know if you did not consider whether the rotation gates are adjointed for regular merging rotations (not for merging non-commutative rotations), or if we do not have adjointed gates for regular merging rotations?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge rotation pass applies an rotation gate adjoint canonicalization. The canonicalization simply changes all angles to their negative and removes the adjoint attribute. See #1205

However, looking at the canonicalization pattern, you will find that Rot and CRot are not canonicalized.
(a) Why do you think that is?
(b) Knowing this, what do you think you should do in your added pattern here (assuming no new canonicalization is added for (C)Rot)?

Copy link
Author

@Mohxen Mohxen Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @paul0403
(a) Rot and CRot are not canonicalized because they involve complex, multi-parameter rotations that cannot be standardized by simply negating a single parameter.
(b) First, check for the adjoint attribute on qml.Rot and qml.CRot. If the operation is adjointed, transform it into its non-adjointed, canonical form by reversing the order of the parameters and negating each parameter. After this transformation, remove the adjoint attribute to standardize the operation. Then, proceed with the merging process as if all rotations are in canonical form, ensuring consistency across operations.
Thus, if I have an operation qml.Rot(π/4, π/2, π/3), its adjoint will be qml.Rot(-π/3, -π/2, -π/4)
Please let me know if my method is correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good insights! This is what I would do as well.

Due to how the work is organized (aka adjoint canonicalization happens somewhere else, not here in the merge rotation patterns), in the merge rotation patterns, it suffices to assume that the rotation gates will not carry adjoint attributes when the patterns are hit.

Thus the only thing needed here is a check that the (C)Rot gates do not carry adjoint attributes. If they do, the pattern should do nothing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, added it.


Check notice on line 56 in mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp#L56

Redundant blank line at the start of a code block should be deleted. (whitespace/blank_line)
if (op.getAdjoint()) {
LLVM_DEBUG(dbgs() << "Skipping adjointed operation:\n" << op << "\n");
return failure();
}

LLVM_DEBUG(dbgs() << "Applying scalar formula for combined rotation operation:\n" << op << "\n");
auto params = op.getParams();
auto parentParams = parentOp.getParams();

// Assuming params[0] = alpha1, params[1] = theta1, params[2] = beta1
// and parentParams[0] = alpha2, parentParams[1] = theta2, parentParams[2] = beta2

// Step 1: Calculate c1, c2, s1, s2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the new merge rotation pattern! The formula is very long, so we appreciate the good work 🥳

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your help :) If there's anything specific you'd like me to refine or expand on, please let me know.

auto c1 = rewriter.create<math::CosOp>(loc, params[1]);
auto s1 = rewriter.create<math::SinOp>(loc, params[1]);
auto c2 = rewriter.create<math::CosOp>(loc, parentParams[1]);
auto s2 = rewriter.create<math::SinOp>(loc, parentParams[1]);

// Step 2: Calculate cf
auto c1Squared = rewriter.create<arith::MulFOp>(loc, c1, c1);
auto c2Squared = rewriter.create<arith::MulFOp>(loc, c2, c2);
auto s1Squared = rewriter.create<arith::MulFOp>(loc, s1, s1);
auto s2Squared = rewriter.create<arith::MulFOp>(loc, s2, s2);
auto cosAlphaDiff = rewriter.create<math::CosOp>(loc, rewriter.create<arith::SubFOp>(loc, params[0], parentParams[0]));

auto term1 = rewriter.create<arith::MulFOp>(loc, c1Squared, c2Squared);
auto term2 = rewriter.create<arith::MulFOp>(loc, s1Squared, s2Squared);
auto product = rewriter.create<arith::MulFOp>(loc, c1, c2);
product = rewriter.create<arith::MulFOp>(loc, product, s1);
product = rewriter.create<arith::MulFOp>(loc, product, s2);
auto two = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(2.0));
auto term3 = rewriter.create<arith::MulFOp>(loc, two, rewriter.create<arith::MulFOp>(loc, product, cosAlphaDiff));

auto cfSquare = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::AddFOp>(loc, term1, term2), term3);
auto cf = rewriter.create<math::SqrtOp>(loc, cfSquare);

// Step 3: Calculate theta_f = 2 * arccos(|cf|)
auto absCf = rewriter.create<math::AbsFOp>(loc, cf);
auto acosCf = rewriter.create<math::AcosOp>(loc, absCf);
auto thetaF = rewriter.create<arith::MulFOp>(loc, two, acosCf);

// Step 4: Calculate alpha_f
auto alphaSum = rewriter.create<arith::AddFOp>(loc, params[0], parentParams[0]);
auto betaDiff = rewriter.create<arith::SubFOp>(loc, parentParams[2], params[2]);
auto sinAlphaSum = rewriter.create<math::SinOp>(loc, alphaSum);
auto cosBetaDiff = rewriter.create<math::CosOp>(loc, betaDiff);

auto term1_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinAlphaSum);
auto term2_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosBetaDiff);
auto numerator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_alpha), term2_alpha);

auto cosAlphaSum = rewriter.create<math::CosOp>(loc, alphaSum);
auto denominator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, c2), cosAlphaSum), term2_alpha);

auto alphaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_alpha, denominator_alpha)));

// Step 5: Calculate beta_f
auto betaSum = rewriter.create<arith::AddFOp>(loc, params[2], parentParams[2]);
auto alphaDiffReversed = rewriter.create<arith::SubFOp>(loc, parentParams[0], params[0]);
auto sinBetaSum = rewriter.create<math::SinOp>(loc, betaSum);
auto cosAlphaDiffReversed = rewriter.create<math::CosOp>(loc, alphaDiffReversed);

auto term1_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinBetaSum);
auto term2_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosAlphaDiffReversed);
auto numerator_beta = rewriter.create<arith::AddFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_beta), term2_beta);

auto denominator_beta = denominator_alpha; // Reuse from alpha calculation if applicable
auto betaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_beta, denominator_beta)));

// Step 6: Output angles (phi_f, theta_f, omega_f)
// Assign phi_f = alphaF, theta_f = thetaF, omega_f = betaF as the final values
SmallVector<mlir::Value> combinedAngles = {alphaF, thetaF, betaF};
auto outQubitsTypes = op.getOutQubits().getTypes();
auto outCtrlQubitsTypes = op.getOutCtrlQubits().getTypes();
auto inQubits = op.getInQubits();
auto inCtrlQubits = op.getInCtrlQubits();
auto inCtrlValues = op.getInCtrlValues();
rewriter.replaceOpWithNewOp<CustomOp>(op, outQubitsTypes, outCtrlQubitsTypes, combinedAngles, inQubits, opGateName, nullptr, inCtrlQubits, inCtrlValues);

return success();
}
else {
TypeRange outQubitsTypes = op.getOutQubits().getTypes();
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
ValueRange parentInQubits = parentOp.getInQubits();
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();
auto parentParams = parentOp.getParams();
auto params = op.getParams();
SmallVector<mlir::Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, opGateName, nullptr,
parentInCtrlQubits, parentInCtrlValues);

op.replaceAllUsesWith(mergeOp);

return success();
}
}
};

Expand Down
Loading