Skip to content

Commit 99ef8b1

Browse files
AinsleySnowimkiva
authored andcommitted
[LLVM][XTHeadVector] Implement intrinsics for vmerge and vmv.v.x/i. (llvm#72)
* [LLVM][XTHeadVector] Define intrinsic functions for vmerge and vmv.v.{x,i}. * [LLVM][XTHeadVector] Define pseudos and pats for vmerge. * [LLVM][XTHeadVector] Add test cases for vmerge. * [LLVM][XTHeadVector] Define policy-free pseudo nodes for vmv.v.{v/x/i}. Define pats for vmv.v.v. * [LLVM][XTHeadVector] Define ISD node for vmv.v.x and map it to pseudo nodes. * [LLVM][XTHeadVector] Select vmv.v.x using logic shared with its 1.0 version. * [LLVM][XTHeadVector] Don't add policy for xthead pseudo nodes. * [LLVM][XTHeadVector] Add test cases for vmv.v.x. * [LLVM][XTHeadVector] Update test cases since now pseudo vmv do not accept policy fields any more. * [NFC][XTHeadVector] Update readme.
1 parent ea1dc1b commit 99ef8b1

12 files changed

+3613
-37
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Any feature not listed below but present in the specification should be consider
5151
- (Done) `12.11. Vector Widening Integer Multiply Instructions`
5252
- (Done) `12.12. Vector Single-Width Integer Multiply-Add Instructions`
5353
- (Done) `12.13. Vector Widening Integer Multiply-Add Instructions`
54+
- (Done) `12.14. Vector Integer Merge and Move Instructions`
5455
- (WIP) Clang intrinsics related to the `XTHeadVector` extension:
5556
- (WIP) `6. Configuration-Setting and Utility`
5657
- (Done) `6.1. Set vl and vtype`

llvm/include/llvm/IR/IntrinsicsRISCVXTHeadV.td

+12-2
Original file line numberDiff line numberDiff line change
@@ -770,10 +770,10 @@ let TargetPrefix = "riscv" in {
770770
defm th_vwmacc : XVTernaryWide;
771771
defm th_vwmaccus : XVTernaryWide;
772772
defm th_vwmaccsu : XVTernaryWide;
773-
} // TargetPrefix = "riscv"
774773

775-
let TargetPrefix = "riscv" in {
776774
// 12.14. Vector Integer Merge and Move Instructions
775+
defm th_vmerge : RISCVBinaryWithV0;
776+
777777
// Output: (vector)
778778
// Input: (passthru, vector_in, vl)
779779
def int_riscv_th_vmv_v_v : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
@@ -783,4 +783,14 @@ let TargetPrefix = "riscv" in {
783783
[IntrNoMem]>, RISCVVIntrinsic {
784784
let VLOperand = 2;
785785
}
786+
// Output: (vector)
787+
// Input: (passthru, scalar, vl)
788+
def int_riscv_th_vmv_v_x : DefaultAttrsIntrinsic<[llvm_anyint_ty],
789+
[LLVMMatchType<0>,
790+
LLVMVectorElementType<0>,
791+
llvm_anyint_ty],
792+
[IntrNoMem]>, RISCVVIntrinsic {
793+
let ScalarOperand = 1;
794+
let VLOperand = 2;
795+
}
786796
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+21-11
Original file line numberDiff line numberDiff line change
@@ -3554,7 +3554,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
35543554

35553555
static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35563556
SDValue Lo, SDValue Hi, SDValue VL,
3557-
SelectionDAG &DAG) {
3557+
SelectionDAG &DAG, bool HasVendorXTHeadV) {
35583558
if (!Passthru)
35593559
Passthru = DAG.getUNDEF(VT);
35603560
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
@@ -3563,7 +3563,9 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35633563
// If Hi constant is all the same sign bit as Lo, lower this as a custom
35643564
// node in order to try and match RVV vector/scalar instructions.
35653565
if ((LoC >> 31) == HiC)
3566-
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3566+
return DAG.getNode(HasVendorXTHeadV ?
3567+
RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3568+
DL, VT, Passthru, Lo, VL);
35673569

35683570
// If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use
35693571
// vmv.v.x whose EEW = 32 to lower it.
@@ -3572,8 +3574,8 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35723574
// TODO: if vl <= min(VLMAX), we can also do this. But we could not
35733575
// access the subtarget here now.
35743576
auto InterVec = DAG.getNode(
3575-
RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo,
3576-
DAG.getRegister(RISCV::X0, MVT::i32));
3577+
HasVendorXTHeadV ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3578+
DL, InterVT, DAG.getUNDEF(InterVT), Lo, DAG.getRegister(RISCV::X0, MVT::i32));
35773579
return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
35783580
}
35793581
}
@@ -3588,11 +3590,11 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35883590
// of the halves.
35893591
static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35903592
SDValue Scalar, SDValue VL,
3591-
SelectionDAG &DAG) {
3593+
SelectionDAG &DAG, bool HasVendorXTHeadV) {
35923594
assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
35933595
SDValue Lo, Hi;
35943596
std::tie(Lo, Hi) = DAG.SplitScalar(Scalar, DL, MVT::i32, MVT::i32);
3595-
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
3597+
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG, HasVendorXTHeadV);
35963598
}
35973599

