Skip to content

Commit 746cea3

Browse files
authored
[VP][RISCV] Introduce vp.splat and RISC-V. (#98731)
This patch introduces a vp intrinsic for splat. It's helpful for IR-level passes to create a splat with specific vector length.
1 parent b6c4ad7 commit 746cea3

File tree

12 files changed

+1051
-4
lines changed

12 files changed

+1051
-4
lines changed

llvm/docs/LangRef.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22841,6 +22841,53 @@ Examples:
2284122841
llvm.experimental.vp.splice(<A,B,C,D>, <E,F,G,H>, -2, 3, 2); ==> <B, C, poison, poison> trailing elements
2284222842

2284322843

22844+
.. _int_experimental_vp_splat:
22845+
22846+
22847+
'``llvm.experimental.vp.splat``' Intrinsic
22848+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
22849+
22850+
Syntax:
22851+
"""""""
22852+
This is an overloaded intrinsic.
22853+
22854+
::
22855+
22856+
declare <2 x double> @llvm.experimental.vp.splat.v2f64(double %scalar, <2 x i1> %mask, i32 %evl)
22857+
declare <vscale x 4 x i32> @llvm.experimental.vp.splat.nxv4i32(i32 %scalar, <vscale x 4 x i1> %mask, i32 %evl)
22858+
22859+
Overview:
22860+
"""""""""
22861+
22862+
The '``llvm.experimental.vp.splat.*``' intrinsic is to create a predicated splat
22863+
with specific effective vector length.
22864+
22865+
Arguments:
22866+
""""""""""
22867+
22868+
The result is a vector and it is a splat of the first scalar argument. The
22869+
second argument ``mask`` is a vector mask and has the same number of elements as
22870+
the result. The third argument is the explicit vector length of the operation.
22871+
22872+
Semantics:
22873+
""""""""""
22874+
22875+
This intrinsic splats a vector with ``evl`` elements of a scalar argument.
22876+
The lanes in the result vector disabled by ``mask`` are ``poison``. The
22877+
elements past ``evl`` are poison.
22878+
22879+
Examples:
22880+
"""""""""
22881+
22882+
.. code-block:: llvm
22883+
22884+
%r = call <4 x float> @llvm.vp.splat.v4f32(float %a, <4 x i1> %mask, i32 %evl)
22885+
;; For all lanes below %evl, %r is lane-wise equivalent to %also.r
22886+
%e = insertelement <4 x float> poison, float %a, i32 0
22887+
%s = shufflevector <4 x float> %e, <4 x float> poison, <4 x i32> zeroinitializer
22888+
%also.r = select <4 x i1> %mask, <4 x float> %s, <4 x float> poison
22889+
22890+
2284422891
.. _int_experimental_vp_reverse:
2284522892

2284622893

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,6 +2342,13 @@ def int_experimental_vp_reverse:
23422342
llvm_i32_ty],
23432343
[IntrNoMem]>;
23442344

