Skip to content

[AArch64] Add lowering for @llvm.experimental.vector.compress #101015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2412,11 +2412,64 @@ void DAGTypeLegalizer::SplitVecRes_VECTOR_COMPRESS(SDNode *N, SDValue &Lo,
SDValue &Hi) {
// This is not "trivial", as there is a dependency between the two subvectors.
// Depending on the number of 1s in the mask, the elements from the Hi vector
// need to be moved to the Lo vector. So we just perform this as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets rid
// of VECTOR_COMPRESS and all other operands can be legalized later.
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, SDLoc(N));
// need to be moved to the Lo vector. Passthru values make this even harder.
// We try to use VECTOR_COMPRESS if the target has custom lowering with
// smaller types and passthru is undef, as it is most likely faster than the
// fully expand path. Otherwise, just do the full expansion as one "big"
// operation and then extract the Lo and Hi vectors from that. This gets
// rid of VECTOR_COMPRESS and all other operands can be legalized later.
SDLoc DL(N);
EVT VecVT = N->getValueType(0);

auto [LoVT, HiVT] = DAG.GetSplitDestVTs(VecVT);
bool HasCustomLowering = false;
EVT CheckVT = LoVT;
while (CheckVT.getVectorMinNumElements() > 1) {
// TLI.isOperationLegalOrCustom requires a legal type, but we could have a
// custom lowering for illegal types. So we do the checks separately.
if (TLI.isOperationLegal(ISD::VECTOR_COMPRESS, CheckVT) ||
TLI.isOperationCustom(ISD::VECTOR_COMPRESS, CheckVT)) {
HasCustomLowering = true;
break;
}
CheckVT = CheckVT.getHalfNumVectorElementsVT(*DAG.getContext());
}

SDValue Passthru = N->getOperand(2);
if (!HasCustomLowering || !Passthru.isUndef()) {
SDValue Compressed = TLI.expandVECTOR_COMPRESS(N, DAG);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL, LoVT, HiVT);
return;
}

// Try to VECTOR_COMPRESS smaller vectors and combine via a stack store+load.
SDValue LoMask, HiMask;
std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
std::tie(LoMask, HiMask) = SplitMask(N->getOperand(1));

SDValue UndefPassthru = DAG.getUNDEF(LoVT);
Lo = DAG.getNode(ISD::VECTOR_COMPRESS, DL, LoVT, Lo, LoMask, UndefPassthru);
Hi = DAG.getNode(ISD::VECTOR_COMPRESS, DL, HiVT, Hi, HiMask, UndefPassthru);

SDValue StackPtr = DAG.CreateStackTemporary(
VecVT.getStoreSize(), DAG.getReducedAlign(VecVT, /*UseABI=*/false));
MachineFunction &MF = DAG.getMachineFunction();
MachinePointerInfo PtrInfo = MachinePointerInfo::getFixedStack(
MF, cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex());

// We store LoVec and then insert HiVec starting at offset=|1s| in LoMask.
SDValue WideMask =
DAG.getNode(ISD::ZERO_EXTEND, DL, LoMask.getValueType(), LoMask);
SDValue Offset = DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, WideMask);
Offset = TLI.getVectorElementPointer(DAG, StackPtr, VecVT, Offset);

SDValue Chain = DAG.getEntryNode();
Chain = DAG.getStore(Chain, DL, Lo, StackPtr, PtrInfo);
Chain = DAG.getStore(Chain, DL, Hi, Offset,
MachinePointerInfo::getUnknownStack(MF));

SDValue Compressed = DAG.getLoad(VecVT, DL, Chain, StackPtr, PtrInfo);
std::tie(Lo, Hi) = DAG.SplitVector(Compressed, DL);
}

void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
Expand Down Expand Up @@ -5790,7 +5843,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_COMPRESS(SDNode *N) {
TLI.getTypeToTransformTo(*DAG.getContext(), Vec.getValueType());
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(),
Mask.getValueType().getVectorElementType(),
WideVecVT.getVectorNumElements());
WideVecVT.getVectorElementCount());

SDValue WideVec = ModifyToType(Vec, WideVecVT);
SDValue WideMask = ModifyToType(Mask, WideMaskVT, /*FillWithZeroes=*/true);
Expand Down
116 changes: 116 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::v2f32, MVT::v4f32, MVT::v2f64})
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// We can lower types that have <vscale x {2|4}> elements to compact.
for (auto VT :
{MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv2f32,
MVT::nxv2f64, MVT::nxv4i8, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv4f32})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// If we have SVE, we can use SVE logic for legal (or smaller than legal)
// NEON vectors in the lowest bits of the SVE register.
for (auto VT : {MVT::v2i8, MVT::v2i16, MVT::v2i32, MVT::v2i64, MVT::v2f32,
MVT::v2f64, MVT::v4i8, MVT::v4i16, MVT::v4i32, MVT::v4f32})
setOperationAction(ISD::VECTOR_COMPRESS, VT, Custom);

