diff --git a/src/coreclr/jit/fgopt.cpp b/src/coreclr/jit/fgopt.cpp index 547688a959c4da..35a4346d540843 100644 --- a/src/coreclr/jit/fgopt.cpp +++ b/src/coreclr/jit/fgopt.cpp @@ -1793,6 +1793,171 @@ bool Compiler::fgOptimizeSwitchBranches(BasicBlock* block) return true; } + else if (block->GetSwitchTargets()->GetSuccCount() == 2 && block->GetSwitchTargets()->HasDefaultCase() && + (block->IsLIR() || fgNodeThreading == NodeThreading::AllTrees)) + { + // If all non-default cases jump to one target and the default jumps to a different target, + // replace the switch with an unsigned comparison against the max case index: + // GT_SWITCH(switchVal) -> GT_JTRUE(GT_LE/GT_GT(switchVal, caseCount - 2)) + // The comparison direction is chosen to favor fall-through to the next block. + // When both targets are simple return blocks, fgFoldCondToReturnBlock can further + // convert this into branchless codegen like "cmp; setbe" instead of a jump table. + + BBswtDesc* switchDesc = block->GetSwitchTargets(); + unsigned caseCount = switchDesc->GetCaseCount(); + + // For small switches with 2 or fewer non-default cases (caseCount <= 3 including the default), + // the backend's switch lowering already generates efficient comparison chains. Converting + // early can enable if-conversion (cmov) that produces worse code for these cases. + if (caseCount <= 3) + { + return modified; + } + + FlowEdge* defaultEdge = switchDesc->GetDefaultCase(); + BasicBlock* defaultDest = defaultEdge->getDestinationBlock(); + + // Check that all non-default cases share the same target, distinct from the default target. + FlowEdge* firstCaseEdge = switchDesc->GetCase(0); + BasicBlock* caseDest = firstCaseEdge->getDestinationBlock(); + + if (caseDest != defaultDest) + { + bool allCasesSameTarget = true; + for (unsigned i = 1; i < caseCount - 1; i++) + { + if (switchDesc->GetCase(i)->getDestinationBlock() != caseDest) + { + allCasesSameTarget = false; + break; + } + } + + if (allCasesSameTarget) + { + GenTree* switchVal = switchTree->AsOp()->gtOp1; + noway_assert(genActualTypeIsIntOrI(switchVal->TypeGet())); + + // If we are in LIR, remove the jump table from the block. + if (block->IsLIR()) + { + GenTree* jumpTable = switchTree->AsOp()->gtOp2; + assert(jumpTable->OperIs(GT_JMPTABLE)); + blockRange->Remove(jumpTable); + } + + // The highest case index is caseCount - 2 (caseCount includes the default case). + // Using an unsigned GT comparison handles negative values correctly because they + // wrap to large unsigned values, making them greater than maxCaseIndex as expected. + const unsigned maxCaseIndex = caseCount - 2; + + JITDUMP("\nConverting a switch (" FMT_BB + ") where all non-default cases target the same block to a " + "conditional branch. Before:\n", + block->bbNum); + DISPNODE(switchTree); + + switchTree->ChangeOper(GT_JTRUE); + GenTree* maxCaseNode = gtNewIconNode(maxCaseIndex, genActualType(switchVal->TypeGet())); + + // Choose the comparison direction so the fall-through (false edge) + // targets the lexically next block when possible. + GenTree* condNode; + FlowEdge* trueEdge; + FlowEdge* falseEdge; + + if (block->NextIs(defaultDest)) + { + // defaultDest is next: use GT_LE so false (out of range) falls through + // to defaultDest + condNode = gtNewOperNode(GT_LE, TYP_INT, switchVal, maxCaseNode); + trueEdge = firstCaseEdge; + falseEdge = + (switchDesc->GetSucc(0) == firstCaseEdge) ? switchDesc->GetSucc(1) : switchDesc->GetSucc(0); + } + else + { + // caseDest is next, or neither is next: use GT_GT so false (in range) + // falls through to caseDest, which is typically the hotter path + condNode = gtNewOperNode(GT_GT, TYP_INT, switchVal, maxCaseNode); + trueEdge = + (switchDesc->GetSucc(0) == firstCaseEdge) ? switchDesc->GetSucc(1) : switchDesc->GetSucc(0); + falseEdge = firstCaseEdge; + } + + assert(trueEdge != nullptr); + assert(falseEdge != nullptr); + + condNode->SetUnsigned(); + switchTree->AsOp()->gtOp1 = condNode; + switchTree->AsOp()->gtOp1->gtFlags |= (GTF_RELOP_JMP_USED | GTF_DONT_CSE); + + if (block->IsLIR()) + { + blockRange->InsertAfter(switchVal, maxCaseNode, condNode); + LIR::ReadOnlyRange range(maxCaseNode, switchTree); + m_pLowering->LowerRange(block, range); + } + else if (fgNodeThreading == NodeThreading::AllTrees) + { + gtSetStmtInfo(switchStmt); + fgSetStmtSeq(switchStmt); + } + + // Fix up dup counts: multiple switch cases originally pointed to the same + // successor, but the conditional branch has exactly one edge per target. + const unsigned trueDupCount = trueEdge->getDupCount(); + const unsigned falseDupCount = falseEdge->getDupCount(); + + if (trueDupCount > 1) + { + trueEdge->decrementDupCount(trueDupCount - 1); + trueEdge->getDestinationBlock()->bbRefs -= (trueDupCount - 1); + } + if (falseDupCount > 1) + { + falseEdge->decrementDupCount(falseDupCount - 1); + falseEdge->getDestinationBlock()->bbRefs -= (falseDupCount - 1); + } + + block->SetCond(trueEdge, falseEdge); + + // The switch-to-cond conversion preserved edge likelihoods but + // successor block weights may be stale (they were set during import + // based on the original switch topology). Recompute them so that + // downstream passes like block compaction see correct weights. + if (block->hasProfileWeight()) + { + if (caseDest->hasProfileWeight()) + { + weight_t oldWeight = caseDest->bbWeight; + caseDest->setBBProfileWeight(caseDest->computeIncomingWeight()); + JITDUMP("Updated " FMT_BB " (caseDest) profile weight from " FMT_WT " to " FMT_WT "\n", + caseDest->bbNum, oldWeight, caseDest->bbWeight); + } + if (defaultDest->hasProfileWeight()) + { + weight_t oldWeight = defaultDest->bbWeight; + defaultDest->setBBProfileWeight(defaultDest->computeIncomingWeight()); + JITDUMP("Updated " FMT_BB " (defaultDest) profile weight from " FMT_WT " to " FMT_WT "\n", + defaultDest->bbNum, oldWeight, defaultDest->bbWeight); + } + } + + JITDUMP("After:\n"); + DISPNODE(switchTree); + + if (fgFoldCondToReturnBlock(block)) + { + JITDUMP("Folded conditional return into branchless return. After:\n"); + DISPNODE(switchTree); + } + + return true; + } + } + } + return modified; } diff --git a/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs index 1292ad4088e5bb..beeabe279e387a 100644 --- a/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs +++ b/src/tests/JIT/opt/OptSwitchRecognition/optSwitchRecognition.cs @@ -139,4 +139,52 @@ private static int RecSwitchSkipBitTest(int arch) [InlineData(6, 4)] [InlineData(10, 1)] public static void TestRecSwitchSkipBitTest(int arg1, int expected) => Assert.Equal(expected, RecSwitchSkipBitTest(arg1)); + + // Test that consecutive equality comparisons (comparison chain) produce the same result + // as pattern matching. The switch recognition should convert the chain to a switch, and + // then fgOptimizeSwitchBranches should simplify it to an unsigned range comparison. + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool IsLetterCategoryCompare(int uc) + { + return uc == 0 + || uc == 1 + || uc == 2 + || uc == 3 + || uc == 4; + } + + [Theory] + [InlineData(-1, false)] + [InlineData(0, true)] + [InlineData(1, true)] + [InlineData(2, true)] + [InlineData(3, true)] + [InlineData(4, true)] + [InlineData(5, false)] + [InlineData(100, false)] + [InlineData(int.MinValue, false)] + [InlineData(int.MaxValue, false)] + public static void TestSwitchToRangeCheck(int arg1, bool expected) => Assert.Equal(expected, IsLetterCategoryCompare(arg1)); + + // Test with non-zero-based consecutive values + [MethodImpl(MethodImplOptions.NoInlining)] + private static bool IsInRange10To14(int val) + { + return val == 10 + || val == 11 + || val == 12 + || val == 13 + || val == 14; + } + + [Theory] + [InlineData(9, false)] + [InlineData(10, true)] + [InlineData(11, true)] + [InlineData(12, true)] + [InlineData(13, true)] + [InlineData(14, true)] + [InlineData(15, false)] + [InlineData(-1, false)] + public static void TestSwitchToRangeCheckNonZeroBased(int arg1, bool expected) => Assert.Equal(expected, IsInRange10To14(arg1)); }