diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 5828cc156cc78..0399499a6224e 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -460,6 +460,9 @@ class ScalarEvolution { LoopComputable ///< The SCEV varies predictably with the loop. }; + bool AssumeLoopFinite = false; + void setAssumeLoopExits(); + /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { DoesNotDominateBlock, ///< The SCEV does not dominate the block. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 93f885c5d5ad8..2f8bdb3ee366d 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -509,6 +509,8 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } +void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopFinite = true; } + const SCEV *ScalarEvolution::getElementCount(Type *Ty, ElementCount EC) { const SCEV *Res = getConstant(Ty, EC.getKnownMinValue()); if (EC.isScalable()) @@ -7422,7 +7424,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); + return AssumeLoopFinite || isFinite(L) || + (isMustProgress(L) && loopHasNoSideEffects(L)); } const SCEV *ScalarEvolution::createSCEVIter(Value *V) { diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index a7b3c5c404ab7..f74584636bd15 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1085,6 +1085,35 @@ TEST_F(ScalarEvolutionsTest, SCEVComputeExpressionSize) { EXPECT_EQ(S2S->getExpressionSize(), 5u); } +TEST_F(ScalarEvolutionsTest, AssumeLoopExists) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %N) { " + "entry: " + " %cmp3 = icmp sgt i32 %N, 0 " + " br i1 %cmp3, label %for.body, label %for.cond.cleanup " + "for.cond.cleanup: " + " ret void " + "for.body: " + " br label %for.body " + "} " + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + BasicBlock *L = F.begin()->getNextNode()->getNextNode(); + auto *Loop = LI.getLoopFor(L); + bool IsFinite = SE.loopIsFiniteByAssumption(Loop); + EXPECT_FALSE(IsFinite); + SE.setAssumeLoopExits(); + IsFinite = SE.loopIsFiniteByAssumption(Loop); + EXPECT_TRUE(IsFinite); + }); +} + TEST_F(ScalarEvolutionsTest, SCEVLoopDecIntrinsic) { LLVMContext C; SMDiagnostic Err;