Skip to content

Commit 754dc9f

Browse files
authored
[RISCV] Move exact VLEN VLMAX transform to RISCVVectorPeephole (#100551)
We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.
1 parent 7432ad6 commit 754dc9f

File tree

4 files changed

+52
-47
lines changed

4 files changed

+52
-47
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2758,19 +2758,6 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL,
27582758
return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
27592759
}
27602760

2761-
static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
2762-
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
2763-
// If we know the exact VLEN, and our VL is exactly equal to VLMAX,
2764-
// canonicalize the representation. InsertVSETVLI will pick the immediate
2765-
// encoding later if profitable.
2766-
const auto [MinVLMAX, MaxVLMAX] =
2767-
RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
2768-
if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX)
2769-
return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
2770-
2771-
return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
2772-
}
2773-
27742761
static std::pair<SDValue, SDValue>
27752762
getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG,
27762763
const RISCVSubtarget &Subtarget) {
@@ -2784,7 +2771,7 @@ static std::pair<SDValue, SDValue>
27842771
getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
27852772
SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
27862773
assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
2787-
SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget);
2774+
SDValue VL = DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
27882775
SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
27892776
return {Mask, VL};
27902777
}
@@ -9427,8 +9414,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
94279414
MVT VT = Op->getSimpleValueType(0);
94289415
MVT ContainerVT = getContainerForFixedLengthVector(VT);
94299416

9430-
SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
9431-
Subtarget);
9417+
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
94329418
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
94339419
auto *Load = cast<MemIntrinsicSDNode>(Op);
94349420
SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
@@ -9507,8 +9493,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
95079493
MVT VT = Op->getOperand(2).getSimpleValueType();
95089494
MVT ContainerVT = getContainerForFixedLengthVector(VT);
95099495

9510-
SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
9511-
Subtarget);
9496+
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
95129497
SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
95139498
SDValue Ptr = Op->getOperand(NF + 2);
95149499

@@ -9974,7 +9959,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
99749959
// Set the vector length to only the number of elements we care about. Note
99759960
// that for slideup this includes the offset.
99769961
unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements();
9977-
SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget);
9962+
SDValue VL = DAG.getConstant(EndIndex, DL, XLenVT);
99789963

99799964
// Use tail agnostic policy if we're inserting over Vec's tail.
99809965
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
@@ -10211,8 +10196,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
1021110196
getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
1021210197
// Set the vector length to only the number of elements we care about. This
1021310198
// avoids sliding down elements we're going to discard straight away.
10214-
SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG,
10215-
Subtarget);
10199+
SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
1021610200
SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
1021710201
SDValue Slidedown =
1021810202
getVSlidedown(DAG, Subtarget, DL, ContainerVT,
@@ -10287,8 +10271,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
1028710271
SDValue SlidedownAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
1028810272
auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
1028910273
if (SubVecVT.isFixedLengthVector())
10290-
VL = getVLOp(SubVecVT.getVectorNumElements(), InterSubVT, DL, DAG,
10291-
Subtarget);
10274+
VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
1029210275
SDValue Slidedown =
1029310276
getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
1029410277
Vec, SlidedownAmt, Mask, VL);
@@ -10668,7 +10651,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
1066810651
return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
1066910652
}
1067010653

10671-
SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget);
10654+
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
1067210655

1067310656
bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
1067410657
SDValue IntID = DAG.getTargetConstant(
@@ -10715,7 +10698,6 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
1071510698
SDValue NewValue =
1071610699
convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
1071710700

10718-
1071910701
// If we know the exact VLEN and our fixed length vector completely fills
1072010702
// the container, use a whole register store instead.
1072110703
const auto [MinVLMAX, MaxVLMAX] =
@@ -10728,8 +10710,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
1072810710
MMO->getFlags(), MMO->getAAInfo());
1072910711
}
1073010712

10731-
SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
10732-
Subtarget);
10713+
SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
1073310714

1073410715
bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
1073510716
SDValue IntID = DAG.getTargetConstant(

llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
4747
const TargetInstrInfo *TII;
4848
MachineRegisterInfo *MRI;
4949
const TargetRegisterInfo *TRI;
50+
const RISCVSubtarget *ST;
5051
RISCVVectorPeephole() : MachineFunctionPass(ID) {}
5152

5253
bool runOnMachineFunction(MachineFunction &MF) override;
@@ -64,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
6465
bool convertVMergeToVMv(MachineInstr &MI) const;
6566

6667
bool isAllOnesMask(const MachineInstr *MaskDef) const;
68+
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
6769

6870
/// Maps uses of V0 to the corresponding def of V0.
6971
DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
@@ -76,13 +78,44 @@ char RISCVVectorPeephole::ID = 0;
7678
INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
7779
false)
7880

