diff --git a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp index b374371667b5e..5eb2f058f329f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanConstruction.cpp @@ -386,33 +386,54 @@ std::unique_ptr VPlanTransforms::buildPlainCFG( /// Checks if \p HeaderVPB is a loop header block in the plain CFG; that is, it /// has exactly 2 predecessors (preheader and latch), where the block /// dominates the latch and the preheader dominates the block. If it is a -/// header block return true, making sure the preheader appears first and -/// the latch second. Otherwise return false. -static bool canonicalHeader(VPBlockBase *HeaderVPB, - const VPDominatorTree &VPDT) { +/// header block return true and canonicalize the predecessors of the header +/// (making sure the preheader appears first and the latch second) and the +/// successors of the latch (making sure the loop exit comes first). Otherwise +/// return false. +static bool canonicalHeaderAndLatch(VPBlockBase *HeaderVPB, + const VPDominatorTree &VPDT) { ArrayRef Preds = HeaderVPB->getPredecessors(); if (Preds.size() != 2) return false; auto *PreheaderVPBB = Preds[0]; auto *LatchVPBB = Preds[1]; - if (VPDT.dominates(PreheaderVPBB, HeaderVPB) && - VPDT.dominates(HeaderVPB, LatchVPBB)) - return true; + if (!VPDT.dominates(PreheaderVPBB, HeaderVPB) || + !VPDT.dominates(HeaderVPB, LatchVPBB)) { + std::swap(PreheaderVPBB, LatchVPBB); - std::swap(PreheaderVPBB, LatchVPBB); + if (!VPDT.dominates(PreheaderVPBB, HeaderVPB) || + !VPDT.dominates(HeaderVPB, LatchVPBB)) + return false; - if (VPDT.dominates(PreheaderVPBB, HeaderVPB) && - VPDT.dominates(HeaderVPB, LatchVPBB)) { - // Canonicalize predecessors of header so that preheader is first and latch - // second. + // Canonicalize predecessors of header so that preheader is first and + // latch second. HeaderVPB->swapPredecessors(); for (VPRecipeBase &R : cast(HeaderVPB)->phis()) R.swapOperands(); - return true; } - return false; + // The two successors of conditional branch match the condition, with the + // first successor corresponding to true and the second to false. We + // canonicalize the successors of the latch when introducing the region, such + // that the latch exits the region when its condition is true; invert the + // original condition if the original CFG branches to the header on true. + // Note that the exit edge is not yet connected for top-level loops. + if (LatchVPBB->getSingleSuccessor() || + LatchVPBB->getSuccessors()[0] != HeaderVPB) + return true; + + assert(LatchVPBB->getNumSuccessors() == 2 && "Must have 2 successors"); + auto *Term = cast(LatchVPBB)->getTerminator(); + assert(cast(Term)->getOpcode() == + VPInstruction::BranchOnCond && + "terminator must be a BranchOnCond"); + auto *Not = new VPInstruction(VPInstruction::Not, {Term->getOperand(0)}); + Not->insertBefore(Term); + Term->setOperand(0, Not); + LatchVPBB->swapSuccessors(); + + return true; } /// Create a new VPRegionBlock for the loop starting at \p HeaderVPB. @@ -447,7 +468,7 @@ void VPlanTransforms::createLoopRegions(VPlan &Plan, Type *InductionTy, VPDominatorTree VPDT; VPDT.recalculate(Plan); for (VPBlockBase *HeaderVPB : vp_depth_first_shallow(Plan.getEntry())) - if (canonicalHeader(HeaderVPB, VPDT)) + if (canonicalHeaderAndLatch(HeaderVPB, VPDT)) createLoopRegion(Plan, HeaderVPB); VPRegionBlock *TopRegion = Plan.getVectorLoopRegion(); diff --git a/llvm/test/Transforms/LoopVectorize/outer-loop-inner-latch-successors.ll b/llvm/test/Transforms/LoopVectorize/outer-loop-inner-latch-successors.ll index 388da8540646f..afd1308a2d24a 100644 --- a/llvm/test/Transforms/LoopVectorize/outer-loop-inner-latch-successors.ll +++ b/llvm/test/Transforms/LoopVectorize/outer-loop-inner-latch-successors.ll @@ -4,7 +4,6 @@ @A = common global [1024 x i64] zeroinitializer, align 16 @B = common global [1024 x i64] zeroinitializer, align 16 -; FIXME: The exit condition of the inner loop is incorrect when vectorizing. define void @inner_latch_header_first_successor(i64 %N, i32 %c, i64 %M) { ; CHECK-LABEL: define void @inner_latch_header_first_successor( ; CHECK-SAME: i64 [[N:%.*]], i32 [[C:%.*]], i64 [[M:%.*]]) { @@ -35,8 +34,9 @@ define void @inner_latch_header_first_successor(i64 %N, i32 %c, i64 %M) { ; CHECK-NEXT: [[TMP3]] = add nsw <4 x i64> [[TMP2]], [[VEC_PHI4]] ; CHECK-NEXT: [[TMP4]] = add nuw nsw <4 x i64> [[VEC_PHI]], splat (i64 1) ; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <4 x i64> [[TMP4]], [[BROADCAST_SPLAT2]] -; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i1> [[TMP5]], i32 0 -; CHECK-NEXT: br i1 [[TMP6]], label %[[VECTOR_LATCH]], label %[[INNER3]] +; CHECK-NEXT: [[TMP6:%.*]] = xor <4 x i1> [[TMP5]], splat (i1 true) +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <4 x i1> [[TMP6]], i32 0 +; CHECK-NEXT: br i1 [[TMP9]], label %[[VECTOR_LATCH]], label %[[INNER3]] ; CHECK: [[VECTOR_LATCH]]: ; CHECK-NEXT: [[VEC_PHI6:%.*]] = phi <4 x i64> [ [[TMP3]], %[[INNER3]] ] ; CHECK-NEXT: call void @llvm.masked.scatter.v4i64.v4p0(<4 x i64> [[VEC_PHI6]], <4 x ptr> [[TMP0]], i32 4, <4 x i1> splat (i1 true))