@@ -1775,6 +1775,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1775
1775
MVT::v2f32, MVT::v4f32, MVT::v2f64})
1776
1776
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1777
1777
1778
+ // We can lower types that have <vscale x {2|4}> elements to compact.
1779
+ for (auto VT :
1780
+ {MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
1781
+ MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
1782
+ setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1783
+
1784
+ // If we have SVE, we can use SVE logic for legal (or smaller than legal)
1785
+ // NEON vectors in the lowest bits of the SVE register.
1786
+ for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
1787
+ MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
1788
+ setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);
1789
+
1778
1790
// Histcnt is SVE2 only
1779
1791
if (Subtarget->hasSVE2()) {
1780
1792
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
@@ -6619,6 +6631,104 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
6619
6631
return DAG.getMergeValues({Ext, Chain}, DL);
6620
6632
}
6621
6633
6634
+ SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
6635
+ SelectionDAG &DAG) const {
6636
+ SDLoc DL(Op);
6637
+ SDValue Vec = Op.getOperand(0);
6638
+ SDValue Mask = Op.getOperand(1);
6639
+ SDValue Passthru = Op.getOperand(2);
6640
+ EVT VecVT = Vec.getValueType();
6641
+ EVT MaskVT = Mask.getValueType();
6642
+ EVT ElmtVT = VecVT.getVectorElementType();
6643
+ const bool IsFixedLength = VecVT.isFixedLengthVector();
6644
+ const bool HasPassthru = !Passthru.isUndef();
6645
+ unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
6646
+ EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);
6647
+
6648
+ assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");
6649
+
6650
+ if (!Subtarget->isSVEAvailable())
6651
+ return SDValue();
6652
+
6653
+ if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
6654
+ return SDValue();
6655
+
6656
+ // Only <vscale x {4|2} x {i32|i64}> supported for compact.
6657
+ if (MinElmts != 2 && MinElmts != 4)
6658
+ return SDValue();
6659
+
6660
+ // We can use the SVE register containing the NEON vector in its lowest bits.
6661
+ if (IsFixedLength) {
6662
+ EVT ScalableVecVT =
6663
+ MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
6664
+ EVT ScalableMaskVT = MVT::getScalableVectorVT(
6665
+ MaskVT.getVectorElementType().getSimpleVT(), MinElmts);
6666
+
6667
+ Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
6668
+ DAG.getUNDEF(ScalableVecVT), Vec,
6669
+ DAG.getConstant(0, DL, MVT::i64));
6670
+ Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
6671
+ DAG.getUNDEF(ScalableMaskVT), Mask,
6672
+ DAG.getConstant(0, DL, MVT::i64));
6673
+ Mask = DAG.getNode(ISD::TRUNCATE, DL,
6674
+ ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
6675
+ Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
6676
+ DAG.getUNDEF(ScalableVecVT), Passthru,
6677
+ DAG.getConstant(0, DL, MVT::i64));
6678
+
6679
+ VecVT = Vec.getValueType();
6680
+ MaskVT = Mask.getValueType();
6681
+ }
6682
+
6683
+ // Get legal type for compact instruction
6684
+ EVT ContainerVT = getSVEContainerType(VecVT);
6685
+ EVT CastVT = VecVT.changeVectorElementTypeToInteger();
6686
+
6687
+ // Convert to i32 or i64 for smaller types, as these are the only supported
6688
+ // sizes for compact.
6689
+ if (ContainerVT != VecVT) {
6690
+ Vec = DAG.getBitcast(CastVT, Vec);
6691
+ Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
6692
+ }
6693
+
6694
+ SDValue Compressed = DAG.getNode(
6695
+ ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
6696
+ DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);
6697
+
6698
+ // compact fills with 0s, so if our passthru is all 0s, do nothing here.
6699
+ if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
6700
+ SDValue Offset = DAG.getNode(
6701
+ ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
6702
+ DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);
6703
+
6704
+ SDValue IndexMask = DAG.getNode(
6705
+ ISD::INTRINSIC_WO_CHAIN, DL, MaskVT,
6706
+ DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
6707
+ DAG.getConstant(0, DL, MVT::i64), Offset);
6708
+
6709
+ Compressed =
6710
+ DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
6711
+ }
6712
+
6713
+ // Extracting from a legal SVE type before truncating produces better code.
6714
+ if (IsFixedLength) {
6715
+ Compressed = DAG.getNode(
6716
+ ISD::EXTRACT_SUBVECTOR, DL,
6717
+ FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
6718
+ Compressed, DAG.getConstant(0, DL, MVT::i64));
6719
+ CastVT = FixedVecVT.changeVectorElementTypeToInteger();
6720
+ VecVT = FixedVecVT;
6721
+ }
6722
+
6723
+ // If we changed the element type before, we need to convert it back.
6724
+ if (ContainerVT != VecVT) {
6725
+ Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
6726
+ Compressed = DAG.getBitcast(VecVT, Compressed);
6727
+ }
6728
+
6729
+ return Compressed;
6730
+ }
6731
+
6622
6732
// Generate SUBS and CSEL for integer abs.
6623
6733
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
6624
6734
MVT VT = Op.getSimpleValueType();
@@ -6999,6 +7109,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6999
7109
return LowerDYNAMIC_STACKALLOC(Op, DAG);
7000
7110
case ISD::VSCALE:
7001
7111
return LowerVSCALE(Op, DAG);
7112
+ case ISD::VECTOR_COMPRESS:
7113
+ return LowerVECTOR_COMPRESS(Op, DAG);
7002
7114
case ISD::ANY_EXTEND:
7003
7115
case ISD::SIGN_EXTEND:
7004
7116
case ISD::ZERO_EXTEND:
@@ -26563,6 +26675,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
26563
26675
case ISD::VECREDUCE_UMIN:
26564
26676
Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
26565
26677
return;
26678
+ case ISD::VECTOR_COMPRESS:
26679
+ if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
26680
+ Results.push_back(Res);
26681
+ return;
26566
26682
case ISD::ADD:
26567
26683
case ISD::FADD:
26568
26684
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
0 commit comments