Skip to content

[RISCV] Move exact VLEN VLMAX transform to RISCVVectorPeephole #100551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Jul 25, 2024

We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.

We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.
@llvmbot
Copy link
Member

llvmbot commented Jul 25, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

We can teach RISCVVectorPeephole to detect when an AVL is equal to the VLMAX when the exact VLEN is known and use the VLMAX sentinel instead, and in doing so remove the need for getVLOp in RISCVISelLowering. This keeps all the VLMAX logic in one place.


Full diff: https://github.com/llvm/llvm-project/pull/100551.diff

4 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-27)
  • (modified) llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp (+38-14)
  • (modified) llvm/test/CodeGen/RISCV/rvv/pr83017.ll (+3-3)
  • (modified) llvm/test/CodeGen/RISCV/rvv/pr90559.ll (+3-3)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..0339b302fb218 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2758,19 +2758,6 @@ static SDValue getAllOnesMask(MVT VecVT, SDValue VL, const SDLoc &DL,
   return DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL);
 }
 
-static SDValue getVLOp(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
-                       SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
-  // If we know the exact VLEN, and our VL is exactly equal to VLMAX,
-  // canonicalize the representation.  InsertVSETVLI will pick the immediate
-  // encoding later if profitable.
-  const auto [MinVLMAX, MaxVLMAX] =
-      RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
-  if (MinVLMAX == MaxVLMAX && NumElts == MinVLMAX)
-    return DAG.getRegister(RISCV::X0, Subtarget.getXLenVT());
-
-  return DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
-}
-
 static std::pair<SDValue, SDValue>
 getDefaultScalableVLOps(MVT VecVT, const SDLoc &DL, SelectionDAG &DAG,
                         const RISCVSubtarget &Subtarget) {
@@ -2784,7 +2771,7 @@ static std::pair<SDValue, SDValue>
 getDefaultVLOps(uint64_t NumElts, MVT ContainerVT, const SDLoc &DL,
                 SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
   assert(ContainerVT.isScalableVector() && "Expecting scalable container type");
-  SDValue VL = getVLOp(NumElts, ContainerVT, DL, DAG, Subtarget);
+  SDValue VL = DAG.getConstant(NumElts, DL, Subtarget.getXLenVT());
   SDValue Mask = getAllOnesMask(ContainerVT, VL, DL, DAG);
   return {Mask, VL};
 }
@@ -9427,8 +9414,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
     MVT VT = Op->getSimpleValueType(0);
     MVT ContainerVT = getContainerForFixedLengthVector(VT);
 
-    SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
-                         Subtarget);
+    SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
     SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
     auto *Load = cast<MemIntrinsicSDNode>(Op);
     SmallVector<EVT, 9> ContainerVTs(NF, ContainerVT);
@@ -9507,8 +9493,7 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
     MVT VT = Op->getOperand(2).getSimpleValueType();
     MVT ContainerVT = getContainerForFixedLengthVector(VT);
 
-    SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
-                         Subtarget);
+    SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
     SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
     SDValue Ptr = Op->getOperand(NF + 2);
 
@@ -9974,7 +9959,7 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
     // Set the vector length to only the number of elements we care about. Note
     // that for slideup this includes the offset.
     unsigned EndIndex = OrigIdx + SubVecVT.getVectorNumElements();
-    SDValue VL = getVLOp(EndIndex, ContainerVT, DL, DAG, Subtarget);
+    SDValue VL = DAG.getConstant(EndIndex, DL, XLenVT);
 
     // Use tail agnostic policy if we're inserting over Vec's tail.
     unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
@@ -10211,8 +10196,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
         getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget).first;
     // Set the vector length to only the number of elements we care about. This
     // avoids sliding down elements we're going to discard straight away.
-    SDValue VL = getVLOp(SubVecVT.getVectorNumElements(), ContainerVT, DL, DAG,
-                         Subtarget);
+    SDValue VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
     SDValue SlidedownAmt = DAG.getConstant(OrigIdx, DL, XLenVT);
     SDValue Slidedown =
         getVSlidedown(DAG, Subtarget, DL, ContainerVT,
@@ -10287,8 +10271,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
   SDValue SlidedownAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
   auto [Mask, VL] = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget);
   if (SubVecVT.isFixedLengthVector())
-    VL = getVLOp(SubVecVT.getVectorNumElements(), InterSubVT, DL, DAG,
-                 Subtarget);
+    VL = DAG.getConstant(SubVecVT.getVectorNumElements(), DL, XLenVT);
   SDValue Slidedown =
       getVSlidedown(DAG, Subtarget, DL, InterSubVT, DAG.getUNDEF(InterSubVT),
                     Vec, SlidedownAmt, Mask, VL);
@@ -10668,7 +10651,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
     return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL);
   }
 
-  SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG, Subtarget);
+  SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
 
   bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
   SDValue IntID = DAG.getTargetConstant(
@@ -10715,7 +10698,6 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
   SDValue NewValue =
       convertToScalableVector(ContainerVT, StoreVal, DAG, Subtarget);
 
-
   // If we know the exact VLEN and our fixed length vector completely fills
   // the container, use a whole register store instead.
   const auto [MinVLMAX, MaxVLMAX] =
@@ -10728,8 +10710,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
                         MMO->getFlags(), MMO->getAAInfo());
   }
 