// Histcnt is SVE2 only
if (Subtarget->hasSVE2())
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Expand Down Expand Up @@ -6616,6 +6628,104 @@ SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
return DAG.getMergeValues({Ext, Chain}, DL);
}

SDValue AArch64TargetLowering::LowerVECTOR_COMPRESS(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Vec = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue Passthru = Op.getOperand(2);
EVT VecVT = Vec.getValueType();
EVT MaskVT = Mask.getValueType();
EVT ElmtVT = VecVT.getVectorElementType();
const bool IsFixedLength = VecVT.isFixedLengthVector();
const bool HasPassthru = !Passthru.isUndef();
unsigned MinElmts = VecVT.getVectorElementCount().getKnownMinValue();
EVT FixedVecVT = MVT::getVectorVT(ElmtVT.getSimpleVT(), MinElmts);

assert(VecVT.isVector() && "Input to VECTOR_COMPRESS must be vector.");

if (!Subtarget->isSVEAvailable())
return SDValue();

if (IsFixedLength && VecVT.getSizeInBits().getFixedValue() > 128)
return SDValue();

// Only <vscale x {4|2} x {i32|i64}> supported for compact.
if (MinElmts != 2 && MinElmts != 4)
return SDValue();

// We can use the SVE register containing the NEON vector in its lowest bits.
if (IsFixedLength) {
EVT ScalableVecVT =
MVT::getScalableVectorVT(ElmtVT.getSimpleVT(), MinElmts);
EVT ScalableMaskVT = MVT::getScalableVectorVT(
MaskVT.getVectorElementType().getSimpleVT(), MinElmts);

Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Vec,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableMaskVT,
DAG.getUNDEF(ScalableMaskVT), Mask,
DAG.getConstant(0, DL, MVT::i64));
Mask = DAG.getNode(ISD::TRUNCATE, DL,
ScalableMaskVT.changeVectorElementType(MVT::i1), Mask);
Passthru = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVecVT,
DAG.getUNDEF(ScalableVecVT), Passthru,
DAG.getConstant(0, DL, MVT::i64));

VecVT = Vec.getValueType();
MaskVT = Mask.getValueType();
}

// Get legal type for compact instruction
EVT ContainerVT = getSVEContainerType(VecVT);
EVT CastVT = VecVT.changeVectorElementTypeToInteger();

// Convert to i32 or i64 for smaller types, as these are the only supported
// sizes for compact.
if (ContainerVT != VecVT) {
Vec = DAG.getBitcast(CastVT, Vec);
Vec = DAG.getNode(ISD::ANY_EXTEND, DL, ContainerVT, Vec);
}

SDValue Compressed = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, Vec.getValueType(),
DAG.getConstant(Intrinsic::aarch64_sve_compact, DL, MVT::i64), Mask, Vec);

// compact fills with 0s, so if our passthru is all 0s, do nothing here.
if (HasPassthru && !ISD::isConstantSplatVectorAllZeros(Passthru.getNode())) {
SDValue Offset = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
DAG.getConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64), Mask, Mask);

SDValue IndexMask = DAG.getNode(
ISD::INTRINSIC_WO_CHAIN, DL, MaskVT,
DAG.getConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64),
DAG.getConstant(0, DL, MVT::i64), Offset);

Compressed =
DAG.getNode(ISD::VSELECT, DL, VecVT, IndexMask, Compressed, Passthru);
}

// Extracting from a legal SVE type before truncating produces better code.
if (IsFixedLength) {
Compressed = DAG.getNode(
ISD::EXTRACT_SUBVECTOR, DL,
FixedVecVT.changeVectorElementType(ContainerVT.getVectorElementType()),
Compressed, DAG.getConstant(0, DL, MVT::i64));
CastVT = FixedVecVT.changeVectorElementTypeToInteger();
VecVT = FixedVecVT;
}

// If we changed the element type before, we need to convert it back.
if (ContainerVT != VecVT) {
Compressed = DAG.getNode(ISD::TRUNCATE, DL, CastVT, Compressed);
Compressed = DAG.getBitcast(VecVT, Compressed);
}

return Compressed;
}

// Generate SUBS and CSEL for integer abs.
SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
MVT VT = Op.getSimpleValueType();
Expand Down Expand Up @@ -6996,6 +7106,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerDYNAMIC_STACKALLOC(Op, DAG);
case ISD::VSCALE:
return LowerVSCALE(Op, DAG);
case ISD::VECTOR_COMPRESS:
return LowerVECTOR_COMPRESS(Op, DAG);
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
Expand Down Expand Up @@ -26372,6 +26484,10 @@ void AArch64TargetLowering::ReplaceNodeResults(
case ISD::VECREDUCE_UMIN:
Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
return;
case ISD::VECTOR_COMPRESS:
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,8 @@ class AArch64TargetLowering : public TargetLowering {

SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerVECTOR_COMPRESS(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading