Skip to content

[VP][RISCV] Introduce vp.splat and RISC-V. #98731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22841,6 +22841,51 @@ Examples:
llvm.experimental.vp.splice(<A,B,C,D>, <E,F,G,H>, -2, 3, 2); ==> <B, C, poison, poison> trailing elements


.. _int_experimental_vp_splat:


'``llvm.experimental.vp.splat``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Syntax:
"""""""
This is an overloaded intrinsic.

::

declare <2 x double> @llvm.experimental.vp.splat.v2f64(<2 x double> %vec, <2 x i1> %mask, i32 %evl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be something like (double %scalar, <2 x i1> %mask, i32 %evl)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thank you for the finding.

declare <vscale x 4 x i32> @llvm.experimental.vp.splat.nxv4i32(<vscale x 4 x i32> %vec, <vscale x 4 x i1> %mask, i32 %evl)

Overview:
"""""""""

The '``llvm.experimental.vp.splat.*``' intrinsic is to create a prdicated splat
with specific effective vector length.

Arguments:
""""""""""

The result is a vector and it is a splat of the second scalar operand. The
second argument ``mask`` is a vector mask and has the same number of elements as
the result. The third argument is the explicit vector length of the operation.

Semantics:
""""""""""

This intrinsic splats a vector with ``evl`` elements of a scalar operand.
The lanes in the result vector disabled by ``mask`` are ``poison``. The
elements past ``evl`` are poison.

Examples:
"""""""""

.. code-block:: llvm

%r = call <4 x float> @llvm.vp.splat.v4f32(float %a, <4 x i1> %mask, i32 %evl)
;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
%also.r = select <4 x i1> %mask, <4 x float> splat(float %a), <4 x float> poison


.. _int_experimental_vp_reverse:


Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,13 @@ def int_experimental_vp_reverse:
llvm_i32_ty],
[IntrNoMem]>;

def int_experimental_vp_splat:
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMVectorElementType<0>,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
llvm_i32_ty],
[IntrNoMem]>;

def int_vp_is_fpclass:
DefaultAttrsIntrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
[ llvm_anyvector_ty,
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,13 @@ END_REGISTER_VP(experimental_vp_reverse, EXPERIMENTAL_VP_REVERSE)

///// } Shuffles

// llvm.vp.splat(ptr,val,mask,vlen)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm.experimental.vp.splat(val,mask,vlen) or x or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

BEGIN_REGISTER_VP_INTRINSIC(experimental_vp_splat, 1, 2)
BEGIN_REGISTER_VP_SDNODE(EXPERIMENTAL_VP_SPLAT, -1, experimental_vp_splat, 1, 2)
VP_PROPERTY_NO_FUNCTIONAL
HELPER_MAP_VPID_TO_VPSD(experimental_vp_splat, EXPERIMENTAL_VP_SPLAT)
END_REGISTER_VP(experimental_vp_splat, EXPERIMENTAL_VP_SPLAT)

#undef BEGIN_REGISTER_VP
#undef BEGIN_REGISTER_VP_INTRINSIC
#undef BEGIN_REGISTER_VP_SDNODE
Expand Down
17 changes: 14 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
break;
case ISD::SPLAT_VECTOR:
case ISD::SCALAR_TO_VECTOR:
case ISD::EXPERIMENTAL_VP_SPLAT:
Res = PromoteIntRes_ScalarOp(N);
break;
case ISD::STEP_VECTOR: Res = PromoteIntRes_STEP_VECTOR(N); break;
Expand Down Expand Up @@ -1916,6 +1917,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
break;
case ISD::SPLAT_VECTOR:
case ISD::SCALAR_TO_VECTOR:
case ISD::EXPERIMENTAL_VP_SPLAT:
Res = PromoteIntOp_ScalarOp(N);
break;
case ISD::VSELECT:
Expand Down Expand Up @@ -2211,10 +2213,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_INSERT_VECTOR_ELT(SDNode *N,
}

SDValue DAGTypeLegalizer::PromoteIntOp_ScalarOp(SDNode *N) {
SDValue Op = GetPromotedInteger(N->getOperand(0));
if (N->getOpcode() == ISD::EXPERIMENTAL_VP_SPLAT)
return DAG.getNode(ISD::EXPERIMENTAL_VP_SPLAT, SDLoc(N), N->getValueType(0),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use UpdateNodeOperands?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Op, N->getOperand(1), N->getOperand(2));

// Integer SPLAT_VECTOR/SCALAR_TO_VECTOR operands are implicitly truncated,
// so just promote the operand in place.
return SDValue(DAG.UpdateNodeOperands(N,
GetPromotedInteger(N->getOperand(0))), 0);
return SDValue(DAG.UpdateNodeOperands(N, Op), 0);
}

SDValue DAGTypeLegalizer::PromoteIntOp_SELECT(SDNode *N, unsigned OpNo) {
Expand Down Expand Up @@ -5231,6 +5237,7 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break;
case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break;
case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break;
case ISD::EXPERIMENTAL_VP_SPLAT:
case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break;
case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break;
case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break;
Expand Down Expand Up @@ -5859,7 +5866,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ScalarOp(SDNode *N) {
EVT NOutElemVT = NOutVT.getVectorElementType();

SDValue Op = DAG.getNode(ISD::ANY_EXTEND, dl, NOutElemVT, N->getOperand(0));

if (N->isVPOpcode()) {
SDValue Mask = N->getOperand(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pass N->getOperand(1) and N->getOperand(2) directly to getNode without the temporaries?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

SDValue VL = N->getOperand(2);
return DAG.getNode(N->getOpcode(), dl, NOutVT, Op, Mask, VL);
}
return DAG.getNode(N->getOpcode(), dl, NOutVT, Op);
}

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_Gather(MemSDNode *VPGT, SDValue &Lo, SDValue &Hi,
bool SplitSETCC = false);
void SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_SPLAT(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_STEP_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VECTOR_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down Expand Up @@ -1052,6 +1053,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_MGATHER(SDNode* N, unsigned OpNo);
SDValue WidenVecOp_MSCATTER(SDNode* N, unsigned OpNo);
SDValue WidenVecOp_VP_SCATTER(SDNode* N, unsigned OpNo);
SDValue WidenVecOp_VP_SPLAT(SDNode *N, unsigned OpNo);
SDValue WidenVecOp_SETCC(SDNode* N);
SDValue WidenVecOp_STRICT_FSETCC(SDNode* N);
SDValue WidenVecOp_VSELECT(SDNode *N);
Expand Down
26 changes: 26 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FCOPYSIGN: SplitVecRes_FPOp_MultiType(N, Lo, Hi); break;
case ISD::IS_FPCLASS: SplitVecRes_IS_FPCLASS(N, Lo, Hi); break;
case ISD::INSERT_VECTOR_ELT: SplitVecRes_INSERT_VECTOR_ELT(N, Lo, Hi); break;
case ISD::EXPERIMENTAL_VP_SPLAT: SplitVecRes_VP_SPLAT(N, Lo, Hi); break;
case ISD::SPLAT_VECTOR:
case ISD::SCALAR_TO_VECTOR:
SplitVecRes_ScalarOp(N, Lo, Hi);
Expand Down Expand Up @@ -1992,6 +1993,16 @@ void DAGTypeLegalizer::SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo,
}
}

void DAGTypeLegalizer::SplitVecRes_VP_SPLAT(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc dl(N);
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(N->getValueType(0));
auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
auto [EVLLo, EVLHi] = DAG.SplitEVL(N->getOperand(2), N->getValueType(0), dl);
Lo = DAG.getNode(N->getOpcode(), dl, LoVT, N->getOperand(0), MaskLo, EVLLo);
Hi = DAG.getNode(N->getOpcode(), dl, HiVT, N->getOperand(0), MaskHi, EVLHi);
}

void DAGTypeLegalizer::SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo,
SDValue &Hi) {
assert(ISD::isUNINDEXEDLoad(LD) && "Indexed load during type legalization!");
Expand Down Expand Up @@ -4284,6 +4295,7 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::STEP_VECTOR:
case ISD::SPLAT_VECTOR:
case ISD::SCALAR_TO_VECTOR:
case ISD::EXPERIMENTAL_VP_SPLAT:
Res = WidenVecRes_ScalarOp(N);
break;
case ISD::SIGN_EXTEND_INREG: Res = WidenVecRes_InregOp(N); break;
Expand Down Expand Up @@ -5814,6 +5826,9 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_GATHER(VPGatherSDNode *N) {

SDValue DAGTypeLegalizer::WidenVecRes_ScalarOp(SDNode *N) {
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
if (N->isVPOpcode())
return DAG.getNode(N->getOpcode(), SDLoc(N), WidenVT, N->getOperand(0),
N->getOperand(1), N->getOperand(2));
return DAG.getNode(N->getOpcode(), SDLoc(N), WidenVT, N->getOperand(0));
}

Expand Down Expand Up @@ -6353,6 +6368,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
Res = WidenVecOp_FP_TO_XINT_SAT(N);
break;

case ISD::EXPERIMENTAL_VP_SPLAT:
Res = WidenVecOp_VP_SPLAT(N, OpNo);
break;

case ISD::VECREDUCE_FADD:
case ISD::VECREDUCE_FMUL:
case ISD::VECREDUCE_ADD:
Expand Down Expand Up @@ -6813,6 +6832,13 @@ SDValue DAGTypeLegalizer::WidenVecOp_STORE(SDNode *N) {
report_fatal_error("Unable to widen vector store");
}

SDValue DAGTypeLegalizer::WidenVecOp_VP_SPLAT(SDNode *N, unsigned OpNo) {
assert(OpNo == 1 && "Can widen only mask operand of vp_splat");
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
N->getOperand(0), GetWidenedVector(N->getOperand(1)),
N->getOperand(2));
}

SDValue DAGTypeLegalizer::WidenVecOp_VP_STORE(SDNode *N, unsigned OpNo) {
assert((OpNo == 1 || OpNo == 3) &&
"Can widen only data or mask operand of vp_store");
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,9 @@ Function *VPIntrinsic::getDeclarationForParams(Module *M, Intrinsic::ID VPID,
VPFunc = Intrinsic::getDeclaration(
M, VPID, {Params[0]->getType(), Params[1]->getType()});
break;
case Intrinsic::experimental_vp_splat:
VPFunc = Intrinsic::getDeclaration(M, VPID, ReturnType);
break;
}
assert(VPFunc && "Could not declare VP intrinsic");
return VPFunc;
Expand Down
30 changes: 28 additions & 2 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
ISD::VP_SADDSAT, ISD::VP_UADDSAT, ISD::VP_SSUBSAT,
ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF,
ISD::EXPERIMENTAL_VP_SPLAT};

static const unsigned FloatingPointVPOps[] = {
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
Expand All @@ -721,7 +722,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::VP_FMINIMUM, ISD::VP_FMAXIMUM, ISD::VP_LRINT,
ISD::VP_LLRINT, ISD::EXPERIMENTAL_VP_REVERSE,
ISD::EXPERIMENTAL_VP_SPLICE, ISD::VP_REDUCE_FMINIMUM,
ISD::VP_REDUCE_FMAXIMUM};
ISD::VP_REDUCE_FMAXIMUM, ISD::EXPERIMENTAL_VP_SPLAT};

static const unsigned IntegerVecReduceOps[] = {
ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, ISD::VECREDUCE_OR,
Expand Down Expand Up @@ -7268,6 +7269,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return lowerVPSpliceExperimental(Op, DAG);
case ISD::EXPERIMENTAL_VP_REVERSE:
return lowerVPReverseExperimental(Op, DAG);
case ISD::EXPERIMENTAL_VP_SPLAT:
return lowerVPSplatExperimental(Op, DAG);
case ISD::CLEAR_CACHE: {
assert(getTargetMachine().getTargetTriple().isOSLinux() &&
"llvm.clear_cache only needs custom lower on Linux targets");
Expand Down Expand Up @@ -11630,6 +11633,29 @@ RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
return convertFromScalableVector(VT, Result, DAG, Subtarget);
}

SDValue RISCVTargetLowering::lowerVPSplatExperimental(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Val = Op.getOperand(0);
SDValue Mask = Op.getOperand(1);
SDValue VL = Op.getOperand(2);
MVT VT = Op.getSimpleValueType();

MVT ContainerVT = VT;
if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(VT);
MVT MaskVT = getMaskTypeFor(ContainerVT);
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
}

SDValue Result =
lowerScalarSplat(SDValue(), Val, VL, ContainerVT, DL, DAG, Subtarget);

if (!VT.isFixedLengthVector())
return Result;
return convertFromScalableVector(VT, Result, DAG, Subtarget);
}

SDValue
RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
SelectionDAG &DAG) const {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ class RISCVTargetLowering : public TargetLowering {
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSplatExperimental(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPSpliceExperimental(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPReverseExperimental(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;
Expand Down
Loading
Loading