35983600
// This function lowers a splat of a scalar operand Splat with the vector
@@ -3628,7 +3630,9 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
36283630
if (isOneConstant(VL) &&
36293631
(!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue())))
36303632
return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
3631-
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
3633+
return DAG.getNode(
3634+
Subtarget.hasVendorXTHeadV() ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3635+
DL, VT, Passthru, Scalar, VL);
36323636
}
36333637

36343638
assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
@@ -3639,7 +3643,8 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
36393643
DAG.getConstant(0, DL, XLenVT), VL);
36403644

36413645
// Otherwise use the more complicated splatting algorithm.
3642-
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
3646+
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL,
3647+
DAG, Subtarget.hasVendorXTHeadV());
36433648
}
36443649

36453650
static MVT getLMUL1VT(MVT VT) {
@@ -6637,7 +6642,8 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
66376642
auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
66386643

66396644
SDValue Res =
6640-
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
6645+
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL,
6646+
DAG, Subtarget.hasVendorXTHeadV());
66416647
return convertFromScalableVector(VecVT, Res, DAG, Subtarget);
66426648
}
66436649

@@ -7369,7 +7375,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
73697375
// We need to convert the scalar to a splat vector.
73707376
SDValue VL = getVLOperand(Op);
73717377
assert(VL.getValueType() == XLenVT);
7372-
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
7378+
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL,
7379+
DAG, Subtarget.hasVendorXTHeadV());
73737380
return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
73747381
}
73757382

@@ -7483,6 +7490,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
74837490
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
74847491
Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
74857492
case Intrinsic::riscv_vmv_v_x:
7493+
case Intrinsic::riscv_th_vmv_v_x:
74867494
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
74877495
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
74887496
Subtarget);
@@ -7519,7 +7527,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
75197527
SDValue Vec = Op.getOperand(1);
75207528
SDValue VL = getVLOperand(Op);
75217529

7522-
SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
7530+
SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL,
7531+
DAG, Subtarget.hasVendorXTHeadV());
75237532
if (Op.getOperand(1).isUndef())
75247533
return SplattedVal;
75257534
SDValue SplattedIdx =
@@ -16429,6 +16438,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1642916438
NODE_NAME_CASE(TH_SDD)
1643016439
NODE_NAME_CASE(VMV_V_V_VL)
1643116440
NODE_NAME_CASE(VMV_V_X_VL)
16441+
NODE_NAME_CASE(TH_VMV_V_X_VL)
1643216442
NODE_NAME_CASE(VFMV_V_F_VL)
1643316443
NODE_NAME_CASE(VMV_X_S)
1643416444
NODE_NAME_CASE(VMV_S_X_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ enum NodeType : unsigned {
151151
// for the VL value to be used for the operation. The first operand is
152152
// passthru operand.
153153
VMV_V_X_VL,
154+
TH_VMV_V_X_VL,
154155
// VFMV_V_F_VL matches the semantics of vfmv.v.f but includes an extra operand
155156
// for the VL value to be used for the operation. The first operand is
156157
// passthru operand.

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
488488
const MCInstrDesc &Desc = DefMBBI->getDesc();
489489
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
490490
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
491-
MIB.addImm(0); // tu, mu
491+
if (!XTHeadV)
492+
MIB.addImm(0); // tu, mu
492493
MIB.addReg(RISCV::VL, RegState::Implicit);
493494
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
494495
}
@@ -522,7 +523,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
522523
const MCInstrDesc &Desc = DefMBBI->getDesc();
523524
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
524525
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
525-
MIB.addImm(0); // tu, mu
526+
if (!XTHeadV)
527+
MIB.addImm(0); // tu, mu
526528
MIB.addReg(RISCV::VL, RegState::Implicit);
527529
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
528530
}

