Skip to content
This repository has been archived by the owner on Jan 20, 2024. It is now read-only.

Commit

Permalink
[OpenMP] Generate code for target worksharing loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Dominik Adamski committed Sep 5, 2023
1 parent a0703ec commit 470dde0
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 76 deletions.
16 changes: 16 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,22 @@ class OpenMPIRBuilder {
Type *LlvmPtrTy, Constant *Addr);

private:
/// Modifies the canonical loop to be a statically-scheduled workshare loop
/// which is executed on the device
///
/// This takes a \p LoopInfo representing a canonical loop, such as the one
/// created by \p createCanonicalLoop and emits additional instructions to
/// turn it into a workshare loop. In particular, it calls to an OpenMP
/// runtime function in the preheader to call OpenMP device rtl function
/// which handles worksharing of loop body interations.
///
/// \param DL Debug location for instructions added for the
/// workshare-loop construct itself.
/// \param CLI A descriptor of the canonical loop to workshare.
///
/// \returns Point where to insert code after the workshare construct.
InsertPointTy applyWorkshareLoopDevice(DebugLoc DL, CanonicalLoopInfo *CLI);

/// Modifies the canonical loop to be a statically-scheduled workshare loop.
///
/// This takes a \p LoopInfo representing a canonical loop, such as the one
Expand Down
310 changes: 234 additions & 76 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,29 +523,7 @@ void OpenMPIRBuilder::finalizeFunction(Function *Fn) {
Function *OuterFn = OI.getFunction();
CodeExtractorAnalysisCache CEAC(*OuterFn);
Function *OutlinedFn = nullptr;
if (Config.isTargetDevice()) {
// Use extractor which does not aggregate args
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ false,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocaBlock*/ OI.OuterAllocaBB,
/* Suffix */ ".omp_par");

LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
<< " Exit: " << OI.ExitBB->getName() << "\n");
assert(Extractor.isEligible() &&
"Expected OpenMP outlining to be possible!");

for (auto *V : OI.ExcludeArgsFromAggregate)
Extractor.excludeArgFromAggregate(V);

OutlinedFn = Extractor.extractCodeRegion(CEAC);
} else {
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true,
/* BlockFrequencyInfo */ nullptr,
Expand All @@ -566,7 +544,6 @@ void OpenMPIRBuilder::finalizeFunction(Function *Fn) {
Extractor.excludeArgFromAggregate(V);

OutlinedFn = Extractor.extractCodeRegion(CEAC);
}
LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
assert(OutlinedFn->getReturnType()->isVoidTy() &&
Expand Down Expand Up @@ -2584,11 +2561,244 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
}

/// Determine which blocks in \p BBs are reachable from outside and remove the
/// ones that are not reachable from the function.
static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
for (Use &U : BB->uses()) {
auto *UseInst = dyn_cast<Instruction>(U.getUser());
if (!UseInst)
continue;
if (BBsToErase.count(UseInst->getParent()))
continue;
return true;
}
return false;
};

while (true) {
bool Changed = false;
for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
if (HasRemainingUses(BB)) {
BBsToErase.erase(BB);
Changed = true;
}
}
if (!Changed)
break;
}

SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
DeleteDeadBlocks(BBVec);
}

OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::applyWorkshareLoopDevice(DebugLoc DL, CanonicalLoopInfo *CLI) {
uint32_t SrcLocStrSize;
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);

OutlineInfo OI;
OI.OuterAllocaBB = CLI->getPreheader();
Function *OuterFn = CLI->getPreheader()->getParent();

// Instructions which need to be deleted at the end of code generation
SmallVector<Instruction *, 4> ToBeDeleted;

// Put additional allocas generated by extractor in loop preheader
OI.OuterAllocaBB = CLI->getPreheader();

// Mark the body loop as region which needs to be extracted
OI.EntryBB = CLI->getBody();
OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
"omp.prelatch", true);

// Prepare loop body for extraction
Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});

// Insert new loop counter variable which will be used only in loop
// body.
AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
Instruction *newLoopCntLoad =
Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt);
// New loop counter instructions are redundant in the loop preheader when
// code generation for workshare loop is finshed. That's why mark them as
// ready for deletion.
ToBeDeleted.push_back(newLoopCntLoad);
ToBeDeleted.push_back(newLoopCnt);

// Analyse loop body region. Find all input variables which are used inside
// loop body region.
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
SmallVector<BasicBlock *, 32> Blocks;
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
ParallelRegionBlockSet.end());
;

CodeExtractorAnalysisCache CEAC(*OuterFn);
CodeExtractor Extractor(Blocks,
/* DominatorTree */ nullptr,
/* AggregateArgs */ true,
/* BlockFrequencyInfo */ nullptr,
/* BranchProbabilityInfo */ nullptr,
/* AssumptionCache */ nullptr,
/* AllowVarArgs */ true,
/* AllowAlloca */ true,
/* AllocationBlock */ CLI->getPreheader(),
/* Suffix */ ".omp_wsloop");

BasicBlock *CommonExit = nullptr;
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;

// Find allocas outside the loop body region which are used inside loop
// body
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);

// Sink allocas used only inside loop body
// original code:
// %item_used_in_loop_body = alloca i32
// ;no more instructions which uses item_used_in_loop_body
// loopbody:
// use(%item_used_in_loop_body)
//
// After sinking:
// loopbody:
// %item_used_in_loop_body_moved_alloca = alloca i32
// use(%item_used_in_loop_body_moved_alloca)
//
// TODO: OMPIRBuilder should not be responsible for sinking allocas
// which are used only inside loop body region.
for (AllocaInst *AllocaItem : CEAC.getAllocas()) {
bool ReadyToMove = true;
for (User *AllocaUse : AllocaItem->users()) {
Instruction *Inst;
if ((Inst = dyn_cast<LoadInst>(AllocaUse)) &&
ParallelRegionBlockSet.count(Inst->getParent()))
continue;
if ((Inst = dyn_cast<StoreInst>(AllocaUse)) &&
ParallelRegionBlockSet.count(Inst->getParent()))
continue;
ReadyToMove = false;
break;
}
if (ReadyToMove) {
Builder.restoreIP({CLI->getBody(), CLI->getBody()->begin()});
AllocaInst *NewAlloca =
Builder.CreateAlloca(CLI->getIndVarType(), 0, "moved_alloca");
std::vector<User *> Users(AllocaItem->user_begin(),
AllocaItem->user_end());
for (User *use : Users) {
use->replaceUsesOfWith(AllocaItem, NewAlloca);
}
ToBeDeleted.push_back(AllocaItem);
}
}
// We need to model loop body region as the function f(cnt, loop_arg).
// That's why we replace loop induction variable by the new counter
// which will be one of loop body function argument
std::vector<User *> Users(CLI->getIndVar()->user_begin(),
CLI->getIndVar()->user_end());
for (User *use : Users) {
if (Instruction *inst = dyn_cast<Instruction>(use)) {
if (ParallelRegionBlockSet.count(inst->getParent())) {
inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad);
}
}
}
Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
for (Value *Input : Inputs) {
// Make sure that loop counter variable is not merged into loop body
// function argument structure and it is passed as separate variable
if (Input == newLoopCntLoad)
OI.ExcludeArgsFromAggregate.push_back(Input);
}

// PostOutline CB is invoked when loop body function is outlined and
// loop body is replaced by call to outlined function. We need to add
// call to OpenMP device rtl inside loop preheader. OpenMP device rtl
// function will handle loop control logic.
//
OI.PostOutlineCB = [=](Function &OutlinedFn) {
BasicBlock *Preheader = CLI->getPreheader();
Value *TripCount = CLI->getTripCount();

// After loop body outling, the loop body contains only set up
// of loop body argument structure and the call to the outlined
// loop body function. Firstly, we need to move setup of loop body args
// into loop preheader.
Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
CLI->getBody()->begin(),
std::prev(CLI->getBody()->end()));

