Skip to content

Commit e0ad56b

Browse files
authored
[AArch64] Add lowering for @llvm.experimental.vector.compress (#101015)
This is a follow-up to #92289 that adds custom lowering of the new `@llvm.experimental.vector.compress` intrinsic on AArch64 with SVE instructions. Some vectors have a `compact` instruction that they can be lowered to.
1 parent 306b9c7 commit e0ad56b

File tree

4 files changed

+453
-6
lines changed

4 files changed

+453
-6
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

+59-6
Original file line numberDiff line numberDiff line change
@@ -2412,11 +2412,64 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
24122412
SDValue &Hi) {
24132413
// This is not "trivial", as there is a dependency between the two subvectors.
24142414
// Depending on the number of 1s in the mask, the elements from the Hi vector
2415-
// need to be moved to the Lo vector. So we just perform this as one "big"
2416-
// operation and then extract the Lo and Hi vectors from that. This gets rid
2417-
// of VECTOR_COMPRESS and all other operands can be legalized later.
2418-
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
2419-
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
2415+
// need to be moved to the Lo vector. Passthru values make this even harder.
2416+
// We try to use VECTOR_COMPRESS if the target has custom lowering with
2417+
// smaller types and passthru is undef, as it is most likely faster than the
2418+
// fully expand path. Otherwise, just do the full expansion as one "big"
2419+
// operation and then extract the Lo and Hi vectors from that. This gets
2420+
// rid of VECTOR_COMPRESS and all other operands can be legalized later.
2421+
SDLoc DL(N);
2422+
EVT VecVT = N->getValueType(0);
2423+
2424+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
2425+
bool HasCustomLowering = false;
2426+
EVT CheckVT = LoVT;
2427+
while (CheckVT.getVectorMinNumElements() > 1) {
2428+
// TLI.isOperationLegalOrCustom requires a legal type, but we could have a
2429+
// custom lowering for illegal types. So we do the checks separately.
2430+
if (TLI.isOperationLegal(ISD::VECTOR_COMPRESS, CheckVT) ||
2431+
TLI.isOperationCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
2432+
HasCustomLowering = true;
2433+
break;
2434+
}
2435+
CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
2436+
}
2437+
2438+
SDValue Passthru = N->getOperand(2);
2439+
if (!HasCustomLowering || !Passthru.isUndef()) {
2440+
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
2441+
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
2442+
return;
2443+
}
2444+
2445+
// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
2446+
SDValue LoMask, HiMask;
2447+
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
2448+
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));
2449+
2450+
SDValue UndefPassthru = DAG.getUNDEF(LoVT);
2451+
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
2452+
Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);
2453+
2454+
SDValue StackPtr = DAG.CreateStackTemporary(
2455+
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
2456+
MachineFunction &MF = DAG.getMachineFunction();
2457+
MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
2458+
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());
2459+
2460+
// We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
2461+
SDValue WideMask =
2462+
DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
2463+
SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
2464+
Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);
2465+
2466+
SDValue Chain = DAG.getEntryNode();
2467+
Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
2468+
Chain = DAG.getStore(Chain, DL, Hi, Offset,
2469+
MachinePointerInfo::getUnknownStack(MF));
2470+
2471+
SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
2472+
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
24202473
}
24212474

24222475
void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
@@ -5790,7 +5843,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
57905843
TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
57915844
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
57925845
Mask.getValueType().getVectorElementType(),
5793-
WideVecVT.getVectorNumElements());
5846+
WideVecVT.getVectorElementCount());
57945847

57955848
SDValue WideVec = ModifyToType(Vec, WideVecVT);
57965849
SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

+116
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
17751775
MVT::v2f32, MVT::v4f32, MVT::v2f64})
17761776
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
17771777