llvm/lib/Target/RISCV/RISCVInstrInfoXTHeadVPseudos.td

+68-14
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,19 @@ class XVPseudoBinaryMaskNoPolicy<VReg RetClass,
16011601
let HasSEWOp = 1;
16021602
}
16031603

1604+
class XVPseudoUnaryNoMask<DAGOperand RetClass, DAGOperand OpClass,
1605+
string Constraint = ""> :
1606+
Pseudo<(outs RetClass:$rd),
1607+
(ins RetClass:$merge, OpClass:$rs2, AVL:$vl, ixlenimm:$sew), []>,
1608+
RISCVVPseudo {
1609+
let mayLoad = 0;
1610+
let mayStore = 0;
1611+
let hasSideEffects = 0;
1612+
let Constraints = !interleave([Constraint, "$rd = $merge"], ",");
1613+
let HasVLOp = 1;
1614+
let HasSEWOp = 1;
1615+
}
1616+
16041617
multiclass XVPseudoBinary<VReg RetClass,
16051618
VReg Op1Class,
16061619
DAGOperand Op2Class,
@@ -2907,6 +2920,30 @@ let Predicates = [HasVendorXTHeadV] in {
29072920
//===----------------------------------------------------------------------===//
29082921
// 12.14. Vector Integer Merge and Move Instructions
29092922
//===----------------------------------------------------------------------===//
2923+
multiclass XVPseudoVMRG_VM_XM_IM {
2924+
foreach m = MxListXTHeadV in {
2925+
defvar mx = m.MX;
2926+
defvar WriteVIMergeV_MX = !cast<SchedWrite>("WriteVIMergeV_" # mx);
2927+
defvar WriteVIMergeX_MX = !cast<SchedWrite>("WriteVIMergeX_" # mx);
2928+
defvar WriteVIMergeI_MX = !cast<SchedWrite>("WriteVIMergeI_" # mx);
2929+
defvar ReadVIMergeV_MX = !cast<SchedRead>("ReadVIMergeV_" # mx);
2930+
defvar ReadVIMergeX_MX = !cast<SchedRead>("ReadVIMergeX_" # mx);
2931+
2932+
def "_VVM" # "_" # m.MX:
2933+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2934+
m.vrclass, m.vrclass, m, 1, "">,
2935+
Sched<[WriteVIMergeV_MX, ReadVIMergeV_MX, ReadVIMergeV_MX, ReadVMask]>;
2936+
def "_VXM" # "_" # m.MX:
2937+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2938+
m.vrclass, GPR, m, 1, "">,
2939+
Sched<[WriteVIMergeX_MX, ReadVIMergeV_MX, ReadVIMergeX_MX, ReadVMask]>;
2940+
def "_VIM" # "_" # m.MX:
2941+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2942+
m.vrclass, simm5, m, 1, "">,
2943+
Sched<[WriteVIMergeI_MX, ReadVIMergeV_MX, ReadVMask]>;
2944+
}
2945+
}
2946+
29102947
multiclass XVPseudoUnaryVMV_V_X_I {
29112948
foreach m = MxListXTHeadV in {
29122949
let VLMul = m.value in {
@@ -2918,34 +2955,49 @@ multiclass XVPseudoUnaryVMV_V_X_I {
29182955
defvar ReadVIMovX_MX = !cast<SchedRead>("ReadVIMovX_" # mx);
29192956

29202957
let VLMul = m.value in {
2921-
def "_V_" # mx : VPseudoUnaryNoMask<m.vrclass, m.vrclass>,
2958+
def "_V_" # mx : XVPseudoUnaryNoMask<m.vrclass, m.vrclass>,
29222959
Sched<[WriteVIMovV_MX, ReadVIMovV_MX]>;
2923-
def "_X_" # mx : VPseudoUnaryNoMask<m.vrclass, GPR>,
2960+
def "_X_" # mx : XVPseudoUnaryNoMask<m.vrclass, GPR>,
29242961
Sched<[WriteVIMovX_MX, ReadVIMovX_MX]>;
2925-
def "_I_" # mx : VPseudoUnaryNoMask<m.vrclass, simm5>,
2962+
def "_I_" # mx : XVPseudoUnaryNoMask<m.vrclass, simm5>,
29262963
Sched<[WriteVIMovI_MX]>;
29272964
}
29282965
}
29292966
}
29302967
}
29312968

29322969
let Predicates = [HasVendorXTHeadV] in {
2970+
defm PseudoTH_VMERGE : XVPseudoVMRG_VM_XM_IM;
29332971
defm PseudoTH_VMV_V : XVPseudoUnaryVMV_V_X_I;
29342972
} // Predicates = [HasVendorXTHeadV]
29352973

2936-
// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
2937-
foreach vti = AllXVectors in {
2938-
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2939-
// vmv.v.v
2940-
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
2941-
(vti.Vector vti.RegClass:$rs1),
2942-
VLOpFrag)),
2943-
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
2944-
$passthru, $rs1, GPR:$vl, vti.Log2SEW, TU_MU)>;
2974+
let Predicates = [HasVendorXTHeadV] in {
2975+
defm : XVPatBinaryV_VM_XM_IM<"int_riscv_th_vmerge", "PseudoTH_VMERGE">;
2976+
// Define patterns for vmerge intrinsics with float-point arguments.
2977+
foreach vti = AllFloatXVectors in {
2978+
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2979+
defm : VPatBinaryCarryInTAIL<"int_riscv_th_vmerge", "PseudoTH_VMERGE", "VVM",
2980+
vti.Vector,
2981+
vti.Vector, vti.Vector, vti.Mask,
2982+
vti.Log2SEW, vti.LMul, vti.RegClass,
2983+
vti.RegClass, vti.RegClass>;
2984+
}
2985+
}
29452986

2946-
// TODO: vmv.v.x, vmv.v.i
2987+
// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
2988+
foreach vti = AllXVectors in {
2989+
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2990+
// vmv.v.v
2991+
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
2992+
(vti.Vector vti.RegClass:$rs1),
2993+
VLOpFrag)),
2994+
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
2995+
$passthru, $rs1, GPR:$vl, vti.Log2SEW)>;
2996+
// Patterns for vmv.v.x and vmv.v.i are defined
2997+
// in RISCVInstrInfoXTHeadVVLPatterns.td
2998+
}
29472999
}
2948-
}
3000+
} // Predicates = [HasVendorXTHeadV]
29493001

