-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added merge rotation patterns for qml.Rot and qml.CRot #1270
base: main
Are you sure you want to change the base?
Changes from all commits
f312d08
840c609
54fb65c
6014868
51ae22d
c323bf0
6810996
d4dd01a
a1d5504
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,13 +21,16 @@ | |
#include "llvm/ADT/StringSet.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/Errc.h" | ||
#include "mlir/Dialect/Math/IR/Math.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
|
||
using llvm::dbgs; | ||
using namespace mlir; | ||
using namespace catalyst::quantum; | ||
|
||
static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift", | ||
"CRX", "CRY", "CRZ", "ControlledPhaseShift"}; | ||
"CRX", "CRY", "CRZ", "ControlledPhaseShift", | ||
"qml.Rot", "qml.CRot"}; | ||
|
||
namespace { | ||
|
||
|
@@ -49,27 +52,111 @@ | |
return failure(); | ||
} | ||
|
||
TypeRange outQubitsTypes = op.getOutQubits().getTypes(); | ||
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); | ||
ValueRange parentInQubits = parentOp.getInQubits(); | ||
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); | ||
ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); | ||
|
||
auto parentParams = parentOp.getParams(); | ||
auto params = op.getParams(); | ||
SmallVector<mlir::Value> sumParams; | ||
for (auto [param, parentParam] : llvm::zip(params, parentParams)) { | ||
mlir::Value sumParam = | ||
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult(); | ||
sumParams.push_back(sumParam); | ||
}; | ||
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, | ||
parentInQubits, opGateName, nullptr, | ||
parentInCtrlQubits, parentInCtrlValues); | ||
|
||
op.replaceAllUsesWith(mergeOp); | ||
|
||
return success(); | ||
if (opGateName == "qml.Rot" || opGateName == "qml.CRot") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if the rotation gates are adjointed? Should the merge still happen? (In Catalyst adjointed gates are indicated by a adjoint unit attribute, see for example here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you @paul0403 Please let me know if you did not consider whether the rotation gates are adjointed for regular merging rotations (not for merging non-commutative rotations), or if we do not have adjointed gates for regular merging rotations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The merge rotation pass applies an rotation gate adjoint canonicalization. The canonicalization simply changes all angles to their negative and removes the adjoint attribute. See #1205 However, looking at the canonicalization pattern, you will find that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you @paul0403 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good insights! This is what I would do as well. Due to how the work is organized (aka adjoint canonicalization happens somewhere else, not here in the merge rotation patterns), in the merge rotation patterns, it suffices to assume that the rotation gates will not carry adjoint attributes when the patterns are hit. Thus the only thing needed here is a check that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, added it. |
||
|
||
if (op.getAdjoint()) { | ||
LLVM_DEBUG(dbgs() << "Skipping adjointed operation:\n" << op << "\n"); | ||
return failure(); | ||
} | ||
|
||
LLVM_DEBUG(dbgs() << "Applying scalar formula for combined rotation operation:\n" << op << "\n"); | ||
auto params = op.getParams(); | ||
auto parentParams = parentOp.getParams(); | ||
|
||
// Assuming params[0] = alpha1, params[1] = theta1, params[2] = beta1 | ||
// and parentParams[0] = alpha2, parentParams[1] = theta2, parentParams[2] = beta2 | ||
|
||
// Step 1: Calculate c1, c2, s1, s2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding the new merge rotation pattern! The formula is very long, so we appreciate the good work 🥳 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you so much for your help :) If there's anything specific you'd like me to refine or expand on, please let me know. |
||
auto c1 = rewriter.create<math::CosOp>(loc, params[1]); | ||
auto s1 = rewriter.create<math::SinOp>(loc, params[1]); | ||
auto c2 = rewriter.create<math::CosOp>(loc, parentParams[1]); | ||
auto s2 = rewriter.create<math::SinOp>(loc, parentParams[1]); | ||
|
||
// Step 2: Calculate cf | ||
auto c1Squared = rewriter.create<arith::MulFOp>(loc, c1, c1); | ||
auto c2Squared = rewriter.create<arith::MulFOp>(loc, c2, c2); | ||
auto s1Squared = rewriter.create<arith::MulFOp>(loc, s1, s1); | ||
auto s2Squared = rewriter.create<arith::MulFOp>(loc, s2, s2); | ||
auto cosAlphaDiff = rewriter.create<math::CosOp>(loc, rewriter.create<arith::SubFOp>(loc, params[0], parentParams[0])); | ||
|
||
auto term1 = rewriter.create<arith::MulFOp>(loc, c1Squared, c2Squared); | ||
auto term2 = rewriter.create<arith::MulFOp>(loc, s1Squared, s2Squared); | ||
auto product = rewriter.create<arith::MulFOp>(loc, c1, c2); | ||
product = rewriter.create<arith::MulFOp>(loc, product, s1); | ||
product = rewriter.create<arith::MulFOp>(loc, product, s2); | ||
auto two = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(2.0)); | ||
auto term3 = rewriter.create<arith::MulFOp>(loc, two, rewriter.create<arith::MulFOp>(loc, product, cosAlphaDiff)); | ||
|
||
auto cfSquare = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::AddFOp>(loc, term1, term2), term3); | ||
auto cf = rewriter.create<math::SqrtOp>(loc, cfSquare); | ||
|
||
// Step 3: Calculate theta_f = 2 * arccos(|cf|) | ||
auto absCf = rewriter.create<math::AbsFOp>(loc, cf); | ||
auto acosCf = rewriter.create<math::AcosOp>(loc, absCf); | ||
auto thetaF = rewriter.create<arith::MulFOp>(loc, two, acosCf); | ||
|
||
// Step 4: Calculate alpha_f | ||
auto alphaSum = rewriter.create<arith::AddFOp>(loc, params[0], parentParams[0]); | ||
auto betaDiff = rewriter.create<arith::SubFOp>(loc, parentParams[2], params[2]); | ||
auto sinAlphaSum = rewriter.create<math::SinOp>(loc, alphaSum); | ||
auto cosBetaDiff = rewriter.create<math::CosOp>(loc, betaDiff); | ||
|
||
auto term1_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinAlphaSum); | ||
auto term2_alpha = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosBetaDiff); | ||
auto numerator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_alpha), term2_alpha); | ||
|
||
auto cosAlphaSum = rewriter.create<math::CosOp>(loc, alphaSum); | ||
auto denominator_alpha = rewriter.create<arith::SubFOp>(loc, rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, c2), cosAlphaSum), term2_alpha); | ||
|
||
auto alphaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_alpha, denominator_alpha))); | ||
|
||
// Step 5: Calculate beta_f | ||
auto betaSum = rewriter.create<arith::AddFOp>(loc, params[2], parentParams[2]); | ||
auto alphaDiffReversed = rewriter.create<arith::SubFOp>(loc, parentParams[0], params[0]); | ||
auto sinBetaSum = rewriter.create<math::SinOp>(loc, betaSum); | ||
auto cosAlphaDiffReversed = rewriter.create<math::CosOp>(loc, alphaDiffReversed); | ||
|
||
auto term1_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, s2), sinBetaSum); | ||
auto term2_beta = rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s2), cosAlphaDiffReversed); | ||
auto numerator_beta = rewriter.create<arith::AddFOp>(loc, rewriter.create<arith::NegFOp>(loc, term1_beta), term2_beta); | ||
|
||
auto denominator_beta = denominator_alpha; // Reuse from alpha calculation if applicable | ||
auto betaF = rewriter.create<arith::NegFOp>(loc, rewriter.create<math::AtanOp>(loc, rewriter.create<arith::DivFOp>(loc, numerator_beta, denominator_beta))); | ||
|
||
// Step 6: Output angles (phi_f, theta_f, omega_f) | ||
// Assign phi_f = alphaF, theta_f = thetaF, omega_f = betaF as the final values | ||
SmallVector<mlir::Value> combinedAngles = {alphaF, thetaF, betaF}; | ||
auto outQubitsTypes = op.getOutQubits().getTypes(); | ||
auto outCtrlQubitsTypes = op.getOutCtrlQubits().getTypes(); | ||
auto inQubits = op.getInQubits(); | ||
auto inCtrlQubits = op.getInCtrlQubits(); | ||
auto inCtrlValues = op.getInCtrlValues(); | ||
rewriter.replaceOpWithNewOp<CustomOp>(op, outQubitsTypes, outCtrlQubitsTypes, combinedAngles, inQubits, opGateName, nullptr, inCtrlQubits, inCtrlValues); | ||
|
||
return success(); | ||
} | ||
else { | ||
TypeRange outQubitsTypes = op.getOutQubits().getTypes(); | ||
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes(); | ||
ValueRange parentInQubits = parentOp.getInQubits(); | ||
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits(); | ||
ValueRange parentInCtrlValues = parentOp.getInCtrlValues(); | ||
auto parentParams = parentOp.getParams(); | ||
auto params = op.getParams(); | ||
SmallVector<mlir::Value> sumParams; | ||
for (auto [param, parentParam] : llvm::zip(params, parentParams)) { | ||
mlir::Value sumParam = | ||
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult(); | ||
sumParams.push_back(sumParam); | ||
}; | ||
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams, | ||
parentInQubits, opGateName, nullptr, | ||
parentInCtrlQubits, parentInCtrlValues); | ||
|
||
op.replaceAllUsesWith(mergeOp); | ||
|
||
return success(); | ||
} | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So testing this new merge rotation pattern is a bit tricky: we know that regardless of whether the merge rotation transformation took effect or not, the circuit will produce the same results. Given that, does this test here actually test for whether the rotation gates are merged? If not, what is the best way to test that the rotation gates are merged, and what is the purpose of these end-to-end circuit execution tests here?
Hint: search through the code base and look for how the existing merge rotation patterns are tested!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @paul0403
I’ve added explicit checks in test_complex_merge_rotation to verify that the merge_rotations transformation reduces the number of rotation gates (Rot and CRot) and preserves the circuit's functionality. By explicitly calling qml.transforms.merge_rotations, we can compare the unoptimized and optimized circuits directly. This allows us to confirm both that the rotation gates are actually merged (fewer gates) and that the results remain the same, addressing the need for both functionality and transformation verification in the test. Please let me know if this method is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the test! An additional approach on top of plain functionality check is definitely good to have 💯