@@ -3535,7 +3535,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
3535
3535
3536
3536
static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3537
3537
SDValue Lo, SDValue Hi, SDValue VL,
3538
- SelectionDAG &DAG) {
3538
+ SelectionDAG &DAG, bool HasVendorXTHeadV ) {
3539
3539
if (!Passthru)
3540
3540
Passthru = DAG.getUNDEF(VT);
3541
3541
if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
@@ -3544,7 +3544,9 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3544
3544
// If Hi constant is all the same sign bit as Lo, lower this as a custom
3545
3545
// node in order to try and match RVV vector/scalar instructions.
3546
3546
if ((LoC >> 31) == HiC)
3547
- return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Lo, VL);
3547
+ return DAG.getNode(HasVendorXTHeadV ?
3548
+ RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3549
+ DL, VT, Passthru, Lo, VL);
3548
3550
3549
3551
// If vl is equal to XLEN_MAX and Hi constant is equal to Lo, we could use
3550
3552
// vmv.v.x whose EEW = 32 to lower it.
@@ -3553,8 +3555,8 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3553
3555
// TODO: if vl <= min(VLMAX), we can also do this. But we could not
3554
3556
// access the subtarget here now.
3555
3557
auto InterVec = DAG.getNode(
3556
- RISCVISD::VMV_V_X_VL, DL, InterVT, DAG.getUNDEF(InterVT), Lo ,
3557
- DAG.getRegister(RISCV::X0, MVT::i32));
3558
+ HasVendorXTHeadV ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL ,
3559
+ DL, InterVT, DAG.getUNDEF(InterVT), Lo, DAG.getRegister(RISCV::X0, MVT::i32));
3558
3560
return DAG.getNode(ISD::BITCAST, DL, VT, InterVec);
3559
3561
}
3560
3562
}
@@ -3569,11 +3571,11 @@ static SDValue splatPartsI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3569
3571
// of the halves.
3570
3572
static SDValue splatSplitI64WithVL(const SDLoc &DL, MVT VT, SDValue Passthru,
3571
3573
SDValue Scalar, SDValue VL,
3572
- SelectionDAG &DAG) {
3574
+ SelectionDAG &DAG, bool HasVendorXTHeadV ) {
3573
3575
assert(Scalar.getValueType() == MVT::i64 && "Unexpected VT!");
3574
3576
SDValue Lo, Hi;
3575
3577
std::tie(Lo, Hi) = DAG.SplitScalar(Scalar, DL, MVT::i32, MVT::i32);
3576
- return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG);
3578
+ return splatPartsI64WithVL(DL, VT, Passthru, Lo, Hi, VL, DAG, HasVendorXTHeadV );
3577
3579
}
3578
3580
3579
3581
// This function lowers a splat of a scalar operand Splat with the vector
@@ -3609,7 +3611,9 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
3609
3611
if (isOneConstant(VL) &&
3610
3612
(!Const || isNullConstant(Scalar) || !isInt<5>(Const->getSExtValue())))
3611
3613
return DAG.getNode(RISCVISD::VMV_S_X_VL, DL, VT, Passthru, Scalar, VL);
3612
- return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
3614
+ return DAG.getNode(
3615
+ Subtarget.hasVendorXTHeadV() ? RISCVISD::TH_VMV_V_X_VL : RISCVISD::VMV_V_X_VL,
3616
+ DL, VT, Passthru, Scalar, VL);
3613
3617
}
3614
3618
3615
3619
assert(XLenVT == MVT::i32 && Scalar.getValueType() == MVT::i64 &&
@@ -3620,7 +3624,8 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
3620
3624
DAG.getConstant(0, DL, XLenVT), VL);
3621
3625
3622
3626
// Otherwise use the more complicated splatting algorithm.
3623
- return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
3627
+ return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL,
3628
+ DAG, Subtarget.hasVendorXTHeadV());
3624
3629
}
3625
3630
3626
3631
static MVT getLMUL1VT(MVT VT) {
@@ -6549,7 +6554,8 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
6549
6554
auto VL = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).second;
6550
6555
6551
6556
SDValue Res =
6552
- splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL, DAG);
6557
+ splatPartsI64WithVL(DL, ContainerVT, SDValue(), Lo, Hi, VL,
6558
+ DAG, Subtarget.hasVendorXTHeadV());
6553
6559
return convertFromScalableVector(VecVT, Res, DAG, Subtarget);
6554
6560
}
6555
6561
@@ -7281,7 +7287,8 @@ static SDValue lowerVectorIntrinsicScalars(SDValue Op, SelectionDAG &DAG,
7281
7287
// We need to convert the scalar to a splat vector.
7282
7288
SDValue VL = getVLOperand(Op);
7283
7289
assert(VL.getValueType() == XLenVT);
7284
- ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL, DAG);
7290
+ ScalarOp = splatSplitI64WithVL(DL, VT, SDValue(), ScalarOp, VL,
7291
+ DAG, Subtarget.hasVendorXTHeadV());
7285
7292
return DAG.getNode(Op->getOpcode(), DL, Op->getVTList(), Operands);
7286
7293
}
7287
7294
@@ -7395,6 +7402,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
7395
7402
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, Op.getValueType(),
7396
7403
Op.getOperand(1), DAG.getConstant(0, DL, XLenVT));
7397
7404
case Intrinsic::riscv_vmv_v_x:
7405
+ case Intrinsic::riscv_th_vmv_v_x:
7398
7406
return lowerScalarSplat(Op.getOperand(1), Op.getOperand(2),
7399
7407
Op.getOperand(3), Op.getSimpleValueType(), DL, DAG,
7400
7408
Subtarget);
@@ -7431,7 +7439,8 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
7431
7439
SDValue Vec = Op.getOperand(1);
7432
7440
SDValue VL = getVLOperand(Op);
7433
7441
7434
- SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL, DAG);
7442
+ SDValue SplattedVal = splatSplitI64WithVL(DL, VT, SDValue(), Scalar, VL,
7443
+ DAG, Subtarget.hasVendorXTHeadV());
7435
7444
if (Op.getOperand(1).isUndef())
7436
7445
return SplattedVal;
7437
7446
SDValue SplattedIdx =
@@ -16527,6 +16536,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
16527
16536
NODE_NAME_CASE(TH_SDD)
16528
16537
NODE_NAME_CASE(VMV_V_V_VL)
16529
16538
NODE_NAME_CASE(VMV_V_X_VL)
16539
+ NODE_NAME_CASE(TH_VMV_V_X_VL)
16530
16540
NODE_NAME_CASE(VFMV_V_F_VL)
16531
16541
NODE_NAME_CASE(VMV_X_S)
16532
16542
NODE_NAME_CASE(VMV_S_X_VL)
0 commit comments