29503002
//===----------------------------------------------------------------------===//
29513003
// 12.14. Vector Integer Merge and Move Instructions
@@ -2967,3 +3019,5 @@ let Predicates = [HasVendorXTHeadV] in {
29673019
def PseudoTH_VMV8R_V : XVPseudoWholeMove<TH_VMV_V_V, V_M8, VRM8>;
29683020
}
29693021
} // Predicates = [HasVendorXTHeadV]
3022+
3023+
include "RISCVInstrInfoXTHeadVVLPatterns.td"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- RISCVInstrInfoXTHeadVVLPatterns.td - RVV VL patterns -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------------===//
8+
///
9+
/// This file contains the required infrastructure and VL patterns to support
10+
/// code generation for the standard 'V' (Vector) extension, version 0.7.1
11+
///
12+
/// This file is included from RISCVInstrInfoXTHeadVPseudos.td
13+
//===---------------------------------------------------------------------------===//
14+
15+
def riscv_th_vmv_v_x_vl : SDNode<"RISCVISD::TH_VMV_V_X_VL",
16+
SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<0>,
17+
SDTCisSameAs<0, 1>,
18+
SDTCisVT<2, XLenVT>,
19+
SDTCisVT<3, XLenVT>]>>;
20+
21+
foreach vti = AllXVectors in {
22+
foreach vti = AllIntegerXVectors in {
23+
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, GPR:$rs2, VLOpFrag)),
24+
(!cast<Instruction>("PseudoTH_VMV_V_X_"#vti.LMul.MX)
25+
vti.RegClass:$passthru, GPR:$rs2, GPR:$vl, vti.Log2SEW)>;
26+
defvar ImmPat = !cast<ComplexPattern>("sew"#vti.SEW#"simm5");
27+
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, (ImmPat simm5:$imm5),
28+
VLOpFrag)),
29+
(!cast<Instruction>("PseudoTH_VMV_V_I_"#vti.LMul.MX)
30+
vti.RegClass:$passthru, simm5:$imm5, GPR:$vl, vti.Log2SEW)>;
31+
}
32+
}

0 commit comments

Comments
 (0)