@@ -1371,6 +1371,18 @@ class BoUpSLP {
1371
1371
return MinBWs.at(VectorizableTree.front().get()).second;
1372
1372
}
1373
1373
1374
+ /// Returns reduction bitwidth and signedness, if it does not match the
1375
+ /// original requested size.
1376
+ std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1377
+ if (ReductionBitWidth == 0 ||
1378
+ ReductionBitWidth ==
1379
+ DL->getTypeSizeInBits(
1380
+ VectorizableTree.front()->Scalars.front()->getType()))
1381
+ return std::nullopt;
1382
+ return std::make_pair(ReductionBitWidth,
1383
+ MinBWs.at(VectorizableTree.front().get()).second);
1384
+ }
1385
+
1374
1386
/// Builds external uses of the vectorized scalars, i.e. the list of
1375
1387
/// vectorized scalars to be extracted, their lanes and their scalar users. \p
1376
1388
/// ExternallyUsedValues contains additional list of external uses to handle
@@ -17887,24 +17899,37 @@ void BoUpSLP::computeMinimumValueSizes() {
17887
17899
// Add reduction ops sizes, if any.
17888
17900
if (UserIgnoreList &&
17889
17901
isa<IntegerType>(VectorizableTree.front()->Scalars.front()->getType())) {
17890
- for (Value *V : *UserIgnoreList) {
17891
- auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17892
- auto NumTypeBits = DL->getTypeSizeInBits(V->getType());
17893
- unsigned BitWidth1 = NumTypeBits - NumSignBits;
17894
- if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17895
- ++BitWidth1;
17896
- unsigned BitWidth2 = BitWidth1;
17897
- if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17898
- auto Mask = DB->getDemandedBits(cast<Instruction>(V));
17899
- BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17902
+ // Convert vector_reduce_add(ZExt(<n x i1>)) to ZExtOrTrunc(ctpop(bitcast <n
17903
+ // x i1> to in)).
17904
+ if (all_of(*UserIgnoreList,
17905
+ [](Value *V) {
17906
+ return cast<Instruction>(V)->getOpcode() == Instruction::Add;
17907
+ }) &&
17908
+ VectorizableTree.front()->State == TreeEntry::Vectorize &&
17909
+ VectorizableTree.front()->getOpcode() == Instruction::ZExt &&
17910
+ cast<CastInst>(VectorizableTree.front()->getMainOp())->getSrcTy() ==
17911
+ Builder.getInt1Ty()) {
17912
+ ReductionBitWidth = 1;
17913
+ } else {
17914
+ for (Value *V : *UserIgnoreList) {
17915
+ unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
17916
+ TypeSize NumTypeBits = DL->getTypeSizeInBits(V->getType());
17917
+ unsigned BitWidth1 = NumTypeBits - NumSignBits;
17918
+ if (!isKnownNonNegative(V, SimplifyQuery(*DL)))
17919
+ ++BitWidth1;
17920
+ unsigned BitWidth2 = BitWidth1;
17921
+ if (!RecurrenceDescriptor::isIntMinMaxRecurrenceKind(::getRdxKind(V))) {
17922
+ APInt Mask = DB->getDemandedBits(cast<Instruction>(V));
17923
+ BitWidth2 = Mask.getBitWidth() - Mask.countl_zero();
17924
+ }
17925
+ ReductionBitWidth =
17926
+ std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17900
17927
}
17901
- ReductionBitWidth =
17902
- std::max(std::min(BitWidth1, BitWidth2), ReductionBitWidth);
17903
- }
17904
- if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17905
- ReductionBitWidth = 8;
17928
+ if (ReductionBitWidth < 8 && ReductionBitWidth > 1)
17929
+ ReductionBitWidth = 8;
17906
17930
17907
- ReductionBitWidth = bit_ceil(ReductionBitWidth);
17931
+ ReductionBitWidth = bit_ceil(ReductionBitWidth);
17932
+ }
17908
17933
}
17909
17934
bool IsTopRoot = NodeIdx == 0;
17910
17935
while (NodeIdx < VectorizableTree.size() &&
@@ -19760,8 +19785,8 @@ class HorizontalReduction {
19760
19785
19761
19786
// Estimate cost.
19762
19787
InstructionCost TreeCost = V.getTreeCost(VL);
19763
- InstructionCost ReductionCost =
19764
- getReductionCost( TTI, VL, IsCmpSelMinMax, ReduxWidth, RdxFMF );
19788
+ InstructionCost ReductionCost = getReductionCost(
19789
+ TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
19765
19790
InstructionCost Cost = TreeCost + ReductionCost;
19766
19791
LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
19767
19792
<< " for reduction\n");
@@ -19866,10 +19891,12 @@ class HorizontalReduction {
19866
19891
createStrideMask(I, ScalarTyNumElements, VL.size());
19867
19892
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
19868
19893
ReducedSubTree = Builder.CreateInsertElement(
19869
- ReducedSubTree, emitReduction(Lane, Builder, TTI), I);
19894
+ ReducedSubTree,
19895
+ emitReduction(Lane, Builder, TTI, RdxRootInst->getType()), I);
19870
19896
}
19871
19897
} else {
19872
- ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI);
19898
+ ReducedSubTree = emitReduction(VectorizedRoot, Builder, TTI,
19899
+ RdxRootInst->getType());
19873
19900
}
19874
19901
if (ReducedSubTree->getType() != VL.front()->getType()) {
19875
19902
assert(ReducedSubTree->getType() != VL.front()->getType() &&
@@ -20050,12 +20077,13 @@ class HorizontalReduction {
20050
20077
20051
20078
private:
20052
20079
/// Calculate the cost of a reduction.
20053
- InstructionCost getReductionCost(TargetTransformInfo *TTI,
20054
- ArrayRef<Value *> ReducedVals,
20055
- bool IsCmpSelMinMax, unsigned ReduxWidth ,
20056
- FastMathFlags FMF ) {
20080
+ InstructionCost getReductionCost(
20081
+ TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20082
+ bool IsCmpSelMinMax, FastMathFlags FMF ,
20083
+ const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
20057
20084
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
20058
20085
Type *ScalarTy = ReducedVals.front()->getType();
20086
+ unsigned ReduxWidth = ReducedVals.size();
20059
20087
FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth);
20060
20088
InstructionCost VectorCost = 0, ScalarCost;
20061
20089
// If all of the reduced values are constant, the vector cost is 0, since
@@ -20114,8 +20142,22 @@ class HorizontalReduction {
20114
20142
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
20115
20143
/*Extract*/ false, TTI::TCK_RecipThroughput);
20116
20144
} else {
20117
- VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
20118
- CostKind);
20145
+ auto [Bitwidth, IsSigned] =
20146
+ BitwidthAndSign.value_or(std::make_pair(0u, false));
20147
+ if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20148
+ // Represent vector_reduce_add(ZExt(<n x i1>)) to
20149
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20150
+ auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20151
+ IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20152
+ VectorCost =
20153
+ TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20154
+ getWidenedType(ScalarTy, ReduxWidth),
20155
+ TTI::CastContextHint::None, CostKind) +
20156
+ TTI->getIntrinsicInstrCost(ICA, CostKind);
20157
+ } else {
20158
+ VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
20159
+ FMF, CostKind);
20160
+ }
20119
20161
}
20120
20162
}
20121
20163
ScalarCost = EvaluateScalarCost([&]() {
@@ -20152,11 +20194,22 @@ class HorizontalReduction {
20152
20194
20153
20195
/// Emit a horizontal reduction of the vectorized value.
20154
20196
Value *emitReduction(Value *VectorizedValue, IRBuilderBase &Builder,
20155
- const TargetTransformInfo *TTI) {
20197
+ const TargetTransformInfo *TTI, Type *DestTy ) {
20156
20198
assert(VectorizedValue && "Need to have a vectorized tree node");
20157
20199
assert(RdxKind != RecurKind::FMulAdd &&
20158
20200
"A call to the llvm.fmuladd intrinsic is not handled yet");
20159
20201
20202
+ auto *FTy = cast<FixedVectorType>(VectorizedValue->getType());
20203
+ if (FTy->getScalarType() == Builder.getInt1Ty() &&
20204
+ RdxKind == RecurKind::Add &&
20205
+ DestTy->getScalarType() != FTy->getScalarType()) {
20206
+ // Convert vector_reduce_add(ZExt(<n x i1>)) to
20207
+ // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20208
+ Value *V = Builder.CreateBitCast(
20209
+ VectorizedValue, Builder.getIntNTy(FTy->getNumElements()));
20210
+ ++NumVectorInstructions;
20211
+ return Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, V);
20212
+ }
20160
20213
++NumVectorInstructions;
20161
20214
return createSimpleReduction(Builder, VectorizedValue, RdxKind);
20162
20215
}
0 commit comments