diff --git a/src/coreclr/jit/CMakeLists.txt b/src/coreclr/jit/CMakeLists.txt index 7c867635b9c71..30e9ad972dab6 100644 --- a/src/coreclr/jit/CMakeLists.txt +++ b/src/coreclr/jit/CMakeLists.txt @@ -151,6 +151,7 @@ set( JIT_SOURCES objectalloc.cpp optcse.cpp optimizebools.cpp + switchrecognition.cpp optimizer.cpp patchpoint.cpp phase.cpp diff --git a/src/coreclr/jit/block.cpp b/src/coreclr/jit/block.cpp index d7eabe832ac98..8b5cef28a71a8 100644 --- a/src/coreclr/jit/block.cpp +++ b/src/coreclr/jit/block.cpp @@ -792,9 +792,6 @@ bool BasicBlock::IsLIR() const //------------------------------------------------------------------------ // firstStmt: Returns the first statement in the block // -// Arguments: -// None. -// // Return Value: // The first statement in the block's bbStmtList. // @@ -804,10 +801,18 @@ Statement* BasicBlock::firstStmt() const } //------------------------------------------------------------------------ -// lastStmt: Returns the last statement in the block +// hasSingleStmt: Returns true if block has a single statement // -// Arguments: -// None. +// Return Value: +// true if block has a single statement, false otherwise +// +bool BasicBlock::hasSingleStmt() const +{ + return (firstStmt() != nullptr) && (firstStmt() == lastStmt()); +} + +//------------------------------------------------------------------------ +// lastStmt: Returns the last statement in the block // // Return Value: // The last statement in the block's bbStmtList. diff --git a/src/coreclr/jit/block.h b/src/coreclr/jit/block.h index 676efcfc9485c..9c7953a12b9e5 100644 --- a/src/coreclr/jit/block.h +++ b/src/coreclr/jit/block.h @@ -1127,6 +1127,7 @@ struct BasicBlock : private LIR::Range Statement* firstStmt() const; Statement* lastStmt() const; + bool hasSingleStmt() const; // Statements: convenience method for enabling range-based `for` iteration over the statement list, e.g.: // for (Statement* const stmt : block->Statements()) diff --git a/src/coreclr/jit/compiler.cpp b/src/coreclr/jit/compiler.cpp index 53b3352a06587..3ab6b41839e2e 100644 --- a/src/coreclr/jit/compiler.cpp +++ b/src/coreclr/jit/compiler.cpp @@ -5053,6 +5053,10 @@ void Compiler::compCompile(void** methodCodePtr, uint32_t* methodCodeSize, JitFl // Optimize block order // DoPhase(this, PHASE_OPTIMIZE_LAYOUT, &Compiler::optOptimizeLayout); + + // Conditional to Switch conversion + // + DoPhase(this, PHASE_SWITCH_RECOGNITION, &Compiler::optSwitchRecognition); } // Determine start of cold region if we are hot/cold splitting diff --git a/src/coreclr/jit/compiler.h b/src/coreclr/jit/compiler.h index 1e845dd53b758..3c4badb6088b4 100644 --- a/src/coreclr/jit/compiler.h +++ b/src/coreclr/jit/compiler.h @@ -6325,8 +6325,10 @@ class Compiler public: PhaseStatus optOptimizeBools(); + PhaseStatus optSwitchRecognition(); + bool optSwitchConvert(BasicBlock* firstBlock, int testsCount, ssize_t* testValues, GenTree* nodeToTest); + bool optSwitchDetectAndConvert(BasicBlock* firstBlock); -public: PhaseStatus optInvertLoops(); // Invert loops so they're entered at top and tested at bottom. PhaseStatus optOptimizeFlow(); // Simplify flow graph and do tail duplication PhaseStatus optOptimizeLayout(); // Optimize the BasicBlock layout of the method diff --git a/src/coreclr/jit/compphases.h b/src/coreclr/jit/compphases.h index 9ee73c7d69ad6..36eb647f04eba 100644 --- a/src/coreclr/jit/compphases.h +++ b/src/coreclr/jit/compphases.h @@ -72,6 +72,7 @@ CompPhaseNameMacro(PHASE_MORPH_MDARR, "Morph array ops", CompPhaseNameMacro(PHASE_HOIST_LOOP_CODE, "Hoist loop code", false, -1, false) CompPhaseNameMacro(PHASE_MARK_LOCAL_VARS, "Mark local vars", false, -1, false) CompPhaseNameMacro(PHASE_OPTIMIZE_BOOLS, "Optimize bools", false, -1, false) +CompPhaseNameMacro(PHASE_SWITCH_RECOGNITION, "Recognize Switch", false, -1, false) CompPhaseNameMacro(PHASE_FIND_OPER_ORDER, "Find oper order", false, -1, false) CompPhaseNameMacro(PHASE_SET_BLOCK_ORDER, "Set block order", false, -1, true) CompPhaseNameMacro(PHASE_BUILD_SSA, "Build SSA representation", true, -1, false) diff --git a/src/coreclr/jit/switchrecognition.cpp b/src/coreclr/jit/switchrecognition.cpp new file mode 100644 index 0000000000000..a31de6a97bce7 --- /dev/null +++ b/src/coreclr/jit/switchrecognition.cpp @@ -0,0 +1,387 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#include "jitpch.h" +#ifdef _MSC_VER +#pragma hdrstop +#endif + +// We mainly rely on TryLowerSwitchToBitTest in these heuristics, but jump tables can be useful +// even without conversion to a bitmap test. +#define SWITCH_MAX_DISTANCE ((TARGET_POINTER_SIZE * BITS_IN_BYTE) - 1) +#define SWITCH_MIN_TESTS 3 + +//----------------------------------------------------------------------------- +// optSwitchRecognition: Optimize range check for `x == cns1 || x == cns2 || x == cns3 ...` +// pattern and convert it to Switch block (jump table) which is then *might* be converted +// to a bitmap test via TryLowerSwitchToBitTest. +// TODO: recognize general jump table patterns. +// +// Return Value: +// MODIFIED_EVERYTHING if the optimization was applied. +// +PhaseStatus Compiler::optSwitchRecognition() +{ +// Limit to XARCH, ARM is already doing a great job with such comparisons using +// a series of ccmp instruction (see ifConvert phase). +#ifdef TARGET_XARCH + bool modified = false; + for (BasicBlock* block = fgFirstBB; block != nullptr; block = block->bbNext) + { + // block->KindIs(BBJ_COND) check is for better throughput. + if (block->KindIs(BBJ_COND) && !block->isRunRarely() && optSwitchDetectAndConvert(block)) + { + JITDUMP("Converted block " FMT_BB " to switch\n", block->bbNum) + modified = true; + } + } + + if (modified) + { + fgUpdateChangedFlowGraph(FlowGraphUpdates::COMPUTE_BASICS); + return PhaseStatus::MODIFIED_EVERYTHING; + } +#endif + return PhaseStatus::MODIFIED_NOTHING; +} + +//------------------------------------------------------------------------------ +// IsConstantTestCondBlock : Does the given block represent a simple BBJ_COND +// constant test? e.g. JTRUE(EQ/NE(X, CNS)). +// +// Arguments: +// block - The block to check +// blockIfTrue - [out] The block that will be jumped to if X == CNS +// blockIfFalse - [out] The block that will be jumped to if X != CNS +// isReversed - [out] True if the condition is reversed (GT_NE) +// variableNode - [out] The variable node (X in the example above) +// cns - [out] The constant value (CNS in the example above) +// +// Return Value: +// True if the block represents a constant test, false otherwise +// +bool IsConstantTestCondBlock(const BasicBlock* block, + BasicBlock** blockIfTrue, + BasicBlock** blockIfFalse, + bool* isReversed, + GenTree** variableNode = nullptr, + ssize_t* cns = nullptr) +{ + // NOTE: caller is expected to check that a block has multiple statements or not + if (block->KindIs(BBJ_COND) && (block->lastStmt() != nullptr) && ((block->bbFlags & BBF_DONT_REMOVE) == 0)) + { + const GenTree* rootNode = block->lastStmt()->GetRootNode(); + assert(rootNode->OperIs(GT_JTRUE)); + + // It has to be JTRUE(GT_EQ or GT_NE) + if (rootNode->gtGetOp1()->OperIs(GT_EQ, GT_NE)) + { + GenTree* op1 = rootNode->gtGetOp1()->gtGetOp1(); + GenTree* op2 = rootNode->gtGetOp1()->gtGetOp2(); + + if (!varTypeIsIntegral(op1) || !varTypeIsIntegral(op2)) + { + // Only integral types are supported + return false; + } + + // We're looking for "X EQ/NE CNS" or "CNS EQ/NE X" pattern + if (op1->IsCnsIntOrI() ^ op2->IsCnsIntOrI()) + { + // TODO: relax this to support any side-effect free expression + if (!op1->OperIs(GT_LCL_VAR) && !op2->OperIs(GT_LCL_VAR)) + { + return false; + } + + *isReversed = rootNode->gtGetOp1()->OperIs(GT_NE); + *blockIfTrue = *isReversed ? block->bbNext : block->bbJumpDest; + *blockIfFalse = *isReversed ? block->bbJumpDest : block->bbNext; + + if ((block->bbNext == block->bbJumpDest) || (block->bbJumpDest == block)) + { + // Ignoring weird cases like a condition jumping to itself + return false; + } + + if ((variableNode != nullptr) && (cns != nullptr)) + { + if (op1->IsCnsIntOrI()) + { + *cns = op1->AsIntCon()->IconValue(); + *variableNode = op2; + } + else + { + *cns = op2->AsIntCon()->IconValue(); + *variableNode = op1; + } + } + return true; + } + } + } + return false; +} + +//------------------------------------------------------------------------------ +// optSwitchDetectAndConvert : Try to detect a series of conditional blocks which +// can be converted into a switch (jump-table) construct. See optSwitchConvert +// for more details. +// +// Arguments: +// firstBlock - A block to start the search from +// +// Return Value: +// True if the conversion was successful, false otherwise +// +bool Compiler::optSwitchDetectAndConvert(BasicBlock* firstBlock) +{ + assert(firstBlock->KindIs(BBJ_COND)); + + GenTree* variableNode = nullptr; + ssize_t cns = 0; + BasicBlock* blockIfTrue = nullptr; + BasicBlock* blockIfFalse = nullptr; + + // The algorithm is simple - we check that the given block is a constant test block + // and then try to accumulate as many constant test blocks as possible. Once we hit + // a block that doesn't match the pattern, we start processing the accumulated blocks. + bool isReversed = false; + if (IsConstantTestCondBlock(firstBlock, &blockIfTrue, &blockIfFalse, &isReversed, &variableNode, &cns)) + { + if (isReversed) + { + // First block uses NE - we don't support this yet. We currently expect all blocks to use EQ + // and allow NE for the last one (because it's what Roslyn usually emits). + // TODO: make it more flexible and support cases like "x != cns1 && x != cns2 && ..." + return false; + } + + // No more than SWITCH_MAX_TABLE_SIZE blocks are allowed (arbitrary limit in this context) + int testValueIndex = 0; + ssize_t testValues[SWITCH_MAX_DISTANCE] = {}; + testValues[testValueIndex++] = cns; + + const BasicBlock* prevBlock = firstBlock; + + // Now walk the next blocks and see if they are basically the same type of test + for (const BasicBlock* currBb = firstBlock->bbNext; currBb != nullptr; currBb = currBb->bbNext) + { + GenTree* currVariableNode = nullptr; + ssize_t currCns = 0; + BasicBlock* currBlockIfTrue = nullptr; + BasicBlock* currBlockIfFalse = nullptr; + + if (!currBb->hasSingleStmt()) + { + // Only the first conditional block can have multiple statements. + // Stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + // Inspect secondary blocks + if (IsConstantTestCondBlock(currBb, &currBlockIfTrue, &currBlockIfFalse, &isReversed, &currVariableNode, + &currCns)) + { + if (currBlockIfTrue != blockIfTrue) + { + // This blocks jumps to a different target, stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + if (!GenTree::Compare(currVariableNode, variableNode)) + { + // A different variable node is used, stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + if (currBb->GetUniquePred(this) != prevBlock) + { + // Multiple preds in a secondary block, stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + if (!BasicBlock::sameEHRegion(prevBlock, currBb)) + { + // Current block is in a different EH region, stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + // Ok we can work with that, add the test value to the list + testValues[testValueIndex++] = currCns; + if (testValueIndex == SWITCH_MAX_DISTANCE) + { + // Too many suitable tests found - stop and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + if (isReversed) + { + // We only support reversed test (GT_NE) for the last block. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + + prevBlock = currBb; + } + else + { + // Current block is not a suitable test, stop searching and process what we already have. + return optSwitchConvert(firstBlock, testValueIndex, testValues, variableNode); + } + } + } + + return false; +} + +//------------------------------------------------------------------------------ +// optSwitchConvert : Convert a series of conditional blocks into a switch block +// conditional blocks are blocks that have a single statement that is a GT_EQ +// or GT_NE node. The blocks are expected jump into the same target and test +// the same variable against different constants +// +// Arguments: +// firstBlock - First conditional block in the chain +// testsCount - Number of conditional blocks in the chain +// testValues - Array of constants that are tested against the variable +// nodeToTest - Variable node that is tested against the constants +// +// Return Value: +// True if the conversion was successful, false otherwise +// +bool Compiler::optSwitchConvert(BasicBlock* firstBlock, int testsCount, ssize_t* testValues, GenTree* nodeToTest) +{ + assert(firstBlock->KindIs(BBJ_COND)); + + if (testsCount < SWITCH_MIN_TESTS) + { + // Early out - short chains. + return false; + } + + static_assert_no_msg(SWITCH_MIN_TESTS > 0); + + // Find max and min values in the testValues array + // At this point we have at least SWITCH_MIN_TESTS values in the array + ssize_t minValue = testValues[0]; + ssize_t maxValue = testValues[0]; + + int testIdx = 0; + for (; testIdx < testsCount; testIdx++) + { + ssize_t testValue = testValues[testIdx]; + if (testValue < 0) + { + // We don't support negative values + break; + } + + const ssize_t newMinValue = min(minValue, testValue); + const ssize_t newMaxValue = max(maxValue, testValue); + assert(newMaxValue >= newMinValue); + if ((newMaxValue - newMinValue) > SWITCH_MAX_DISTANCE) + { + // Stop here, the distance between min and max is too big + break; + } + minValue = newMinValue; + maxValue = newMaxValue; + } + + assert(testIdx <= testsCount); + if (testIdx < SWITCH_MIN_TESTS) + { + // Make sure we still have at least SWITCH_MIN_TESTS values after we filtered out some of them + return false; + } + + // if MaxValue is less than SWITCH_MAX_DISTANCE then don't bother with SUB(val, minValue) + if (maxValue <= SWITCH_MAX_DISTANCE) + { + minValue = 0; + } + + // Find the last block in the chain + const BasicBlock* lastBlock = firstBlock; + for (int i = 0; i < testIdx - 1; i++) + { + lastBlock = lastBlock->bbNext; + } + + BasicBlock* blockIfTrue = nullptr; + BasicBlock* blockIfFalse = nullptr; + bool isReversed = false; + const bool isTest = IsConstantTestCondBlock(lastBlock, &blockIfTrue, &blockIfFalse, &isReversed); + assert(isTest); + + // Convert firstBlock to a switch block + firstBlock->bbJumpKind = BBJ_SWITCH; + firstBlock->bbJumpDest = nullptr; + firstBlock->bbCodeOffsEnd = lastBlock->bbCodeOffsEnd; + firstBlock->lastStmt()->GetRootNode()->ChangeOper(GT_SWITCH); + + // The root node is now SUB(nodeToTest, minValue) if minValue != 0 + GenTree* switchValue = gtCloneExpr(nodeToTest); + if (minValue != 0) + { + switchValue = + gtNewOperNode(GT_SUB, nodeToTest->TypeGet(), switchValue, gtNewIconNode(minValue, nodeToTest->TypeGet())); + } + + firstBlock->lastStmt()->GetRootNode()->AsOp()->gtOp1 = switchValue; + gtSetStmtInfo(firstBlock->lastStmt()); + fgSetStmtSeq(firstBlock->lastStmt()); + gtUpdateStmtSideEffects(firstBlock->lastStmt()); + + // Unlink and remove the whole chain of conditional blocks + BasicBlock* blockToRemove = firstBlock->bbNext; + fgRemoveRefPred(blockToRemove, firstBlock); + while (blockToRemove != lastBlock->bbNext) + { + BasicBlock* nextBlock = blockToRemove->bbNext; + fgRemoveBlock(blockToRemove, true); + blockToRemove = nextBlock; + } + + const auto jumpCount = static_cast(maxValue - minValue + 1); + assert((jumpCount > 0) && (jumpCount <= SWITCH_MAX_DISTANCE + 1)); + const auto jmpTab = new (this, CMK_BasicBlock) BasicBlock*[jumpCount + 1 /*default case*/]; + + firstBlock->bbJumpSwt = new (this, CMK_BasicBlock) BBswtDesc; + firstBlock->bbJumpSwt->bbsCount = jumpCount + 1; + firstBlock->bbJumpSwt->bbsHasDefault = true; + firstBlock->bbJumpSwt->bbsDstTab = jmpTab; + firstBlock->bbNext = isReversed ? blockIfTrue : blockIfFalse; + fgHasSwitch = true; + + // Splitting doesn't work well with jump-tables currently + opts.compProcedureSplitting = false; + + // Compose a bit vector of all the values we have in the testValues array + // to quickly check if a value is in the array + ssize_t bitVector = 0; + for (testIdx = 0; testIdx < testsCount; testIdx++) + { + assert(testIdx <= (int)((sizeof(ssize_t) * BITS_PER_BYTE) - 1)); + bitVector |= (ssize_t)(1ULL << static_cast((testValues[testIdx] - minValue))); + } + + // Unlink blockIfTrue from firstBlock, we're going to link it again in the loop below. + fgRemoveRefPred(blockIfTrue, firstBlock); + + for (unsigned i = 0; i < jumpCount; i++) + { + // value exists in the testValues array (via bitVector) - 'true' case. + const bool isTrue = (bitVector & static_cast(1ULL << i)) != 0; + jmpTab[i] = isTrue ? blockIfTrue : blockIfFalse; + + fgAddRefPred(jmpTab[i], firstBlock); + } + + // Link the 'default' case + jmpTab[jumpCount] = blockIfFalse; + fgAddRefPred(blockIfFalse, firstBlock); + + return true; +} diff --git a/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs new file mode 100644 index 0000000000000..1292ad4088e5b --- /dev/null +++ b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +// unit test for Switch recognition optimization + +using System.Runtime.CompilerServices; +using Xunit; + +namespace optSwitchRecognition; + +public class CSwitchRecognitionTest +{ + // Test sorted char cases + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitchSortedChar(char c) + { + return (c == 'a' || c == 'b' || c == 'd' || c == 'f') ? true : false; + } + + [Theory] + [InlineData('a', true)] + [InlineData('b', true)] + [InlineData('c', false)] + [InlineData('d', true)] + [InlineData('e', false)] + [InlineData('f', true)] + [InlineData('z', false)] + [InlineData('A', false)] + [InlineData('Z', false)] + [InlineData('?', false)] + public static void TestRecSwitchSortedChar(char arg1, bool expected) => Assert.Equal(expected, RecSwitchSortedChar(arg1)); + + // Test unsorted char cases + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitchUnsortedChar(char c) + { + return (c == 'd' || c == 'f' || c == 'a' || c == 'b') ? true : false; + } + + [Theory] + [InlineData('a', true)] + [InlineData('b', true)] + [InlineData('c', false)] + [InlineData('d', true)] + [InlineData('e', false)] + [InlineData('f', true)] + [InlineData('z', false)] + [InlineData('A', false)] + [InlineData('Z', false)] + [InlineData('?', false)] + public static void TestRecSwitchUnsortedChar(char arg1, bool expected) => Assert.Equal(expected, RecSwitchUnsortedChar(arg1)); + + // Test sorted int cases + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitchSortedInt(int i) + { + return (i == -10 || i == -20 || i == 30 || i == 40) ? true : false; + } + + [Theory] + [InlineData(-100, false)] + [InlineData(-10, true)] + [InlineData(-20, true)] + [InlineData(0, false)] + [InlineData(30, true)] + [InlineData(35, false)] + [InlineData(40, true)] + [InlineData(70, false)] + [InlineData(100, false)] + public static void TestRecSwitchSortedInt(int arg1, bool expected) => Assert.Equal(expected, RecSwitchSortedInt(arg1)); + + // Test unsorted int cases + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitchUnsortedInt(int i) + { + return (i == 30 || i == 40 || i == -10 || i == -20) ? true : false; + } + + [Theory] + [InlineData(-100, false)] + [InlineData(-10, true)] + [InlineData(-20, true)] + [InlineData(0, false)] + [InlineData(30, true)] + [InlineData(35, false)] + [InlineData(40, true)] + [InlineData(70, false)] + [InlineData(100, false)] + public static void TestRecSwitchUnsortedInt(int arg1, bool expected) => Assert.Equal(expected, RecSwitchUnsortedInt(arg1)); + + // Test <= 64 switch cases + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitch64JumpTables(int i) + { + return (i == 0 || i == 4 || i == 6 || i == 63) ? true : false; + } + + [Theory] + [InlineData(-63, false)] + [InlineData(60, false)] + [InlineData(63, true)] + [InlineData(64, false)] + public static void TestRecSwitch64JumpTables(int arg1, bool expected) => Assert.Equal(expected, RecSwitch64JumpTables(arg1)); + + // + // Skip optimization + // + + // Test > 64 Switch cases (should skip Switch Recognition optimization) + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool RecSwitch128JumpTables(int i) + { + return (i == 0 || i == 4 || i == 6 || i == 127); + } + + [Theory] + [InlineData(-127, false)] + [InlineData(6, true)] + [InlineData(127, true)] + [InlineData(128, false)] + public static void TestRecSwitch128JumpTables(int arg1, bool expected) => Assert.Equal(expected, RecSwitch128JumpTables(arg1)); + + // Skips `bit test` conversion because Switch jump targets are > 2 (should skip Switch Recognition optimization) + [MethodImpl(MethodImplOptions.NoInlining)] + private static int RecSwitchSkipBitTest(int arch) + { + if (arch == 1) + return 2; + else if (arch == 2 || arch == 6) + return 4; + else + return 1; + } + + [Theory] + [InlineData(0, 1)] + [InlineData(1, 2)] + [InlineData(2, 4)] + [InlineData(6, 4)] + [InlineData(10, 1)] + public static void TestRecSwitchSkipBitTest(int arg1, int expected) => Assert.Equal(expected, RecSwitchSkipBitTest(arg1)); +} diff --git a/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.csproj b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.csproj new file mode 100644 index 0000000000000..de6d5e08882e8 --- /dev/null +++ b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.csproj @@ -0,0 +1,8 @@ + + + True + + + + +