Skip to content

Commit

Permalink
[Coroutines] Define ABI objects for each type of lowering
Browse files Browse the repository at this point in the history
* Define an ABI object for Switch, Retcon, and Async lowerings
* Perform initialization of each type of lowering as part of ABI
  initialization.
* Make buildCoroutineFrame and splitCoroutine interfaces to the ABI.
  • Loading branch information
tnowicki committed Sep 19, 2024
1 parent ac66469 commit 162db89
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 71 deletions.
15 changes: 11 additions & 4 deletions llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,26 @@

namespace llvm {

namespace coro {
class BaseABI;
class Shape;
} // namespace coro

struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
const std::function<bool(Instruction &)> MaterializableCallback;
// BaseABITy generates an instance of a coro ABI.
using BaseABITy = std::function<coro::BaseABI *(Function &, coro::Shape &)>;

CoroSplitPass(bool OptimizeFrame = false);
CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
bool OptimizeFrame = false)
: MaterializableCallback(MaterializableCallback),
OptimizeFrame(OptimizeFrame) {}
bool OptimizeFrame = false);

PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR);
static bool isRequired() { return true; }

// Generator for an ABI transformer
BaseABITy CreateAndInitABI;

// Would be true if the Optimization level isn't O0.
bool OptimizeFrame;
};
Expand Down
109 changes: 109 additions & 0 deletions llvm/lib/Transforms/Coroutines/ABI.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//===- ABI.h - Coroutine ABI Transformers ---------------------*- C++ -*---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This file declares the pass that analyzes a function for coroutine intrs and
// a transformer class that contains methods for handling different steps of
// coroutine lowering.
//===----------------------------------------------------------------------===//

#ifndef LIB_TRANSFORMS_COROUTINES_ABI_H
#define LIB_TRANSFORMS_COROUTINES_ABI_H

#include "CoroShape.h"
#include "MaterializationUtils.h"
#include "SuspendCrossingInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"

namespace llvm {

class Function;

namespace coro {

// This interface/API is to provide an object oriented way to implement ABI
// functionality. This is intended to replace use of the ABI enum to perform
// ABI operations. The ABIs (e.g. Switch, Async, Retcon{Once}) are the common
// ABIs.

class LLVM_LIBRARY_VISIBILITY BaseABI {
public:
BaseABI(Function &F, Shape &S)
: F(F), Shape(S), IsMaterializable(coro::isTriviallyMaterializable) {}

BaseABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: F(F), Shape(S), IsMaterializable(IsMaterializable) {}

// Initialize the coroutine ABI
virtual void init() = 0;

// Allocate the coroutine frame and do spill/reload as needed.
virtual void buildCoroutineFrame();

// Perform the function splitting according to the ABI.
virtual void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) = 0;

Function &F;
coro::Shape &Shape;

// Callback used by coro::BaseABI::buildCoroutineFrame for rematerialization.
// It is provided to coro::doMaterializations(..).
std::function<bool(Instruction &I)> IsMaterializable;
};

class LLVM_LIBRARY_VISIBILITY SwitchABI : public BaseABI {
public:
SwitchABI(Function &F, coro::Shape &S) : BaseABI(F, S) {}

SwitchABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

class LLVM_LIBRARY_VISIBILITY AsyncABI : public BaseABI {
public:
AsyncABI(Function &F, coro::Shape &S) : BaseABI(F, S) {}

AsyncABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

class LLVM_LIBRARY_VISIBILITY AnyRetconABI : public BaseABI {
public:
AnyRetconABI(Function &F, coro::Shape &S) : BaseABI(F, S) {}

AnyRetconABI(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMaterializable)
: BaseABI(F, S, IsMaterializable) {}

void init() override;

void splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) override;
};

} // end namespace coro

} // end namespace llvm

#endif // LLVM_TRANSFORMS_COROUTINES_ABI_H
7 changes: 3 additions & 4 deletions llvm/lib/Transforms/Coroutines/CoroFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// the value into the coroutine frame.
//===----------------------------------------------------------------------===//

#include "ABI.h"
#include "CoroInternal.h"
#include "MaterializationUtils.h"
#include "SpillUtils.h"
Expand Down Expand Up @@ -2055,11 +2056,9 @@ void coro::normalizeCoroutine(Function &F, coro::Shape &Shape,
rewritePHIs(F);
}

void coro::buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback) {
void coro::BaseABI::buildCoroutineFrame() {
SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds);
doRematerializations(F, Checker, MaterializableCallback);
doRematerializations(F, Checker, IsMaterializable);

const DominatorTree DT(F);
if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon &&
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Transforms/Coroutines/CoroInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ struct LowererBase {
bool defaultMaterializable(Instruction &V);
void normalizeCoroutine(Function &F, coro::Shape &Shape,
TargetTransformInfo &TTI);
void buildCoroutineFrame(
Function &F, Shape &Shape,
const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
TargetTransformInfo &TTI,
ArrayRef<Value *> Arguments, IRBuilder<> &);
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Transforms/Coroutines/CoroShape.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
invalidateCoroutine(F, CoroFrames);
return;
}
initABI();
cleanCoroutine(CoroFrames, UnusedCoroSaves);
}
};
Expand Down
117 changes: 81 additions & 36 deletions llvm/lib/Transforms/Coroutines/CoroSplit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Coroutines/CoroSplit.h"
#include "ABI.h"
#include "CoroInstr.h"
#include "CoroInternal.h"
#include "llvm/ADT/DenseMap.h"
Expand Down Expand Up @@ -1779,9 +1780,9 @@ CallInst *coro::createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
return TailCall;
}

static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
void coro::AsyncABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Async);
assert(Clones.empty());
// Reset various things that the optimizer might have decided it
Expand Down Expand Up @@ -1874,9 +1875,9 @@ static void splitAsyncCoroutine(Function &F, coro::Shape &Shape,
}
}

static void splitRetconCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
void coro::AnyRetconABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
assert(Shape.ABI == coro::ABI::Retcon || Shape.ABI == coro::ABI::RetconOnce);
assert(Clones.empty());

Expand Down Expand Up @@ -2044,26 +2045,27 @@ static bool hasSafeElideCaller(Function &F) {
return false;
}

static coro::Shape
splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI, bool OptimizeFrame,
std::function<bool(Instruction &)> MaterializableCallback) {
PrettyStackTraceFunction prettyStackTrace(F);
void coro::SwitchABI::splitCoroutine(Function &F, coro::Shape &Shape,
SmallVectorImpl<Function *> &Clones,
TargetTransformInfo &TTI) {
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
}

// The suspend-crossing algorithm in buildCoroutineFrame get tripped
// up by uses in unreachable blocks, so remove them as a first pass.
removeUnreachableBlocks(F);
static void doSplitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
coro::BaseABI &ABI, TargetTransformInfo &TTI) {
PrettyStackTraceFunction prettyStackTrace(F);

coro::Shape Shape(F, OptimizeFrame);
if (!Shape.CoroBegin)
return Shape;
auto &Shape = ABI.Shape;
assert(Shape.CoroBegin);

lowerAwaitSuspends(F, Shape);

simplifySuspendPoints(Shape);

normalizeCoroutine(F, Shape, TTI);
buildCoroutineFrame(F, Shape, MaterializableCallback);
ABI.buildCoroutineFrame();
replaceFrameSizeAndAlignment(Shape);

bool isNoSuspendCoroutine = Shape.CoroSuspends.empty();

bool shouldCreateNoAllocVariant = !isNoSuspendCoroutine &&
Expand All @@ -2075,18 +2077,7 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
if (isNoSuspendCoroutine) {
handleNoSuspendCoroutine(Shape);
} else {
switch (Shape.ABI) {
case coro::ABI::Switch:
SwitchCoroutineSplitter::split(F, Shape, Clones, TTI);
break;
case coro::ABI::Async:
splitAsyncCoroutine(F, Shape, Clones, TTI);
break;
case coro::ABI::Retcon:
case coro::ABI::RetconOnce:
splitRetconCoroutine(F, Shape, Clones, TTI);
break;
}
ABI.splitCoroutine(F, Shape, Clones, TTI);
}

// Replace all the swifterror operations in the original function.
Expand All @@ -2107,8 +2098,6 @@ splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,

if (shouldCreateNoAllocVariant)
SwitchCoroutineSplitter::createNoAllocVariant(F, Shape, Clones);

return Shape;
}

static LazyCallGraph::SCC &updateCallGraphAfterCoroutineSplit(
Expand Down Expand Up @@ -2207,8 +2196,53 @@ static void addPrepareFunction(const Module &M,
Fns.push_back(PrepareFn);
}

static coro::BaseABI *CreateNewABI(Function &F, coro::Shape &S) {
switch (S.ABI) {
case coro::ABI::Switch:
return new coro::SwitchABI(F, S);
case coro::ABI::Async:
return new coro::AsyncABI(F, S);
case coro::ABI::Retcon:
return new coro::AnyRetconABI(F, S);
case coro::ABI::RetconOnce:
return new coro::AnyRetconABI(F, S);
}
llvm_unreachable("Unknown ABI");
}

CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
: MaterializableCallback(coro::defaultMaterializable),
: CreateAndInitABI([](Function &F, coro::Shape &S) {
coro::BaseABI *ABI = CreateNewABI(F, S);
ABI->init();
return ABI;
}),
OptimizeFrame(OptimizeFrame) {}

static coro::BaseABI *
CreateNewABIIsMat(Function &F, coro::Shape &S,
std::function<bool(Instruction &)> IsMatCallback) {
switch (S.ABI) {
case coro::ABI::Switch:
return new coro::SwitchABI(F, S, IsMatCallback);
case coro::ABI::Async:
return new coro::AsyncABI(F, S, IsMatCallback);
case coro::ABI::Retcon:
return new coro::AnyRetconABI(F, S, IsMatCallback);
case coro::ABI::RetconOnce:
return new coro::AnyRetconABI(F, S, IsMatCallback);
}
llvm_unreachable("Unknown ABI");
}

// For back compatibility, constructor takes a materializable callback and
// creates a generator for an ABI with a modified materializable callback.
CoroSplitPass::CoroSplitPass(std::function<bool(Instruction &)> IsMatCallback,
bool OptimizeFrame)
: CreateAndInitABI([=](Function &F, coro::Shape &S) {
coro::BaseABI *ABI = CreateNewABIIsMat(F, S, IsMatCallback);
ABI->init();
return ABI;
}),
OptimizeFrame(OptimizeFrame) {}

PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
Expand Down Expand Up @@ -2241,12 +2275,23 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
Function &F = N->getFunction();
LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName()
<< "\n");

// The suspend-crossing algorithm in buildCoroutineFrame gets tripped up
// by unreachable blocks, so remove them as a first pass. Remove the
// unreachable blocks before collecting intrinsics into Shape.
removeUnreachableBlocks(F);

coro::Shape Shape(F, OptimizeFrame);
if (!Shape.CoroBegin)
continue;

F.setSplittedCoroutine();

std::unique_ptr<coro::BaseABI> ABI(CreateAndInitABI(F, Shape));

SmallVector<Function *, 4> Clones;
coro::Shape Shape =
splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
OptimizeFrame, MaterializableCallback);
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
doSplitCoroutine(F, Clones, *ABI, TTI);
CurrentSCC = &updateCallGraphAfterCoroutineSplit(
*N, Shape, Clones, *CurrentSCC, CG, AM, UR, FAM);

Expand Down
Loading

0 comments on commit 162db89

Please sign in to comment.