From 470dde01b9bef37f8daea3240edd073eb9d0326b Mon Sep 17 00:00:00 2001 From: Dominik Adamski Date: Tue, 5 Sep 2023 05:29:51 -0400 Subject: [PATCH] [OpenMP] Generate code for target worksharing loop --- .../llvm/Frontend/OpenMP/OMPIRBuilder.h | 16 + llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 310 +++++++++++++----- 2 files changed, 250 insertions(+), 76 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 6f3c9d528ee0..56b29fd6077a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -874,6 +874,22 @@ class OpenMPIRBuilder { Type *LlvmPtrTy, Constant *Addr); private: + /// Modifies the canonical loop to be a statically-scheduled workshare loop + /// which is executed on the device + /// + /// This takes a \p LoopInfo representing a canonical loop, such as the one + /// created by \p createCanonicalLoop and emits additional instructions to + /// turn it into a workshare loop. In particular, it calls to an OpenMP + /// runtime function in the preheader to call OpenMP device rtl function + /// which handles worksharing of loop body interations. + /// + /// \param DL Debug location for instructions added for the + /// workshare-loop construct itself. + /// \param CLI A descriptor of the canonical loop to workshare. + /// + /// \returns Point where to insert code after the workshare construct. + InsertPointTy applyWorkshareLoopDevice(DebugLoc DL, CanonicalLoopInfo *CLI); + /// Modifies the canonical loop to be a statically-scheduled workshare loop. /// /// This takes a \p LoopInfo representing a canonical loop, such as the one diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 12d0a7c7eef4..c4f35e379fa4 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -523,29 +523,7 @@ void OpenMPIRBuilder::finalizeFunction(Function *Fn) { Function *OuterFn = OI.getFunction(); CodeExtractorAnalysisCache CEAC(*OuterFn); Function *OutlinedFn = nullptr; - if (Config.isTargetDevice()) { // Use extractor which does not aggregate args - CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr, - /* AggregateArgs */ false, - /* BlockFrequencyInfo */ nullptr, - /* BranchProbabilityInfo */ nullptr, - /* AssumptionCache */ nullptr, - /* AllowVarArgs */ true, - /* AllowAlloca */ true, - /* AllocaBlock*/ OI.OuterAllocaBB, - /* Suffix */ ".omp_par"); - - LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n"); - LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName() - << " Exit: " << OI.ExitBB->getName() << "\n"); - assert(Extractor.isEligible() && - "Expected OpenMP outlining to be possible!"); - - for (auto *V : OI.ExcludeArgsFromAggregate) - Extractor.excludeArgFromAggregate(V); - - OutlinedFn = Extractor.extractCodeRegion(CEAC); - } else { CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr, /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr, @@ -566,7 +544,6 @@ void OpenMPIRBuilder::finalizeFunction(Function *Fn) { Extractor.excludeArgFromAggregate(V); OutlinedFn = Extractor.extractCodeRegion(CEAC); - } LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n"); LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n"); assert(OutlinedFn->getReturnType()->isVoidTy() && @@ -2584,11 +2561,244 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop( return {DispatchAfter, DispatchAfter->getFirstInsertionPt()}; } +/// Determine which blocks in \p BBs are reachable from outside and remove the +/// ones that are not reachable from the function. +static void removeUnusedBlocksFromParent(ArrayRef BBs) { + SmallPtrSet BBsToErase{BBs.begin(), BBs.end()}; + auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) { + for (Use &U : BB->uses()) { + auto *UseInst = dyn_cast(U.getUser()); + if (!UseInst) + continue; + if (BBsToErase.count(UseInst->getParent())) + continue; + return true; + } + return false; + }; + + while (true) { + bool Changed = false; + for (BasicBlock *BB : make_early_inc_range(BBsToErase)) { + if (HasRemainingUses(BB)) { + BBsToErase.erase(BB); + Changed = true; + } + } + if (!Changed) + break; + } + + SmallVector BBVec(BBsToErase.begin(), BBsToErase.end()); + DeleteDeadBlocks(BBVec); +} + +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::applyWorkshareLoopDevice(DebugLoc DL, CanonicalLoopInfo *CLI) { + uint32_t SrcLocStrSize; + Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize); + Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize); + + OutlineInfo OI; + OI.OuterAllocaBB = CLI->getPreheader(); + Function *OuterFn = CLI->getPreheader()->getParent(); + + // Instructions which need to be deleted at the end of code generation + SmallVector ToBeDeleted; + + // Put additional allocas generated by extractor in loop preheader + OI.OuterAllocaBB = CLI->getPreheader(); + + // Mark the body loop as region which needs to be extracted + OI.EntryBB = CLI->getBody(); + OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(), + "omp.prelatch", true); + + // Prepare loop body for extraction + Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()}); + + // Insert new loop counter variable which will be used only in loop + // body. + AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, ""); + Instruction *newLoopCntLoad = + Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt); + // New loop counter instructions are redundant in the loop preheader when + // code generation for workshare loop is finshed. That's why mark them as + // ready for deletion. + ToBeDeleted.push_back(newLoopCntLoad); + ToBeDeleted.push_back(newLoopCnt); + + // Analyse loop body region. Find all input variables which are used inside + // loop body region. + SmallPtrSet ParallelRegionBlockSet; + SmallVector Blocks; + OI.collectBlocks(ParallelRegionBlockSet, Blocks); + SmallVector BlocksT(ParallelRegionBlockSet.begin(), + ParallelRegionBlockSet.end()); + ; + + CodeExtractorAnalysisCache CEAC(*OuterFn); + CodeExtractor Extractor(Blocks, + /* DominatorTree */ nullptr, + /* AggregateArgs */ true, + /* BlockFrequencyInfo */ nullptr, + /* BranchProbabilityInfo */ nullptr, + /* AssumptionCache */ nullptr, + /* AllowVarArgs */ true, + /* AllowAlloca */ true, + /* AllocationBlock */ CLI->getPreheader(), + /* Suffix */ ".omp_wsloop"); + + BasicBlock *CommonExit = nullptr; + SetVector Inputs, Outputs, SinkingCands, HoistingCands; + + // Find allocas outside the loop body region which are used inside loop + // body + Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); + + // Sink allocas used only inside loop body + // original code: + // %item_used_in_loop_body = alloca i32 + // ;no more instructions which uses item_used_in_loop_body + // loopbody: + // use(%item_used_in_loop_body) + // + // After sinking: + // loopbody: + // %item_used_in_loop_body_moved_alloca = alloca i32 + // use(%item_used_in_loop_body_moved_alloca) + // + // TODO: OMPIRBuilder should not be responsible for sinking allocas + // which are used only inside loop body region. + for (AllocaInst *AllocaItem : CEAC.getAllocas()) { + bool ReadyToMove = true; + for (User *AllocaUse : AllocaItem->users()) { + Instruction *Inst; + if ((Inst = dyn_cast(AllocaUse)) && + ParallelRegionBlockSet.count(Inst->getParent())) + continue; + if ((Inst = dyn_cast(AllocaUse)) && + ParallelRegionBlockSet.count(Inst->getParent())) + continue; + ReadyToMove = false; + break; + } + if (ReadyToMove) { + Builder.restoreIP({CLI->getBody(), CLI->getBody()->begin()}); + AllocaInst *NewAlloca = + Builder.CreateAlloca(CLI->getIndVarType(), 0, "moved_alloca"); + std::vector Users(AllocaItem->user_begin(), + AllocaItem->user_end()); + for (User *use : Users) { + use->replaceUsesOfWith(AllocaItem, NewAlloca); + } + ToBeDeleted.push_back(AllocaItem); + } + } + // We need to model loop body region as the function f(cnt, loop_arg). + // That's why we replace loop induction variable by the new counter + // which will be one of loop body function argument + std::vector Users(CLI->getIndVar()->user_begin(), + CLI->getIndVar()->user_end()); + for (User *use : Users) { + if (Instruction *inst = dyn_cast(use)) { + if (ParallelRegionBlockSet.count(inst->getParent())) { + inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad); + } + } + } + Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands); + for (Value *Input : Inputs) { + // Make sure that loop counter variable is not merged into loop body + // function argument structure and it is passed as separate variable + if (Input == newLoopCntLoad) + OI.ExcludeArgsFromAggregate.push_back(Input); + } + + // PostOutline CB is invoked when loop body function is outlined and + // loop body is replaced by call to outlined function. We need to add + // call to OpenMP device rtl inside loop preheader. OpenMP device rtl + // function will handle loop control logic. + // + OI.PostOutlineCB = [=](Function &OutlinedFn) { + BasicBlock *Preheader = CLI->getPreheader(); + Value *TripCount = CLI->getTripCount(); + + // After loop body outling, the loop body contains only set up + // of loop body argument structure and the call to the outlined + // loop body function. Firstly, we need to move setup of loop body args + // into loop preheader. + Preheader->splice(std::prev(Preheader->end()), CLI->getBody(), + CLI->getBody()->begin(), + std::prev(CLI->getBody()->end())); + + // The next step is to remove the whole loop. We do not it need anymore. + // That's why make an unconditional branch from loop preheader to loop + // exit block + Builder.restoreIP({Preheader, Preheader->end()}); + Preheader->getTerminator()->eraseFromParent(); + Builder.CreateBr(CLI->getExit()); + + // Delete dead loop blocks + OutlineInfo CleanUpInfo; + SmallPtrSet RegionBlockSet; + SmallVector BlocksToBeRemoved; + CleanUpInfo.EntryBB = CLI->getHeader(); + CleanUpInfo.ExitBB = CLI->getExit(); + CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved); + DeleteDeadBlocks(BlocksToBeRemoved); + + // Find the instruction which corresponds to loop body argument structure + // and remove the call to loop body function instruction. + Value *LoopBodyArg; + for (auto instIt = Preheader->begin(); instIt != Preheader->end(); + ++instIt) { + if (CallInst *CallInstruction = dyn_cast(instIt)) { + if (CallInstruction->getCalledFunction() == &OutlinedFn) { + LoopBodyArg = CallInstruction->getArgOperand(1); + CallInstruction->eraseFromParent(); + break; + } + } + } + + // Create call to the OpenMP runtime function. + // TODO: Provide mechanism of choosing the right OpenMP device function + FunctionCallee RTLFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_for_static_loop_4); + FunctionCallee RTLNumThreads = + getOrCreateRuntimeFunctionPtr(OMPRTL_omp_get_num_threads); + Builder.restoreIP({Preheader, std::prev(Preheader->end())}); + Value *NumThreads = Builder.CreateCall(RTLNumThreads, {}); + Value *ForStaticLoopCallArgs[] = { + /*identifier*/ Ident, + /*loop body func*/ Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr), + /*loop body args*/ LoopBodyArg, + /*num of iters*/ TripCount, + /*num of threads*/ NumThreads, + /*block chunk*/ Builder.getInt32(1)}; + + SmallVector RealArgs; + RealArgs.append(std::begin(ForStaticLoopCallArgs), + std::end(ForStaticLoopCallArgs)); + + Builder.CreateCall(RTLFn, RealArgs); + + for (auto &ToBeDeletedItem : ToBeDeleted) + ToBeDeletedItem->eraseFromParent(); + CLI->invalidate(); + }; + addOutlineInfo(std::move(OI)); + return CLI->getAfterIP(); +} + OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop( DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP, bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind, llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier, bool HasNonmonotonicModifier, bool HasOrderedClause) { + if (Config.isTargetDevice()) + return applyWorkshareLoopDevice(DL, CLI); OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType( SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier, HasNonmonotonicModifier, HasOrderedClause); @@ -2817,37 +3027,6 @@ static void redirectAllPredecessorsTo(BasicBlock *OldTarget, redirectTo(Pred, NewTarget, DL); } -/// Determine which blocks in \p BBs are reachable from outside and remove the -/// ones that are not reachable from the function. -static void removeUnusedBlocksFromParent(ArrayRef BBs) { - SmallPtrSet BBsToErase{BBs.begin(), BBs.end()}; - auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) { - for (Use &U : BB->uses()) { - auto *UseInst = dyn_cast(U.getUser()); - if (!UseInst) - continue; - if (BBsToErase.count(UseInst->getParent())) - continue; - return true; - } - return false; - }; - - while (true) { - bool Changed = false; - for (BasicBlock *BB : make_early_inc_range(BBsToErase)) { - if (HasRemainingUses(BB)) { - BBsToErase.erase(BB); - Changed = true; - } - } - if (!Changed) - break; - } - - SmallVector BBVec(BBsToErase.begin(), BBsToErase.end()); - DeleteDeadBlocks(BBVec); -} CanonicalLoopInfo * OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef Loops, @@ -4523,20 +4702,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize, return getOrCreateRuntimeFunction(M, Name); } -// Copy input from pointer or i64 to the expected argument type. -static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace, - Value *Input, Argument &Arg) { - auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy() - ? Arg.getType() - : Type::getInt64Ty(Builder.getContext()), - AddrSpace); - auto AddrAscast = - Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType()); - Builder.CreateStore(&Arg, AddrAscast); - auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast); - - return Copy; -} static void emitUsed(StringRef Name, std::vector &List, Type *Int8PtrTy, Module &M) { @@ -4636,15 +4801,8 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, Value *Input = std::get<0>(InArg); Argument &Arg = std::get<1>(InArg); - Value *InputCopy = - OMPBuilder.Config.isTargetDevice() - ? copyInput(Builder, - OMPBuilder.M.getDataLayout().getAllocaAddrSpace(), - Input, Arg) - : &Arg; - + Value *InputCopy = &Arg; // Collect all the instructions - assert(InputCopy->getType()->isPointerTy() && "Not Pointer Type"); for (User *User : make_early_inc_range(Input->users())) if (auto Instr = dyn_cast(User)) if (Instr->getFunction() == Func)