2345+
def int_experimental_vp_splat:
2346+
DefaultAttrsIntrinsic<[llvm_anyvector_ty],
2347+
[LLVMVectorElementType<0>,
2348+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
2349+
llvm_i32_ty],
2350+
[IntrNoMem]>;
2351+
23452352
def int_vp_is_fpclass:
23462353
DefaultAttrsIntrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
23472354
[ llvm_anyvector_ty,

llvm/include/llvm/IR/VPIntrinsics.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,13 @@ END_REGISTER_VP(experimental_vp_reverse, EXPERIMENTAL_VP_REVERSE)
777777

778778
///// } Shuffles
779779

780+
// llvm.vp.splat(val,mask,vlen)
781+
BEGIN_REGISTER_VP_INTRINSIC(experimental_vp_splat, 1, 2)
782+
BEGIN_REGISTER_VP_SDNODE(EXPERIMENTAL_VP_SPLAT, -1, experimental_vp_splat, 1, 2)
783+
VP_PROPERTY_NO_FUNCTIONAL
784+
HELPER_MAP_VPID_TO_VPSD(experimental_vp_splat, EXPERIMENTAL_VP_SPLAT)
785+
END_REGISTER_VP(experimental_vp_splat, EXPERIMENTAL_VP_SPLAT)
786+
780787
#undef BEGIN_REGISTER_VP
781788
#undef BEGIN_REGISTER_VP_INTRINSIC
782789
#undef BEGIN_REGISTER_VP_SDNODE

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
137137
break;
138138
case ISD::SPLAT_VECTOR:
139139
case ISD::SCALAR_TO_VECTOR:
140+
case ISD::EXPERIMENTAL_VP_SPLAT:
140141
Res = PromoteIntRes_ScalarOp(N);
141142
break;
142143
case ISD::STEP_VECTOR: Res = PromoteIntRes_STEP_VECTOR(N); break;
@@ -1920,6 +1921,7 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
19201921
break;
19211922
case ISD::SPLAT_VECTOR:
19221923
case ISD::SCALAR_TO_VECTOR:
1924+
case ISD::EXPERIMENTAL_VP_SPLAT:
19231925
Res = PromoteIntOp_ScalarOp(N);
19241926
break;
19251927
case ISD::VSELECT:
@@ -2215,10 +2217,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_INSERT_VECTOR_ELT(SDNode *N,
22152217
}
22162218

22172219
SDValue DAGTypeLegalizer::PromoteIntOp_ScalarOp(SDNode *N) {
2220+
SDValue Op = GetPromotedInteger(N->getOperand(0));
2221+
if (N->getOpcode() == ISD::EXPERIMENTAL_VP_SPLAT)
2222+
return SDValue(
2223+
DAG.UpdateNodeOperands(N, Op, N->getOperand(1), N->getOperand(2)), 0);
2224+
22182225
// Integer SPLAT_VECTOR/SCALAR_TO_VECTOR operands are implicitly truncated,
22192226
// so just promote the operand in place.
2220-
return SDValue(DAG.UpdateNodeOperands(N,
2221-
GetPromotedInteger(N->getOperand(0))), 0);
2227+
return SDValue(DAG.UpdateNodeOperands(N, Op), 0);
22222228
}
22232229

22242230
SDValue DAGTypeLegalizer::PromoteIntOp_SELECT(SDNode *N, unsigned OpNo) {
@@ -5235,6 +5241,7 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
52355241
case ISD::EXTRACT_ELEMENT: Res = ExpandOp_EXTRACT_ELEMENT(N); break;
52365242
case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break;
52375243
case ISD::SCALAR_TO_VECTOR: Res = ExpandOp_SCALAR_TO_VECTOR(N); break;
5244+
case ISD::EXPERIMENTAL_VP_SPLAT:
52385245
case ISD::SPLAT_VECTOR: Res = ExpandIntOp_SPLAT_VECTOR(N); break;
52395246
case ISD::SELECT_CC: Res = ExpandIntOp_SELECT_CC(N); break;
52405247
case ISD::SETCC: Res = ExpandIntOp_SETCC(N); break;
@@ -5863,6 +5870,9 @@ SDValue DAGTypeLegalizer::PromoteIntRes_ScalarOp(SDNode *N) {
58635870
EVT NOutElemVT = NOutVT.getVectorElementType();
58645871

58655872
SDValue Op = DAG.getNode(ISD::ANY_EXTEND, dl, NOutElemVT, N->getOperand(0));
5873+
if (N->isVPOpcode())
5874+
return DAG.getNode(N->getOpcode(), dl, NOutVT, Op, N->getOperand(1),
5875+
N->getOperand(2));
58665876

58675877
return DAG.getNode(N->getOpcode(), dl, NOutVT, Op);
58685878
}

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
928928
void SplitVecRes_Gather(MemSDNode *VPGT, SDValue &Lo, SDValue &Hi,
929929
bool SplitSETCC = false);
930930
void SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo, SDValue &Hi);
931+
void SplitVecRes_VP_SPLAT(SDNode *N, SDValue &Lo, SDValue &Hi);
931932
void SplitVecRes_STEP_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi);
932933
void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi);
933934
void SplitVecRes_VECTOR_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -1065,6 +1066,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
10651066
SDValue WidenVecOp_MGATHER(SDNode* N, unsigned OpNo);
10661067
SDValue WidenVecOp_MSCATTER(SDNode* N, unsigned OpNo);
10671068
SDValue WidenVecOp_VP_SCATTER(SDNode* N, unsigned OpNo);
1069+
SDValue WidenVecOp_VP_SPLAT(SDNode *N, unsigned OpNo);
10681070
SDValue WidenVecOp_SETCC(SDNode* N);
10691071
SDValue WidenVecOp_STRICT_FSETCC(SDNode* N);
10701072
SDValue WidenVecOp_VSELECT(SDNode *N);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
10851085
case ISD::FCOPYSIGN: SplitVecRes_FPOp_MultiType(N, Lo, Hi); break;
10861086
case ISD::IS_FPCLASS: SplitVecRes_IS_FPCLASS(N, Lo, Hi); break;
10871087
case ISD::INSERT_VECTOR_ELT: SplitVecRes_INSERT_VECTOR_ELT(N, Lo, Hi); break;
1088+
case ISD::EXPERIMENTAL_VP_SPLAT: SplitVecRes_VP_SPLAT(N, Lo, Hi); break;
10881089
case ISD::SPLAT_VECTOR:
10891090
case ISD::SCALAR_TO_VECTOR:
10901091
SplitVecRes_ScalarOp(N, Lo, Hi);
@@ -2007,6 +2008,16 @@ void DAGTypeLegalizer::SplitVecRes_ScalarOp(SDNode *N, SDValue &Lo,
20072008
}
20082009
}
20092010