1778+
// We can lower types that have <vscale x {2|4}> elements to compact.
1779+
for (auto VT :
1780+
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
1781+
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
1782+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1783+
1784+
// If we have SVE, we can use SVE logic for legal (or smaller than legal)
1785+
// NEON vectors in the lowest bits of the SVE register.
1786+
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
1787+
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
1788+
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1789+
17781790
// Histcnt is SVE2 only
17791791
if (Subtarget->hasSVE2()) {
17801792
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
@@ -6619,6 +6631,104 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
66196631
return DAG.getMergeValues({Ext, Chain}, DL);
66206632
}
66216633

6634+
SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
6635+
SelectionDAG &DAG) const {
6636+
SDLoc DL(Op);
6637+
SDValue Vec = Op.getOperand(0);
6638+
SDValue Mask = Op.getOperand(1);
6639+
SDValue Passthru = Op.getOperand(2);
6640+
EVT VecVT = Vec.getValueType();
6641+
EVT MaskVT = Mask.getValueType();
6642+
EVT ElmtVT = VecVT.getVectorElementType();
6643+
const bool IsFixedLength = VecVT.isFixedLengthVector();
6644+
const bool HasPassthru = !Passthru.isUndef();
6645+
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
6646+
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
6647+
6648+
assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
6649+
6650+
if (!Subtarget->isSVEAvailable())
6651+
return SDValue();
6652+
6653+
if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
6654+
return SDValue();
6655+
6656+
// Only <vscale x {4|2} x {i32|i64}> supported for compact.
6657+
if (MinElmts != 2 && MinElmts != 4)
6658+
return SDValue();
6659+
6660+
// We can use the SVE register containing the NEON vector in its lowest bits.
6661+
if (IsFixedLength) {
6662+
EVT ScalableVecVT =
6663+
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
6664+
EVT ScalableMaskVT = MVT::getScalableVectorVT(
6665+
MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
6666+
6667+
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
6668+
DAG.getUNDEF(ScalableVecVT), Vec,
6669+
DAG.getConstant(0, DL, MVT::i64));
6670+
Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
6671+
DAG.getUNDEF(ScalableMaskVT), Mask,
6672+
DAG.getConstant(0, DL, MVT::i64));
6673+
Mask = DAG.getNode(ISD::TRUNCATE, DL,
6674+
ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
6675+
Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
6676+
DAG.getUNDEF(ScalableVecVT), Passthru,
6677+
DAG.getConstant(0, DL, MVT::i64));
6678+
6679+
VecVT = Vec.getValueType();
6680+
MaskVT = Mask.getValueType();
6681+
}
6682+
6683+
// Get legal type for compact instruction
6684+
EVT ContainerVT = getSVEContainerType(VecVT);
6685+
EVT CastVT = VecVT.changeVectorElementTypeToInteger();
6686+
6687+
// Convert to i32 or i64 for smaller types, as these are the only supported
6688+
// sizes for compact.
6689+
if (ContainerVT != VecVT) {
6690+
Vec = DAG.getBitcast(CastVT, Vec);
6691+
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
6692+
}
6693+
6694+
SDValue Compressed = DAG.getNode(
6695+
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
6696+
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
6697+
6698+
// compact fills with 0s, so if our passthru is all 0s, do nothing here.
6699+
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
6700+
SDValue Offset = DAG.getNode(
6701+
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
6702+
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);
6703+
6704+
SDValue IndexMask = DAG.getNode(
6705+
ISD::INTRINSIC_WO_CHAIN, DL, MaskVT,
6706+
DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
6707+
DAG.getConstant(0, DL, MVT::i64), Offset);
6708+
6709+
Compressed =
6710+
DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
6711+
}
6712+
6713+
// Extracting from a legal SVE type before truncating produces better code.
6714+
if (IsFixedLength) {
6715+
Compressed = DAG.getNode(
6716+
ISD::EXTRACT_SUBVECTOR, DL,
6717+
FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
6718+
Compressed, DAG.getConstant(0, DL, MVT::i64));
6719+
CastVT = FixedVecVT.changeVectorElementTypeToInteger();
6720+
VecVT = FixedVecVT;
6721+
}
6722+
6723+
// If we changed the element type before, we need to convert it back.
6724+
if (ContainerVT != VecVT) {
6725+
Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
6726+
Compressed = DAG.getBitcast(VecVT, Compressed);
6727+
}
6728+
6729+
return Compressed;
6730+
}
6731+
66226732
// Generate SUBS and CSEL for integer abs.
66236733
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
66246734
MVT VT = Op.getSimpleValueType();
@@ -6999,6 +7109,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
69997109
return LowerDYNAMIC_STACKALLOC(Op, DAG);
70007110
case ISD::VSCALE:
70017111
return LowerVSCALE(Op, DAG);
7112+
case ISD::VECTOR_COMPRESS:
7113+
return LowerVECTOR_COMPRESS(Op, DAG);
70027114
case ISD::ANY_EXTEND:
70037115
case ISD::SIGN_EXTEND:
70047116
case ISD::ZERO_EXTEND:
@@ -26563,6 +26675,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
2656326675
case ISD::VECREDUCE_UMIN:
2656426676
Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
2656526677
return;
26678+
case ISD::VECTOR_COMPRESS:
26679+
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
26680+
Results.push_back(Res);
26681+
return;
2656626682
case ISD::ADD:
2656726683
case ISD::FADD:
2656826684
ReplaceAddWithADDP(N, Results, DAG, Subtarget);

llvm/lib/Target/AArch64/AArch64ISelLowering.h

+2
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,8 @@ class AArch64TargetLowering : public TargetLowering {
10751075

10761076
SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;
10771077

1078+
SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;
1079+
10781080
SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
10791081
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
10801082
SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)