Skip to content

Commit 1d9b322

Browse files
committed
[VPlan] Implement VPWidenSelectRecipe::computeCost.
Implement VPlan-based cost computation for VPWidenSelectRecipe.
1 parent b5bcdb5 commit 1d9b322

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

llvm/lib/Transforms/Vectorize/VPlan.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1701,3 +1701,11 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
17011701
Plan->print(O);
17021702
}
17031703
#endif
1704+
1705+
TargetTransformInfo::OperandValueInfo
1706+
VPCostContext::getOperandInfo(VPValue *V) const {
1707+
if (!V->isLiveIn())
1708+
return {};
1709+
1710+
return TTI::getOperandInfo(V->getLiveInIRValue());
1711+
}

llvm/lib/Transforms/Vectorize/VPlan.h

+8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "llvm/Analysis/DomTreeUpdater.h"
3939
#include "llvm/Analysis/IVDescriptors.h"
4040
#include "llvm/Analysis/LoopInfo.h"
41+
#include "llvm/Analysis/TargetTransformInfo.h"
4142
#include "llvm/Analysis/VectorUtils.h"
4243
#include "llvm/IR/DebugLoc.h"
4344
#include "llvm/IR/FMF.h"
@@ -738,6 +739,9 @@ struct VPCostContext {
738739
/// Return true if the cost for \p UI shouldn't be computed, e.g. because it
739740
/// has already been pre-computed.
740741
bool skipCostComputation(Instruction *UI, bool IsVector) const;
742+
743+
/// Returns the OperandInfo for \p V, if it is a live-in.
744+
TargetTransformInfo::OperandValueInfo getOperandInfo(VPValue *V) const;
741745
};
742746

743747
/// VPRecipeBase is a base class modeling a sequence of one or more output IR
@@ -1844,6 +1848,10 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe {
18441848
/// Produce a widened version of the select instruction.
18451849
void execute(VPTransformState &State) override;
18461850

1851+
/// Return the cost of this VPWidenSelectRecipe.
1852+
InstructionCost computeCost(ElementCount VF,
1853+
VPCostContext &Ctx) const override;
1854+
18471855
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
18481856
/// Print the recipe.
18491857
void print(raw_ostream &O, const Twine &Indent,

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

+32-6
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ template <unsigned BitWidth = 0> struct specific_intval {
7575
if (!CI)
7676
return false;
7777

78-
assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
79-
"Trying the match constant with unexpected bitwidth.");
78+
if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
79+
return false;
8080
return APInt::isSameValue(CI->getValue(), Val);
8181
}
8282
};
@@ -87,6 +87,8 @@ inline specific_intval<0> m_SpecificInt(uint64_t V) {
8787

8888
inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
8989

90+
inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
91+
9092
/// Matching combinators
9193
template <typename LTy, typename RTy> struct match_combine_or {
9294
LTy L;
@@ -122,7 +124,8 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
122124
auto *DefR = dyn_cast<RecipeTy>(R);
123125
// Check for recipes that do not have opcodes.
124126
if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
125-
std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
127+
std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
128+
std::is_same<RecipeTy, VPWidenSelectRecipe>::value)
126129
return DefR;
127130
else
128131
return DefR && DefR->getOpcode() == Opcode;
@@ -322,10 +325,34 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
322325
return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
323326
}
324327

328+
template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
329+
using AllTernaryRecipe_match =
330+
Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
331+
VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
332+
333+
template <typename Op0_t, typename Op1_t, typename Op2_t>
334+
inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
335+
m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
336+
return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
337+
{Op0, Op1, Op2});
338+
}
339+
325340
template <typename Op0_t, typename Op1_t>
326-
inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
341+
inline match_combine_or<
342+
BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
343+
AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
344+
Instruction::Select>>
327345
m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
328-
return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
346+
return m_CombineOr(
347+
m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
348+
m_Select(Op0, Op1, m_False()));
349+
}
350+
351+
template <typename Op0_t, typename Op1_t>
352+
inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
353+
Instruction::Select>
354+
m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
355+
return m_Select(Op0, m_True(), Op1);
329356
}
330357

331358
using VPCanonicalIVPHI_match =
@@ -344,7 +371,6 @@ inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
344371
const Op1_t &Op1) {
345372
return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
346373
}
347-
348374
} // namespace VPlanPatternMatch
349375
} // namespace llvm
350376

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "VPlan.h"
1515
#include "VPlanAnalysis.h"
16+
#include "VPlanPatternMatch.h"
1617
#include "VPlanUtils.h"
1718
#include "llvm/ADT/STLExtras.h"
1819
#include "llvm/ADT/SmallVector.h"
@@ -23,6 +24,7 @@
2324
#include "llvm/IR/Instruction.h"
2425
#include "llvm/IR/Instructions.h"
2526
#include "llvm/IR/Intrinsics.h"
27+
#include "llvm/IR/PatternMatch.h"
2628
#include "llvm/IR/Type.h"
2729
#include "llvm/IR/Value.h"
2830
#include "llvm/IR/VectorBuilder.h"
@@ -1200,6 +1202,46 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
12001202
State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
12011203
}
12021204

1205+
InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
1206+
VPCostContext &Ctx) const {
1207+
SelectInst *SI = cast<SelectInst>(getUnderlyingValue());
1208+
bool ScalarCond = getOperand(0)->isDefinedOutsideLoopRegions();
1209+
Type *ScalarTy = Ctx.Types.inferScalarType(this);
1210+
Type *VectorTy = ToVectorTy(Ctx.Types.inferScalarType(this), VF);
1211+
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
1212+
1213+
VPValue *Op0, *Op1;
1214+
using namespace llvm::VPlanPatternMatch;
1215+
if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 &&
1216+
(match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) ||
1217+
match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) {
1218+
// select x, y, false --> x & y
1219+
// select x, true, y --> x | y
1220+
const auto [Op1VK, Op1VP] = Ctx.getOperandInfo(Op0);
1221+
const auto [Op2VK, Op2VP] = Ctx.getOperandInfo(Op1);
1222+
1223+
SmallVector<const Value *, 2> Operands;
1224+
if (all_of(operands(),
1225+
[](VPValue *Op) { return Op->getUnderlyingValue(); }))
1226+
Operands.append(SI->op_begin(), SI->op_end());
1227+
bool IsLogicalOr = match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1)));
1228+
return Ctx.TTI.getArithmeticInstrCost(
1229+
IsLogicalOr ? Instruction::Or : Instruction::And, VectorTy, CostKind,
1230+
{Op1VK, Op1VP}, {Op2VK, Op2VP}, Operands, SI);
1231+
}
1232+
1233+
Type *CondTy = Ctx.Types.inferScalarType(getOperand(0));
1234+
if (!ScalarCond)
1235+
CondTy = VectorType::get(CondTy, VF);
1236+
1237+
CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
1238+
if (auto *Cmp = dyn_cast<CmpInst>(SI->getCondition()))
1239+
Pred = Cmp->getPredicate();
1240+
return Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VectorTy, CondTy, Pred,
1241+
CostKind, {TTI::OK_AnyValue, TTI::OP_None},
1242+
{TTI::OK_AnyValue, TTI::OP_None}, SI);
1243+
}
1244+
12031245
VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
12041246
const FastMathFlags &FMF) {
12051247
AllowReassoc = FMF.allowReassoc();

0 commit comments

Comments
 (0)