2011+
void DAGTypeLegalizer::SplitVecRes_VP_SPLAT(SDNode *N, SDValue &Lo,
2012+
SDValue &Hi) {
2013+
SDLoc dl(N);
2014+
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(N->getValueType(0));
2015+
auto [MaskLo, MaskHi] = SplitMask(N->getOperand(1));
2016+
auto [EVLLo, EVLHi] = DAG.SplitEVL(N->getOperand(2), N->getValueType(0), dl);
2017+
Lo = DAG.getNode(N->getOpcode(), dl, LoVT, N->getOperand(0), MaskLo, EVLLo);
2018+
Hi = DAG.getNode(N->getOpcode(), dl, HiVT, N->getOperand(0), MaskHi, EVLHi);
2019+
}
2020+
20102021
void DAGTypeLegalizer::SplitVecRes_LOAD(LoadSDNode *LD, SDValue &Lo,
20112022
SDValue &Hi) {
20122023
assert(ISD::isUNINDEXEDLoad(LD) && "Indexed load during type legalization!");
@@ -4299,6 +4310,7 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
42994310
case ISD::STEP_VECTOR:
43004311
case ISD::SPLAT_VECTOR:
43014312
case ISD::SCALAR_TO_VECTOR:
4313+
case ISD::EXPERIMENTAL_VP_SPLAT:
43024314
Res = WidenVecRes_ScalarOp(N);
43034315
break;
43044316
case ISD::SIGN_EXTEND_INREG: Res = WidenVecRes_InregOp(N); break;
@@ -5835,6 +5847,9 @@ SDValue DAGTypeLegalizer::WidenVecRes_VP_GATHER(VPGatherSDNode *N) {
58355847

58365848
SDValue DAGTypeLegalizer::WidenVecRes_ScalarOp(SDNode *N) {
58375849
EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
5850+
if (N->isVPOpcode())
5851+
return DAG.getNode(N->getOpcode(), SDLoc(N), WidenVT, N->getOperand(0),
5852+
N->getOperand(1), N->getOperand(2));
58385853
return DAG.getNode(N->getOpcode(), SDLoc(N), WidenVT, N->getOperand(0));
58395854
}
58405855

@@ -6374,6 +6389,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
63746389
Res = WidenVecOp_FP_TO_XINT_SAT(N);
63756390
break;
63766391

6392+
case ISD::EXPERIMENTAL_VP_SPLAT:
6393+
Res = WidenVecOp_VP_SPLAT(N, OpNo);
6394+
break;
6395+
63776396
case ISD::VECREDUCE_FADD:
63786397
case ISD::VECREDUCE_FMUL:
63796398
case ISD::VECREDUCE_ADD:
@@ -6834,6 +6853,13 @@ SDValue DAGTypeLegalizer::WidenVecOp_STORE(SDNode *N) {
68346853
report_fatal_error("Unable to widen vector store");
68356854
}
68366855

6856+
SDValue DAGTypeLegalizer::WidenVecOp_VP_SPLAT(SDNode *N, unsigned OpNo) {
6857+
assert(OpNo == 1 && "Can widen only mask operand of vp_splat");
6858+
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
6859+
N->getOperand(0), GetWidenedVector(N->getOperand(1)),
6860+
N->getOperand(2));
6861+
}
6862+
68376863
SDValue DAGTypeLegalizer::WidenVecOp_VP_STORE(SDNode *N, unsigned OpNo) {
68386864
assert((OpNo == 1 || OpNo == 3) &&
68396865
"Can widen only data or mask operand of vp_store");

llvm/lib/IR/IntrinsicInst.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,9 @@ Function *VPIntrinsic::getDeclarationForParams(Module *M, Intrinsic::ID VPID,
699699
VPFunc = Intrinsic::getDeclaration(
700700
M, VPID, {Params[0]->getType(), Params[1]->getType()});
701701
break;
702+
case Intrinsic::experimental_vp_splat:
703+
VPFunc = Intrinsic::getDeclaration(M, VPID, ReturnType);
704+
break;
702705
}
703706
assert(VPFunc && "Could not declare VP intrinsic");
704707
return VPFunc;

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
699699
ISD::VP_SMAX, ISD::VP_UMIN, ISD::VP_UMAX,
700700
ISD::VP_ABS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::EXPERIMENTAL_VP_SPLICE,
701701
ISD::VP_SADDSAT, ISD::VP_UADDSAT, ISD::VP_SSUBSAT,
702-
ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF};
702+
ISD::VP_USUBSAT, ISD::VP_CTTZ_ELTS, ISD::VP_CTTZ_ELTS_ZERO_UNDEF,
703+
ISD::EXPERIMENTAL_VP_SPLAT};
703704

704705
static const unsigned FloatingPointVPOps[] = {
705706
ISD::VP_FADD, ISD::VP_FSUB, ISD::VP_FMUL,
@@ -715,7 +716,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
715716
ISD::VP_FMINIMUM, ISD::VP_FMAXIMUM, ISD::VP_LRINT,
716717
ISD::VP_LLRINT, ISD::EXPERIMENTAL_VP_REVERSE,
717718
ISD::EXPERIMENTAL_VP_SPLICE, ISD::VP_REDUCE_FMINIMUM,
718-
ISD::VP_REDUCE_FMAXIMUM};
719+
ISD::VP_REDUCE_FMAXIMUM, ISD::EXPERIMENTAL_VP_SPLAT};
719720

720721
static const unsigned IntegerVecReduceOps[] = {
721722
ISD::VECREDUCE_ADD, ISD::VECREDUCE_AND, ISD::VECREDUCE_OR,
@@ -7252,6 +7253,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
72527253
return lowerVPSpliceExperimental(Op, DAG);
72537254
case ISD::EXPERIMENTAL_VP_REVERSE:
72547255
return lowerVPReverseExperimental(Op, DAG);
7256+
case ISD::EXPERIMENTAL_VP_SPLAT:
7257+
return lowerVPSplatExperimental(Op, DAG);
72557258
case ISD::CLEAR_CACHE: {
72567259
assert(getTargetMachine().getTargetTriple().isOSLinux() &&
72577260
"llvm.clear_cache only needs custom lower on Linux targets");
@@ -11614,6 +11617,29 @@ RISCVTargetLowering::lowerVPSpliceExperimental(SDValue Op,
1161411617
return convertFromScalableVector(VT, Result, DAG, Subtarget);
1161511618
}
1161611619

11620+
SDValue RISCVTargetLowering::lowerVPSplatExperimental(SDValue Op,
11621+
SelectionDAG &DAG) const {
11622+
SDLoc DL(Op);
11623+
SDValue Val = Op.getOperand(0);
11624+
SDValue Mask = Op.getOperand(1);
11625+
SDValue VL = Op.getOperand(2);
11626+
MVT VT = Op.getSimpleValueType();
11627+
11628+
MVT ContainerVT = VT;
11629+
if (VT.isFixedLengthVector()) {
11630+
ContainerVT = getContainerForFixedLengthVector(VT);
11631+
MVT MaskVT = getMaskTypeFor(ContainerVT);
11632+
Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget);
11633+
}
11634+
11635+
SDValue Result =
11636+
lowerScalarSplat(SDValue(), Val, VL, ContainerVT, DL, DAG, Subtarget);
11637+
11638+
if (!VT.isFixedLengthVector())
11639+
return Result;
11640+
return convertFromScalableVector(VT, Result, DAG, Subtarget);
11641+
}
11642+
1161711643
SDValue
1161811644
RISCVTargetLowering::lowerVPReverseExperimental(SDValue Op,
1161911645
SelectionDAG &DAG) const {

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@ class RISCVTargetLowering : public TargetLowering {
973973
SDValue lowerLogicVPOp(SDValue Op, SelectionDAG &DAG) const;
974974
SDValue lowerVPExtMaskOp(SDValue Op, SelectionDAG &DAG) const;
975975
SDValue lowerVPSetCCMaskOp(SDValue Op, SelectionDAG &DAG) const;
976+
SDValue lowerVPSplatExperimental(SDValue Op, SelectionDAG &DAG) const;
976977
SDValue lowerVPSpliceExperimental(SDValue Op, SelectionDAG &DAG) const;
977978
SDValue lowerVPReverseExperimental(SDValue Op, SelectionDAG &DAG) const;
978979
SDValue lowerVPFPIntConvOp(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)