79-
// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
80-
// to the VLMAX sentinel value.
81+
/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
82+
std::optional<unsigned>
83+
RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
84+
if (VL.isImm())
85+
return VL.getImm();
86+
87+
MachineInstr *Def = MRI->getVRegDef(VL.getReg());
88+
if (!Def || Def->getOpcode() != RISCV::ADDI ||
89+
Def->getOperand(1).getReg() != RISCV::X0)
90+
return std::nullopt;
91+
return Def->getOperand(2).getImm();
92+
}
93+
94+
/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
8195
bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
8296
if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
8397
!RISCVII::hasSEWOp(MI.getDesc().TSFlags))
8498
return false;
99+
100+
auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
101+
// Fixed-point value, denominator=8
102+
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
103+
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
104+
// A Log2SEW of 0 is an operation on mask registers only
105+
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
106+
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
107+
assert(8 * LMULFixed / SEW > 0);
108+
109+
// If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
85110
MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
111+
if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
112+
VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
113+
VL.ChangeToImmediate(RISCV::VLMaxSentinel);
114+
return true;
115+
}
116+
117+
// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
118+
// it to the VLMAX sentinel value.
86119
if (!VL.isReg())
87120
return false;
88121
MachineInstr *Def = MRI->getVRegDef(VL.getReg());
@@ -105,15 +138,6 @@ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
105138
if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
106139
return false;
107140

108-
auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
109-
// Fixed-point value, denominator=8
110-
unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
111-
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
112-
// A Log2SEW of 0 is an operation on mask registers only
113-
unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
114-
assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
115-
assert(8 * LMULFixed / SEW > 0);
116-
117141
// AVL = (VLENB * Scale)
118142
//
119143
// VLMAX = (VLENB * 8 * LMUL) / SEW
@@ -302,11 +326,11 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
302326
return false;
303327

304328
// Skip if the vector extension is not enabled.
305-
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
306-
if (!ST.hasVInstructions())
329+
ST = &MF.getSubtarget<RISCVSubtarget>();
330+
if (!ST->hasVInstructions())
307331
return false;
308332

309-
TII = ST.getInstrInfo();
333+
TII = ST->getInstrInfo();
310334
MRI = &MF.getRegInfo();
311335
TRI = MRI->getTargetRegisterInfo();
312336

llvm/test/CodeGen/RISCV/rvv/pr83017.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ define void @aliasing(ptr %p) {
3535
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
3636
; CHECK-NEXT: vmv.v.i v8, 0
3737
; CHECK-NEXT: vs1r.v v8, (a2)
38-
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
39-
; CHECK-NEXT: vmv.v.i v12, 0
40-
; CHECK-NEXT: vs4r.v v12, (a0)
4138
; CHECK-NEXT: addi a2, a0, 64
4239
; CHECK-NEXT: vs1r.v v8, (a2)
40+
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
41+
; CHECK-NEXT: vmv.v.i v8, 0
42+
; CHECK-NEXT: vs4r.v v8, (a0)
4343
; CHECK-NEXT: sw a1, 84(a0)
4444
; CHECK-NEXT: ret
4545
%q = getelementptr inbounds i8, ptr %p, i64 84

llvm/test/CodeGen/RISCV/rvv/pr90559.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ define void @f(ptr %p) vscale_range(2,2) {
3232
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
3333
; CHECK-NEXT: vmv.v.i v8, 0
3434
; CHECK-NEXT: vs1r.v v8, (a2)
35-
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
36-
; CHECK-NEXT: vmv.v.i v12, 0
37-
; CHECK-NEXT: vs4r.v v12, (a0)
3835
; CHECK-NEXT: addi a2, a0, 64
3936
; CHECK-NEXT: vs1r.v v8, (a2)
37+
; CHECK-NEXT: vsetvli a2, zero, e8, m4, ta, ma
38+
; CHECK-NEXT: vmv.v.i v8, 0
39+
; CHECK-NEXT: vs4r.v v8, (a0)
4040
; CHECK-NEXT: sw a1, 84(a0)
4141
; CHECK-NEXT: ret
4242
%q = getelementptr inbounds i8, ptr %p, i64 84

0 commit comments

Comments
 (0)