Skip to content

Commit

Permalink
Decomposition for trapped ions hw {RX, RY, MS} (#1226)
Browse files Browse the repository at this point in the history
**Context:**
For the trapped ions quantum computer from OQD, the native gate set is
RX, RY, MS.

**Description of the Change:**
Add decomposition pass in MLIR that converts gates to their
decomposition in terms of RX, RY, MS.
  • Loading branch information
rmoyard authored Nov 7, 2024
1 parent 75dc517 commit 24cd67d
Show file tree
Hide file tree
Showing 8 changed files with 378 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlir/include/Quantum/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ std::unique_ptr<mlir::Pass> createRemoveChainedSelfInversePass();
std::unique_ptr<mlir::Pass> createAnnotateFunctionPass();
std::unique_ptr<mlir::Pass> createSplitMultipleTapesPass();
std::unique_ptr<mlir::Pass> createMergeRotationsPass();
std::unique_ptr<mlir::Pass> createIonsDecompositionPass();

} // namespace catalyst
6 changes: 6 additions & 0 deletions mlir/include/Quantum/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def SplitMultipleTapesPass : Pass<"split-multiple-tapes"> {
let constructor = "catalyst::createSplitMultipleTapesPass()";
}

def IonsDecompositionPass : Pass<"ions-decomposition"> {
let summary = "Decompose the gates to the set {RX, RY, MS}";

let constructor = "catalyst::createIonsDecompositionPass()";
}

// ----- Quantum circuit transformation passes begin ----- //
// For example, automatic compiler peephole opts, etc.

Expand Down
1 change: 1 addition & 0 deletions mlir/include/Quantum/Transforms/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void populateQIRConversionPatterns(mlir::TypeConverter &, mlir::RewritePatternSe
void populateAdjointPatterns(mlir::RewritePatternSet &);
void populateSelfInversePatterns(mlir::RewritePatternSet &);
void populateMergeRotationsPatterns(mlir::RewritePatternSet &);
void populateIonsDecompositionPatterns(mlir::RewritePatternSet &);

} // namespace quantum
} // namespace catalyst
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createScatterLoweringPass);
mlir::registerPass(catalyst::createSplitMultipleTapesPass);
mlir::registerPass(catalyst::createTestPass);
mlir::registerPass(catalyst::createIonsDecompositionPass);
}
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ file(GLOB SRC
SplitMultipleTapes.cpp
merge_rotation.cpp
MergeRotationsPatterns.cpp
ions_decompositions.cpp
IonsDecompositionPatterns.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
Expand Down
176 changes: 176 additions & 0 deletions mlir/lib/Quantum/Transforms/IonsDecompositionPatterns.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "ions-decomposition"

#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

using namespace mlir;
using namespace catalyst::quantum;

constexpr double PI = llvm::numbers::pi;

// This function creates the RX(phi) RY(theta) RX(lambda) decomposition
// Note that for parametric operations theta is a Value, for non parametric
// operation theta is an attribute, hence the use of std::variant.

void oneQubitDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter, double phi,
std::variant<mlir::Value, double> theta, double lambda)
{
TypeRange outQubitsTypes = op.getOutQubits().getTypes();

ValueRange inQubits = op.getInQubits();

TypedAttr phiAttr = rewriter.getF64FloatAttr(phi);
mlir::Value phiValue = rewriter.create<arith::ConstantOp>(op.getLoc(), phiAttr);

mlir::Value thetaValue;
if (std::holds_alternative<mlir::Value>(theta)) {
thetaValue = std::get<mlir::Value>(theta);
}
else if (std::holds_alternative<double>(theta)) {
TypedAttr thetaAttr = rewriter.getF64FloatAttr(std::get<double>(theta));
thetaValue = rewriter.create<arith::ConstantOp>(op.getLoc(), thetaAttr);
}
TypedAttr lambdaAttr = rewriter.getF64FloatAttr(lambda);
mlir::Value lambdaValue = rewriter.create<arith::ConstantOp>(op.getLoc(), lambdaAttr);

auto rxPhi = rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes, ValueRange{}, phiValue,
inQubits, "RX", nullptr, ValueRange{}, ValueRange{});
auto ryTheta = rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes, ValueRange{}, thetaValue,
rxPhi.getOutQubits(), "RY", nullptr,
rxPhi.getInCtrlQubits(), rxPhi.getInCtrlValues());
auto rxLambda = rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes, ValueRange{},
lambdaValue, ryTheta.getOutQubits(), "RX", nullptr,
ValueRange{}, ValueRange{});
op.replaceAllUsesWith(rxLambda);
}

void tDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
if (op.getAdjoint()) {
oneQubitDecomp(op, rewriter, -PI / 2, -PI / 4, PI / 2);
}
else {
oneQubitDecomp(op, rewriter, -PI / 2, PI / 4, PI / 2);
}
}

void sDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
if (op.getAdjoint()) {
oneQubitDecomp(op, rewriter, -PI / 2, -PI / 2, PI / 2);
}
else {
oneQubitDecomp(op, rewriter, -PI / 2, PI / 2, PI / 2);
}
}

void zDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
oneQubitDecomp(op, rewriter, -PI / 2, PI, PI / 2);
}

void hDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
oneQubitDecomp(op, rewriter, 0.0, PI / 2, PI);
}

void psDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
oneQubitDecomp(op, rewriter, -PI / 2, op.getParams().front(), PI / 2);
}

void rzDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
oneQubitDecomp(op, rewriter, -PI / 2, op.getParams().front(), PI / 2);
}

void cnotDecomp(catalyst::quantum::CustomOp op, mlir::PatternRewriter &rewriter)
{
TypeRange outQubitsTypes = op.getOutQubits().getTypes();

mlir::Value inQubit0 = op.getInQubits().front();
mlir::Value inQubit1 = op.getInQubits().back();

TypedAttr piOver2Attr = rewriter.getF64FloatAttr(PI / 2);
mlir::Value piOver2 = rewriter.create<arith::ConstantOp>(op.getLoc(), piOver2Attr);
auto ryPiOver2 =
rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes.front(), ValueRange{}, piOver2,
inQubit0, "RY", nullptr, ValueRange{}, ValueRange{});
SmallVector<mlir::Value> qubitsAfterRy;
qubitsAfterRy.push_back(ryPiOver2.getOutQubits().front());
qubitsAfterRy.push_back(inQubit1);
auto ms = rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes, ValueRange{}, piOver2,
qubitsAfterRy, "MS", nullptr, ValueRange{}, ValueRange{});
mlir::Value qubit0AfterMs = ms.getOutQubits().front();
mlir::Value qubit1AfterMs = ms.getOutQubits().back();

TypedAttr minusPiOver2Attr = rewriter.getF64FloatAttr(-PI / 2);
mlir::Value minusPiOver2 = rewriter.create<arith::ConstantOp>(op.getLoc(), minusPiOver2Attr);
auto rxMinusPiOver2 =
rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes.front(), ValueRange{}, minusPiOver2,
qubit0AfterMs, "RX", nullptr, ValueRange{}, ValueRange{});
auto firstRyMinusPiOver2 =
rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes.front(), ValueRange{}, minusPiOver2,
qubit1AfterMs, "RY", nullptr, ValueRange{}, ValueRange{});

mlir::Value qubit0AfterRY = rxMinusPiOver2.getOutQubits().front();
auto secondRyMinusPiOver2 =
rewriter.create<CustomOp>(op.getLoc(), outQubitsTypes.front(), ValueRange{}, minusPiOver2,
qubit0AfterRY, "RY", nullptr, ValueRange{}, ValueRange{});

SmallVector<mlir::Value> qubitsEnd;
qubitsEnd.push_back(firstRyMinusPiOver2.getOutQubits().front());
qubitsEnd.push_back(secondRyMinusPiOver2.getOutQubits().front());
op.replaceAllUsesWith(qubitsEnd);
}

std::map<std::string, std::function<void(catalyst::quantum::CustomOp, mlir::PatternRewriter &)>>
funcMap = {{"T", &tDecomp}, {"S", &sDecomp}, {"Z", &zDecomp},
{"Hadamard", &hDecomp}, {"RZ", &rzDecomp}, {"PhaseShift", &psDecomp},
{"CNOT", &cnotDecomp}};

namespace {

struct IonsDecompositionRewritePattern : public mlir::OpRewritePattern<CustomOp> {
using mlir::OpRewritePattern<CustomOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(CustomOp op, mlir::PatternRewriter &rewriter) const override
{
auto it = funcMap.find(op.getGateName().str());
if (it != funcMap.end()) {
auto decompFunc = it->second;
decompFunc(op, rewriter);
return success();
}
else {
return failure();
}
}
};
} // namespace

namespace catalyst {
namespace quantum {

void populateIonsDecompositionPatterns(RewritePatternSet &patterns)
{
patterns.add<IonsDecompositionRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
} // namespace catalyst
67 changes: 67 additions & 0 deletions mlir/lib/Quantum/Transforms/ions_decompositions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "ions-decomposition"

#include "Catalyst/IR/CatalystDialect.h"
#include "Quantum/IR/QuantumOps.h"
#include "Quantum/Transforms/Patterns.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace llvm;
using namespace mlir;
using namespace catalyst::quantum;

namespace catalyst {
namespace quantum {

#define GEN_PASS_DEF_IONSDECOMPOSITIONPASS
#define GEN_PASS_DECL_IONSDECOMPOSITIONPASS
#include "Quantum/Transforms/Passes.h.inc"

struct IonsDecompositionPass : impl::IonsDecompositionPassBase<IonsDecompositionPass> {
using IonsDecompositionPassBase::IonsDecompositionPassBase;

void runOnOperation() final
{
LLVM_DEBUG(dbgs() << "ions decomposition pass"
<< "\n");

Operation *module = getOperation();

RewritePatternSet patternsCanonicalization(&getContext());
catalyst::quantum::CustomOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
if (failed(applyPatternsAndFoldGreedily(module, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}
RewritePatternSet patterns(&getContext());
populateIonsDecompositionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace quantum

std::unique_ptr<Pass> createIonsDecompositionPass()
{
return std::make_unique<quantum::IonsDecompositionPass>();
}

} // namespace catalyst
Loading

0 comments on commit 24cd67d

Please sign in to comment.