// The next step is to remove the whole loop. We do not it need anymore.
// That's why make an unconditional branch from loop preheader to loop
// exit block
Builder.restoreIP({Preheader, Preheader->end()});
Preheader->getTerminator()->eraseFromParent();
Builder.CreateBr(CLI->getExit());

// Delete dead loop blocks
OutlineInfo CleanUpInfo;
SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
CleanUpInfo.EntryBB = CLI->getHeader();
CleanUpInfo.ExitBB = CLI->getExit();
CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
DeleteDeadBlocks(BlocksToBeRemoved);

// Find the instruction which corresponds to loop body argument structure
// and remove the call to loop body function instruction.
Value *LoopBodyArg;
for (auto instIt = Preheader->begin(); instIt != Preheader->end();
++instIt) {
if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
if (CallInstruction->getCalledFunction() == &OutlinedFn) {
LoopBodyArg = CallInstruction->getArgOperand(1);
CallInstruction->eraseFromParent();
break;
}
}
}

// Create call to the OpenMP runtime function.
// TODO: Provide mechanism of choosing the right OpenMP device function
FunctionCallee RTLFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_for_static_loop_4);
FunctionCallee RTLNumThreads =
getOrCreateRuntimeFunctionPtr(OMPRTL_omp_get_num_threads);
Builder.restoreIP({Preheader, std::prev(Preheader->end())});
Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
Value *ForStaticLoopCallArgs[] = {
/*identifier*/ Ident,
/*loop body func*/ Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr),
/*loop body args*/ LoopBodyArg,
/*num of iters*/ TripCount,
/*num of threads*/ NumThreads,
/*block chunk*/ Builder.getInt32(1)};

SmallVector<Value *, 8> RealArgs;
RealArgs.append(std::begin(ForStaticLoopCallArgs),
std::end(ForStaticLoopCallArgs));

Builder.CreateCall(RTLFn, RealArgs);

for (auto &ToBeDeletedItem : ToBeDeleted)
ToBeDeletedItem->eraseFromParent();
CLI->invalidate();
};
addOutlineInfo(std::move(OI));
return CLI->getAfterIP();
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
bool HasNonmonotonicModifier, bool HasOrderedClause) {
if (Config.isTargetDevice())
return applyWorkshareLoopDevice(DL, CLI);
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
HasNonmonotonicModifier, HasOrderedClause);
Expand Down Expand Up @@ -2817,37 +3027,6 @@ static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
redirectTo(Pred, NewTarget, DL);
}

/// Determine which blocks in \p BBs are reachable from outside and remove the
/// ones that are not reachable from the function.
static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
for (Use &U : BB->uses()) {
auto *UseInst = dyn_cast<Instruction>(U.getUser());
if (!UseInst)
continue;
if (BBsToErase.count(UseInst->getParent()))
continue;
return true;
}
return false;
};

while (true) {
bool Changed = false;
for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
if (HasRemainingUses(BB)) {
BBsToErase.erase(BB);
Changed = true;
}
}
if (!Changed)
break;
}

SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
DeleteDeadBlocks(BBVec);
}

CanonicalLoopInfo *
OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
Expand Down Expand Up @@ -4523,20 +4702,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
return getOrCreateRuntimeFunction(M, Name);
}

// Copy input from pointer or i64 to the expected argument type.
static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace,
Value *Input, Argument &Arg) {
auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy()
? Arg.getType()
: Type::getInt64Ty(Builder.getContext()),
AddrSpace);
auto AddrAscast =
Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
Builder.CreateStore(&Arg, AddrAscast);
auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);

return Copy;
}

static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
Type *Int8PtrTy, Module &M) {
Expand Down Expand Up @@ -4636,15 +4801,8 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);

Value *InputCopy =
OMPBuilder.Config.isTargetDevice()
? copyInput(Builder,
OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
Input, Arg)
: &Arg;

Value *InputCopy = &Arg;
// Collect all the instructions
assert(InputCopy->getType()->isPointerTy() && "Not Pointer Type");
for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Expand Down

0 comments on commit 470dde0

Please sign in to comment.