Skip to content

Commit 881becf

Browse files
authored
[LLVM][XTHeadVector] Implement intrinsics for vmerge and vmv.v.x/i. (llvm#72)
* [LLVM][XTHeadVector] Define intrinsic functions for vmerge and vmv.v.{x,i}. * [LLVM][XTHeadVector] Define pseudos and pats for vmerge. * [LLVM][XTHeadVector] Add test cases for vmerge. * [LLVM][XTHeadVector] Define policy-free pseudo nodes for vmv.v.{v/x/i}. Define pats for vmv.v.v. * [LLVM][XTHeadVector] Define ISD node for vmv.v.x and map it to pseudo nodes. * [LLVM][XTHeadVector] Select vmv.v.x using logic shared with its 1.0 version. * [LLVM][XTHeadVector] Don't add policy for xthead pseudo nodes. * [LLVM][XTHeadVector] Add test cases for vmv.v.x. * [LLVM][XTHeadVector] Update test cases since now pseudo vmv do not accept policy fields any more. * [NFC][XTHeadVector] Update readme.
1 parent 236e3c3 commit 881becf

12 files changed

+3613
-37
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ Any feature not listed below but present in the specification should be consider
5151
- (Done) `12.11. Vector Widening Integer Multiply Instructions`
5252
- (Done) `12.12. Vector Single-Width Integer Multiply-Add Instructions`
5353
- (Done) `12.13. Vector Widening Integer Multiply-Add Instructions`
54+
- (Done) `12.14. Vector Integer Merge and Move Instructions`
5455
- (WIP) Clang intrinsics related to the `XTHeadVector` extension:
5556
- (WIP) `6. Configuration-Setting and Utility`
5657
- (Done) `6.1. Set vl and vtype`

llvm/include/llvm/IR/IntrinsicsRISCVXTHeadV.td

+12-2
Original file line numberDiff line numberDiff line change
@@ -770,10 +770,10 @@ let TargetPrefix = "riscv" in {
770770
defm th_vwmacc : XVTernaryWide;
771771
defm th_vwmaccus : XVTernaryWide;
772772
defm th_vwmaccsu : XVTernaryWide;
773-
} // TargetPrefix = "riscv"
774773

775-
let TargetPrefix = "riscv" in {
776774
// 12.14. Vector Integer Merge and Move Instructions
775+
defm th_vmerge : RISCVBinaryWithV0;
776+
777777
// Output: (vector)
778778
// Input: (passthru, vector_in, vl)
779779
def int_riscv_th_vmv_v_v : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
@@ -783,4 +783,14 @@ let TargetPrefix = "riscv" in {
783783
[IntrNoMem]>, RISCVVIntrinsic {
784784
let VLOperand = 2;
785785
}
786+
// Output: (vector)
787+
// Input: (passthru, scalar, vl)
788+
def int_riscv_th_vmv_v_x : DefaultAttrsIntrinsic<[llvm_anyint_ty],
789+
[LLVMMatchType<0>,
790+
LLVMVectorElementType<0>,
791+
llvm_anyint_ty],
792+
[IntrNoMem]>, RISCVVIntrinsic {
793+
let ScalarOperand = 1;
794+
let VLOperand = 2;
795+
}
786796
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

+21-11
Original file line numberDiff line numberDiff line change
@@ -3535,7 +3535,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
35353535

35363536
static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35373537
SDValue Lo, SDValue Hi, SDValue VL,
3538-
SelectionDAG &DAG) {
3538+
SelectionDAG &DAG, bool HasVendorXTHeadV) {
35393539
if (!Passthru)
35403540
Passthru = DAG.getUNDEF(VT);
35413541
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
@@ -3544,7 +3544,9 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35443544
// If Hi constant is all the same sign bit as Lo, lower this as a custom
35453545
// node in order to try and match RVV vector/scalar instructions.
35463546
if ((LoC >> 31) == HiC)
3547-
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3547+
return DAG.getNode(HasVendorXTHeadV ?
3548+
RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3549+
DL, VT, Passthru, Lo, VL);
35483550

35493551
// If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use
35503552
// vmv.v.x whose EEW = 32 to lower it.
@@ -3553,8 +3555,8 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35533555
// TODO: if vl <= min(VLMAX), we can also do this. But we could not
35543556
// access the subtarget here now.
35553557
auto InterVec = DAG.getNode(
3556-
RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo,
3557-
DAG.getRegister(RISCV::X0, MVT::i32));
3558+
HasVendorXTHeadV ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3559+
DL, InterVT, DAG.getUNDEF(InterVT), Lo, DAG.getRegister(RISCV::X0, MVT::i32));
35583560
return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
35593561
}
35603562
}
@@ -3569,11 +3571,11 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35693571
// of the halves.
35703572
static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
35713573
SDValue Scalar, SDValue VL,
3572-
SelectionDAG &DAG) {
3574+
SelectionDAG &DAG, bool HasVendorXTHeadV) {
35733575
assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
35743576
SDValue Lo, Hi;
35753577
std::tie(Lo, Hi) = DAG.SplitScalar(Scalar, DL, MVT::i32, MVT::i32);
3576-
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
3578+
return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG, HasVendorXTHeadV);
35773579
}
35783580

35793581
// This function lowers a splat of a scalar operand Splat with the vector
@@ -3609,7 +3611,9 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
36093611
if (isOneConstant(VL) &&
36103612
(!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue())))
36113613
return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
3612-
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
3614+
return DAG.getNode(
3615+
Subtarget.hasVendorXTHeadV() ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3616+
DL, VT, Passthru, Scalar, VL);
36133617
}
36143618

36153619
assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
@@ -3620,7 +3624,8 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
36203624
DAG.getConstant(0, DL, XLenVT), VL);
36213625

36223626
// Otherwise use the more complicated splatting algorithm.
3623-
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
3627+
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL,
3628+
DAG, Subtarget.hasVendorXTHeadV());
36243629
}
36253630

36263631
static MVT getLMUL1VT(MVT VT) {
@@ -6549,7 +6554,8 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
65496554
auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
65506555

65516556
SDValue Res =
6552-
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
6557+
splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL,
6558+
DAG, Subtarget.hasVendorXTHeadV());
65536559
return convertFromScalableVector(VecVT, Res, DAG, Subtarget);
65546560
}
65556561

@@ -7281,7 +7287,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
72817287
// We need to convert the scalar to a splat vector.
72827288
SDValue VL = getVLOperand(Op);
72837289
assert(VL.getValueType() == XLenVT);
7284-
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
7290+
ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL,
7291+
DAG, Subtarget.hasVendorXTHeadV());
72857292
return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
72867293
}
72877294

@@ -7395,6 +7402,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
73957402
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
73967403
Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
73977404
case Intrinsic::riscv_vmv_v_x:
7405+
case Intrinsic::riscv_th_vmv_v_x:
73987406
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
73997407
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
74007408
Subtarget);
@@ -7431,7 +7439,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
74317439
SDValue Vec = Op.getOperand(1);
74327440
SDValue VL = getVLOperand(Op);
74337441

7434-
SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
7442+
SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL,
7443+
DAG, Subtarget.hasVendorXTHeadV());
74357444
if (Op.getOperand(1).isUndef())
74367445
return SplattedVal;
74377446
SDValue SplattedIdx =
@@ -16527,6 +16536,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1652716536
NODE_NAME_CASE(TH_SDD)
1652816537
NODE_NAME_CASE(VMV_V_V_VL)
1652916538
NODE_NAME_CASE(VMV_V_X_VL)
16539+
NODE_NAME_CASE(TH_VMV_V_X_VL)
1653016540
NODE_NAME_CASE(VFMV_V_F_VL)
1653116541
NODE_NAME_CASE(VMV_X_S)
1653216542
NODE_NAME_CASE(VMV_S_X_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ enum NodeType : unsigned {
144144
// for the VL value to be used for the operation. The first operand is
145145
// passthru operand.
146146
VMV_V_X_VL,
147+
TH_VMV_V_X_VL,
147148
// VFMV_V_F_VL matches the semantics of vfmv.v.f but includes an extra operand
148149
// for the VL value to be used for the operation. The first operand is
149150
// passthru operand.

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
479479
const MCInstrDesc &Desc = DefMBBI->getDesc();
480480
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
481481
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
482-
MIB.addImm(0); // tu, mu
482+
if (!XTHeadV)
483+
MIB.addImm(0); // tu, mu
483484
MIB.addReg(RISCV::VL, RegState::Implicit);
484485
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
485486
}
@@ -513,7 +514,8 @@ void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
513514
const MCInstrDesc &Desc = DefMBBI->getDesc();
514515
MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc))); // AVL
515516
MIB.add(DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc))); // SEW
516-
MIB.addImm(0); // tu, mu
517+
if (!XTHeadV)
518+
MIB.addImm(0); // tu, mu
517519
MIB.addReg(RISCV::VL, RegState::Implicit);
518520
MIB.addReg(RISCV::VTYPE, RegState::Implicit);
519521
}

llvm/lib/Target/RISCV/RISCVInstrInfoXTHeadVPseudos.td

+68-14
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,19 @@ class XVPseudoBinaryMaskNoPolicy<VReg RetClass,
16011601
let HasSEWOp = 1;
16021602
}
16031603

1604+
class XVPseudoUnaryNoMask<DAGOperand RetClass, DAGOperand OpClass,
1605+
string Constraint = ""> :
1606+
Pseudo<(outs RetClass:$rd),
1607+
(ins RetClass:$merge, OpClass:$rs2, AVL:$vl, ixlenimm:$sew), []>,
1608+
RISCVVPseudo {
1609+
let mayLoad = 0;
1610+
let mayStore = 0;
1611+
let hasSideEffects = 0;
1612+
let Constraints = !interleave([Constraint, "$rd = $merge"], ",");
1613+
let HasVLOp = 1;
1614+
let HasSEWOp = 1;
1615+
}
1616+
16041617
multiclass XVPseudoBinary<VReg RetClass,
16051618
VReg Op1Class,
16061619
DAGOperand Op2Class,
@@ -2907,6 +2920,30 @@ let Predicates = [HasVendorXTHeadV] in {
29072920
//===----------------------------------------------------------------------===//
29082921
// 12.14. Vector Integer Merge and Move Instructions
29092922
//===----------------------------------------------------------------------===//
2923+
multiclass XVPseudoVMRG_VM_XM_IM {
2924+
foreach m = MxListXTHeadV in {
2925+
defvar mx = m.MX;
2926+
defvar WriteVIMergeV_MX = !cast<SchedWrite>("WriteVIMergeV_" # mx);
2927+
defvar WriteVIMergeX_MX = !cast<SchedWrite>("WriteVIMergeX_" # mx);
2928+
defvar WriteVIMergeI_MX = !cast<SchedWrite>("WriteVIMergeI_" # mx);
2929+
defvar ReadVIMergeV_MX = !cast<SchedRead>("ReadVIMergeV_" # mx);
2930+
defvar ReadVIMergeX_MX = !cast<SchedRead>("ReadVIMergeX_" # mx);
2931+
2932+
def "_VVM" # "_" # m.MX:
2933+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2934+
m.vrclass, m.vrclass, m, 1, "">,
2935+
Sched<[WriteVIMergeV_MX, ReadVIMergeV_MX, ReadVIMergeV_MX, ReadVMask]>;
2936+
def "_VXM" # "_" # m.MX:
2937+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2938+
m.vrclass, GPR, m, 1, "">,
2939+
Sched<[WriteVIMergeX_MX, ReadVIMergeV_MX, ReadVIMergeX_MX, ReadVMask]>;
2940+
def "_VIM" # "_" # m.MX:
2941+
VPseudoTiedBinaryCarryIn<GetVRegNoV0<m.vrclass>.R,
2942+
m.vrclass, simm5, m, 1, "">,
2943+
Sched<[WriteVIMergeI_MX, ReadVIMergeV_MX, ReadVMask]>;
2944+
}
2945+
}
2946+
29102947
multiclass XVPseudoUnaryVMV_V_X_I {
29112948
foreach m = MxListXTHeadV in {
29122949
let VLMul = m.value in {
@@ -2918,34 +2955,49 @@ multiclass XVPseudoUnaryVMV_V_X_I {
29182955
defvar ReadVIMovX_MX = !cast<SchedRead>("ReadVIMovX_" # mx);
29192956

29202957
let VLMul = m.value in {
2921-
def "_V_" # mx : VPseudoUnaryNoMask<m.vrclass, m.vrclass>,
2958+
def "_V_" # mx : XVPseudoUnaryNoMask<m.vrclass, m.vrclass>,
29222959
Sched<[WriteVIMovV_MX, ReadVIMovV_MX]>;
2923-
def "_X_" # mx : VPseudoUnaryNoMask<m.vrclass, GPR>,
2960+
def "_X_" # mx : XVPseudoUnaryNoMask<m.vrclass, GPR>,
29242961
Sched<[WriteVIMovX_MX, ReadVIMovX_MX]>;
2925-
def "_I_" # mx : VPseudoUnaryNoMask<m.vrclass, simm5>,
2962+
def "_I_" # mx : XVPseudoUnaryNoMask<m.vrclass, simm5>,
29262963
Sched<[WriteVIMovI_MX]>;
29272964
}
29282965
}
29292966
}
29302967
}
29312968

29322969
let Predicates = [HasVendorXTHeadV] in {
2970+
defm PseudoTH_VMERGE : XVPseudoVMRG_VM_XM_IM;
29332971
defm PseudoTH_VMV_V : XVPseudoUnaryVMV_V_X_I;
29342972
} // Predicates = [HasVendorXTHeadV]
29352973

2936-
// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
2937-
foreach vti = AllXVectors in {
2938-
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2939-
// vmv.v.v
2940-
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
2941-
(vti.Vector vti.RegClass:$rs1),
2942-
VLOpFrag)),
2943-
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
2944-
$passthru, $rs1, GPR:$vl, vti.Log2SEW, TU_MU)>;
2974+
let Predicates = [HasVendorXTHeadV] in {
2975+
defm : XVPatBinaryV_VM_XM_IM<"int_riscv_th_vmerge", "PseudoTH_VMERGE">;
2976+
// Define patterns for vmerge intrinsics with float-point arguments.
2977+
foreach vti = AllFloatXVectors in {
2978+
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2979+
defm : VPatBinaryCarryInTAIL<"int_riscv_th_vmerge", "PseudoTH_VMERGE", "VVM",
2980+
vti.Vector,
2981+
vti.Vector, vti.Vector, vti.Mask,
2982+
vti.Log2SEW, vti.LMul, vti.RegClass,
2983+
vti.RegClass, vti.RegClass>;
2984+
}
2985+
}
29452986

2946-
// TODO: vmv.v.x, vmv.v.i
2987+
// Patterns for `int_riscv_vmv_v_v` -> `PseudoTH_VMV_V_V_<LMUL>`
2988+
foreach vti = AllXVectors in {
2989+
let Predicates = GetXVTypePredicates<vti>.Predicates in {
2990+
// vmv.v.v
2991+
def : Pat<(vti.Vector (int_riscv_th_vmv_v_v (vti.Vector vti.RegClass:$passthru),
2992+
(vti.Vector vti.RegClass:$rs1),
2993+
VLOpFrag)),
2994+
(!cast<Instruction>("PseudoTH_VMV_V_V_"#vti.LMul.MX)
2995+
$passthru, $rs1, GPR:$vl, vti.Log2SEW)>;
2996+
// Patterns for vmv.v.x and vmv.v.i are defined
2997+
// in RISCVInstrInfoXTHeadVVLPatterns.td
2998+
}
29472999
}
2948-
}
3000+
} // Predicates = [HasVendorXTHeadV]
29493001

29503002
//===----------------------------------------------------------------------===//
29513003
// 12.14. Vector Integer Merge and Move Instructions
@@ -2967,3 +3019,5 @@ let Predicates = [HasVendorXTHeadV] in {
29673019
def PseudoTH_VMV8R_V : XVPseudoWholeMove<TH_VMV_V_V, V_M8, VRM8>;
29683020
}
29693021
} // Predicates = [HasVendorXTHeadV]
3022+
3023+
include "RISCVInstrInfoXTHeadVVLPatterns.td"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- RISCVInstrInfoXTHeadVVLPatterns.td - RVV VL patterns -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------------===//
8+
///
9+
/// This file contains the required infrastructure and VL patterns to support
10+
/// code generation for the standard 'V' (Vector) extension, version 0.7.1
11+
///
12+
/// This file is included from RISCVInstrInfoXTHeadVPseudos.td
13+
//===---------------------------------------------------------------------------===//
14+
15+
def riscv_th_vmv_v_x_vl : SDNode<"RISCVISD::TH_VMV_V_X_VL",
16+
SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<0>,
17+
SDTCisSameAs<0, 1>,
18+
SDTCisVT<2, XLenVT>,
19+
SDTCisVT<3, XLenVT>]>>;
20+
21+
foreach vti = AllXVectors in {
22+
foreach vti = AllIntegerXVectors in {
23+
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, GPR:$rs2, VLOpFrag)),
24+
(!cast<Instruction>("PseudoTH_VMV_V_X_"#vti.LMul.MX)
25+
vti.RegClass:$passthru, GPR:$rs2, GPR:$vl, vti.Log2SEW)>;
26+
defvar ImmPat = !cast<ComplexPattern>("sew"#vti.SEW#"simm5");
27+
def : Pat<(vti.Vector (riscv_th_vmv_v_x_vl vti.RegClass:$passthru, (ImmPat simm5:$imm5),
28+
VLOpFrag)),
29+
(!cast<Instruction>("PseudoTH_VMV_V_I_"#vti.LMul.MX)
30+
vti.RegClass:$passthru, simm5:$imm5, GPR:$vl, vti.Log2SEW)>;
31+
}
32+
}

0 commit comments

Comments
 (0)