-  SDValue VL = getVLOp(VT.getVectorNumElements(), ContainerVT, DL, DAG,
-                       Subtarget);
+  SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT);
 
   bool IsMaskOp = VT.getVectorElementType() == MVT::i1;
   SDValue IntID = DAG.getTargetConstant(
diff --git a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
index b083e64cfc8d7..f328c55e1d3ba 100644
--- a/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
+++ b/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
@@ -47,6 +47,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
   const TargetInstrInfo *TII;
   MachineRegisterInfo *MRI;
   const TargetRegisterInfo *TRI;
+  const RISCVSubtarget *ST;
   RISCVVectorPeephole() : MachineFunctionPass(ID) {}
 
   bool runOnMachineFunction(MachineFunction &MF) override;
@@ -64,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
   bool convertVMergeToVMv(MachineInstr &MI) const;
 
   bool isAllOnesMask(const MachineInstr *MaskDef) const;
+  std::optional<unsigned> getConstant(const MachineOperand &VL) const;
 
   /// Maps uses of V0 to the corresponding def of V0.
   DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
@@ -76,13 +78,44 @@ char RISCVVectorPeephole::ID = 0;
 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
                 false)
 
-// If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
-// to the VLMAX sentinel value.
+/// Check if an operand is an immediate or a materialized ADDI $x0, imm.
+std::optional<unsigned>
+RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
+  if (VL.isImm())
+    return VL.getImm();
+
+  MachineInstr *Def = MRI->getVRegDef(VL.getReg());
+  if (!Def || Def->getOpcode() != RISCV::ADDI ||
+      Def->getOperand(1).getReg() != RISCV::X0)
+    return std::nullopt;
+  return Def->getOperand(2).getImm();
+}
+
+/// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
 bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
   if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
       !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
     return false;
+
+  auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
+  // Fixed-point value, denominator=8
+  unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
+  unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
+  // A Log2SEW of 0 is an operation on mask registers only
+  unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
+  assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
+  assert(8 * LMULFixed / SEW > 0);
+
+  // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
+  if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
+      VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
+    VL.ChangeToImmediate(RISCV::VLMaxSentinel);
+    return true;
+  }
+
+  // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
+  // it to the VLMAX sentinel value.
   if (!VL.isReg())
     return false;
   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
@@ -105,15 +138,6 @@ bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
   if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
     return false;
 
-  auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
-  // Fixed-point value, denominator=8
-  unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
-  unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
-  // A Log2SEW of 0 is an operation on mask registers only
-  unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
-  assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
-  assert(8 * LMULFixed / SEW > 0);
-
   // AVL = (VLENB * Scale)
   //
   // VLMAX = (VLENB * 8 * LMUL) / SEW
@@ -302,11 +326,11 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
     return false;
 
   // Skip if the vector extension is not enabled.
-  const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
-  if (!ST.hasVInstructions())
+  ST = &MF.getSubtarget<RISCVSubtarget>();
+  if (!ST->hasVInstructions())
     return false;
 
-  TII = ST.getInstrInfo();
+  TII = ST->getInstrInfo();
   MRI = &MF.getRegInfo();
   TRI = MRI->getTargetRegisterInfo();
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
index 3719a2ad994d6..beca480378a35 100644
--- a/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/pr83017.ll
@@ -35,11 +35,11 @@ define void @aliasing(ptr %p) {
 ; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vs1r.v v8, (a2)
-; CHECK-NEXT:    vsetvli a2, zero, e8, m4, ta, ma
-; CHECK-NEXT:    vmv.v.i v12, 0
-; CHECK-NEXT:    vs4r.v v12, (a0)
 ; CHECK-NEXT:    addi a2, a0, 64
 ; CHECK-NEXT:    vs1r.v v8, (a2)
+; CHECK-NEXT:    vsetvli a2, zero, e8, m4, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vs4r.v v8, (a0)
 ; CHECK-NEXT:    sw a1, 84(a0)
 ; CHECK-NEXT:    ret
   %q = getelementptr inbounds i8, ptr %p, i64 84
diff --git a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
index 8d330b12055ae..7e109f307c4a5 100644
--- a/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/pr90559.ll
@@ -32,11 +32,11 @@ define void @f(ptr %p) vscale_range(2,2) {
 ; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
 ; CHECK-NEXT:    vmv.v.i v8, 0
 ; CHECK-NEXT:    vs1r.v v8, (a2)
-; CHECK-NEXT:    vsetvli a2, zero, e8, m4, ta, ma
-; CHECK-NEXT:    vmv.v.i v12, 0
-; CHECK-NEXT:    vs4r.v v12, (a0)
 ; CHECK-NEXT:    addi a2, a0, 64
 ; CHECK-NEXT:    vs1r.v v8, (a2)
+; CHECK-NEXT:    vsetvli a2, zero, e8, m4, ta, ma
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    vs4r.v v8, (a0)
 ; CHECK-NEXT:    sw a1, 84(a0)
 ; CHECK-NEXT:    ret
   %q = getelementptr inbounds i8, ptr %p, i64 84

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general.

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 merged commit 754dc9f into llvm:main Jul 25, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants