Skip to content

[LLT] Add and use isPointerVector and isPointerOrPointerVector. NFC. #81283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2024

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Feb 9, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-backend-aarch64

Author: Jay Foad (jayfoad)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/81283.diff

9 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+1-1)
  • (modified) llvm/include/llvm/CodeGenTypes/LowLevelType.h (+8-6)
  • (modified) llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (+2-3)
  • (modified) llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/MachineVerifier.cpp (+6-8)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+1-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+4-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+2-3)
  • (modified) llvm/unittests/CodeGen/LowLevelTypeTest.cpp (+3)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 5fd80e5f199ce8..637c2c71b02411 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -1084,7 +1084,7 @@ class LegalizeRuleSet {
         },
         [=](const LegalityQuery &Query) {
           LLT T = Query.Types[LargeTypeIdx];
-          if (T.isVector() && T.getElementType().isPointer())
+          if (T.isPointerVector())
             T = T.changeElementType(LLT::scalar(T.getScalarSizeInBits()));
           return std::make_pair(TypeIdx, T);
         });
diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
index cc33152e06f4db..5a16cffb80b32b 100644
--- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h
+++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
@@ -134,15 +134,17 @@ class LLT {
 
   explicit LLT(MVT VT);
 
-  constexpr bool isValid() const { return IsScalar || RawData != 0; }
+  constexpr bool isValid() const { return IsScalar || IsPointer || IsVector; }
 
   constexpr bool isScalar() const { return IsScalar; }
 
-  constexpr bool isPointer() const {
-    return isValid() && IsPointer && !IsVector;
-  }
+  constexpr bool isPointer() const { return IsPointer && !IsVector; }
+
+  constexpr bool isPointerVector() const { return IsPointer && IsVector; }
+
+  constexpr bool isPointerOrPointerVector() const { return IsPointer; }
 
-  constexpr bool isVector() const { return isValid() && IsVector; }
+  constexpr bool isVector() const { return IsVector; }
 
   /// Returns the number of elements in a vector LLT. Must only be called on
   /// vector types.
@@ -209,7 +211,7 @@ class LLT {
   /// but the new element size. Otherwise, return the new element type. Invalid
   /// for pointer types. For pointer types, use changeElementType.
   constexpr LLT changeElementSize(unsigned NewEltSize) const {
-    assert(!getScalarType().isPointer() &&
+    assert(!isPointerOrPointerVector() &&
            "invalid to directly change element size for pointers");
     return isVector() ? LLT::vector(getElementCount(), NewEltSize)
                       : LLT::scalar(NewEltSize);
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 464ff0864d146f..e39fdae1ccbedb 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1716,8 +1716,7 @@ Register LegalizerHelper::coerceToScalar(Register Val) {
   Register NewVal = Val;
 
   assert(Ty.isVector());
-  LLT EltTy = Ty.getElementType();
-  if (EltTy.isPointer())
+  if (Ty.isPointerVector())
     NewVal = MIRBuilder.buildPtrToInt(NewTy, NewVal).getReg(0);
   return MIRBuilder.buildBitcast(NewTy, NewVal).getReg(0);
 }
@@ -7964,7 +7963,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) {
   auto [DstReg, DstTy, MaskReg, MaskTy, Op1Reg, Op1Ty, Op2Reg, Op2Ty] =
       MI.getFirst4RegLLTs();
 
-  bool IsEltPtr = DstTy.getScalarType().isPointer();
+  bool IsEltPtr = DstTy.isPointerOrPointerVector();
   if (IsEltPtr) {
     LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits());
     LLT NewTy = DstTy.changeElementType(ScalarPtrTy);
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index a5827c26c04f48..d58b62871817d8 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -199,7 +199,7 @@ void MachineIRBuilder::validateShiftOp(const LLT Res, const LLT Op0,
 MachineInstrBuilder
 MachineIRBuilder::buildPtrAdd(const DstOp &Res, const SrcOp &Op0,
                               const SrcOp &Op1, std::optional<unsigned> Flags) {
-  assert(Res.getLLTTy(*getMRI()).getScalarType().isPointer() &&
+  assert(Res.getLLTTy(*getMRI()).isPointerOrPointerVector() &&
          Res.getLLTTy(*getMRI()) == Op0.getLLTTy(*getMRI()) && "type mismatch");
   assert(Op1.getLLTTy(*getMRI()).getScalarType().isScalar() && "invalid offset type");
 
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index c65e91741533c8..2632b5b9feac9d 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1288,10 +1288,10 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !PtrTy.isValid() || !OffsetTy.isValid())
       break;
 
-    if (!PtrTy.getScalarType().isPointer())
+    if (!PtrTy.isPointerOrPointerVector())
       report("gep first operand must be a pointer", MI);
 
-    if (OffsetTy.getScalarType().isPointer())
+    if (OffsetTy.isPointerOrPointerVector())
       report("gep offset operand must not be a pointer", MI);
 
     // TODO: Is the offset allowed to be a scalar with a vector?
@@ -1304,7 +1304,7 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !SrcTy.isValid() || !MaskTy.isValid())
       break;
 
-    if (!DstTy.getScalarType().isPointer())
+    if (!DstTy.isPointerOrPointerVector())
       report("ptrmask result type must be a pointer", MI);
 
     if (!MaskTy.getScalarType().isScalar())
@@ -1330,15 +1330,13 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !SrcTy.isValid())
       break;
 
-    LLT DstElTy = DstTy.getScalarType();
-    LLT SrcElTy = SrcTy.getScalarType();
-    if (DstElTy.isPointer() || SrcElTy.isPointer())
+    if (DstTy.isPointerOrPointerVector() || SrcTy.isPointerOrPointerVector())
       report("Generic extend/truncate can not operate on pointers", MI);
 
     verifyVectorElementMatch(DstTy, SrcTy, MI);
 
-    unsigned DstSize = DstElTy.getSizeInBits();
-    unsigned SrcSize = SrcElTy.getSizeInBits();
+    unsigned DstSize = DstTy.getScalarSizeInBits();
+    unsigned SrcSize = SrcTy.getScalarSizeInBits();
     switch (MI->getOpcode()) {
     default:
       if (DstSize <= SrcSize)
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 2515991fbea114..53a0213e9bdb65 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2091,7 +2091,7 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
   case AArch64::G_DUP: {
     // Convert the type from p0 to s64 to help selection.
     LLT DstTy = MRI.getType(I.getOperand(0).getReg());
-    if (!DstTy.getElementType().isPointer())
+    if (!DstTy.isPointerVector())
       return false;
     auto NewSrc = MIB.buildCopy(LLT::scalar(64), I.getOperand(1).getReg());
     MRI.setType(I.getOperand(0).getReg(),
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index cbf5655706e694..ab25e2b9af8b45 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -343,10 +343,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
 
   auto IsPtrVecPred = [=](const LegalityQuery &Query) {
     const LLT &ValTy = Query.Types[0];
-    if (!ValTy.isVector())
-      return false;
-    const LLT EltTy = ValTy.getElementType();
-    return EltTy.isPointer() && EltTy.getAddressSpace() == 0;
+    return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
   getActionDefinitionsBuilder(G_LOAD)
@@ -521,7 +518,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           [=](const LegalityQuery &Query) {
             const LLT &Ty = Query.Types[0];
             const LLT &SrcTy = Query.Types[1];
-            return Ty.isVector() && !SrcTy.getElementType().isPointer() &&
+            return Ty.isVector() && !SrcTy.isPointerVector() &&
                    Ty.getElementType() != SrcTy.getElementType();
           },
           0, 1)
@@ -555,7 +552,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           [=](const LegalityQuery &Query) {
             const LLT &Ty = Query.Types[0];
             const LLT &SrcTy = Query.Types[1];
-            return Ty.isVector() && !SrcTy.getElementType().isPointer() &&
+            return Ty.isVector() && !SrcTy.isPointerVector() &&
                    Ty.getElementType() != SrcTy.getElementType();
           },
           0, 1)
@@ -1649,7 +1646,7 @@ bool AArch64LegalizerInfo::legalizeLoadStore(
     return true;
   }
 
-  if (!ValTy.isVector() || !ValTy.getElementType().isPointer() ||
+  if (!ValTy.isPointerVector() ||
       ValTy.getElementType().getAddressSpace() != 0) {
     LLVM_DEBUG(dbgs() << "Tried to do custom legalization on wrong load/store");
     return false;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 97952de3e6a37b..fd476aa89bfcdc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -416,11 +416,10 @@ static bool loadStoreBitcastWorkaround(const LLT Ty) {
   if (!Ty.isVector())
     return true;
 
-  LLT EltTy = Ty.getElementType();
-  if (EltTy.isPointer())
+  if (Ty.isPointerVector())
     return true;
 
-  unsigned EltSize = EltTy.getSizeInBits();
+  unsigned EltSize = Ty.getScalarSizeInBits();
   return EltSize != 32 && EltSize != 64;
 }
 
diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
index d13cfeeffe56a0..cb34802a5de271 100644
--- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
+++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
@@ -259,6 +259,7 @@ TEST(LowLevelTypeTest, Pointer) {
       // Test kind.
       ASSERT_TRUE(Ty.isValid());
       ASSERT_TRUE(Ty.isPointer());
+      ASSERT_TRUE(Ty.isPointerOrPointerVector());
 
       ASSERT_FALSE(Ty.isScalar());
       ASSERT_FALSE(Ty.isVector());
@@ -266,6 +267,8 @@ TEST(LowLevelTypeTest, Pointer) {
       ASSERT_TRUE(VTy.isValid());
       ASSERT_TRUE(VTy.isVector());
       ASSERT_TRUE(VTy.getElementType().isPointer());
+      ASSERT_TRUE(VTy.isPointerVector());
+      ASSERT_TRUE(VTy.isPointerOrPointerVector());
 
       EXPECT_EQ(Ty, VTy.getElementType());
       EXPECT_EQ(Ty.getSizeInBits(), VTy.getScalarSizeInBits());

@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-llvm-globalisel

Author: Jay Foad (jayfoad)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/81283.diff

9 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h (+1-1)
  • (modified) llvm/include/llvm/CodeGenTypes/LowLevelType.h (+8-6)
  • (modified) llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp (+2-3)
  • (modified) llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/MachineVerifier.cpp (+6-8)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp (+1-1)
  • (modified) llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp (+4-7)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp (+2-3)
  • (modified) llvm/unittests/CodeGen/LowLevelTypeTest.cpp (+3)
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
index 5fd80e5f199ce8..637c2c71b02411 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/LegalizerInfo.h
@@ -1084,7 +1084,7 @@ class LegalizeRuleSet {
         },
         [=](const LegalityQuery &Query) {
           LLT T = Query.Types[LargeTypeIdx];
-          if (T.isVector() && T.getElementType().isPointer())
+          if (T.isPointerVector())
             T = T.changeElementType(LLT::scalar(T.getScalarSizeInBits()));
           return std::make_pair(TypeIdx, T);
         });
diff --git a/llvm/include/llvm/CodeGenTypes/LowLevelType.h b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
index cc33152e06f4db..5a16cffb80b32b 100644
--- a/llvm/include/llvm/CodeGenTypes/LowLevelType.h
+++ b/llvm/include/llvm/CodeGenTypes/LowLevelType.h
@@ -134,15 +134,17 @@ class LLT {
 
   explicit LLT(MVT VT);
 
-  constexpr bool isValid() const { return IsScalar || RawData != 0; }
+  constexpr bool isValid() const { return IsScalar || IsPointer || IsVector; }
 
   constexpr bool isScalar() const { return IsScalar; }
 
-  constexpr bool isPointer() const {
-    return isValid() && IsPointer && !IsVector;
-  }
+  constexpr bool isPointer() const { return IsPointer && !IsVector; }
+
+  constexpr bool isPointerVector() const { return IsPointer && IsVector; }
+
+  constexpr bool isPointerOrPointerVector() const { return IsPointer; }
 
-  constexpr bool isVector() const { return isValid() && IsVector; }
+  constexpr bool isVector() const { return IsVector; }
 
   /// Returns the number of elements in a vector LLT. Must only be called on
   /// vector types.
@@ -209,7 +211,7 @@ class LLT {
   /// but the new element size. Otherwise, return the new element type. Invalid
   /// for pointer types. For pointer types, use changeElementType.
   constexpr LLT changeElementSize(unsigned NewEltSize) const {
-    assert(!getScalarType().isPointer() &&
+    assert(!isPointerOrPointerVector() &&
            "invalid to directly change element size for pointers");
     return isVector() ? LLT::vector(getElementCount(), NewEltSize)
                       : LLT::scalar(NewEltSize);
diff --git a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
index 464ff0864d146f..e39fdae1ccbedb 100644
--- a/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
@@ -1716,8 +1716,7 @@ Register LegalizerHelper::coerceToScalar(Register Val) {
   Register NewVal = Val;
 
   assert(Ty.isVector());
-  LLT EltTy = Ty.getElementType();
-  if (EltTy.isPointer())
+  if (Ty.isPointerVector())
     NewVal = MIRBuilder.buildPtrToInt(NewTy, NewVal).getReg(0);
   return MIRBuilder.buildBitcast(NewTy, NewVal).getReg(0);
 }
@@ -7964,7 +7963,7 @@ LegalizerHelper::LegalizeResult LegalizerHelper::lowerSelect(MachineInstr &MI) {
   auto [DstReg, DstTy, MaskReg, MaskTy, Op1Reg, Op1Ty, Op2Reg, Op2Ty] =
       MI.getFirst4RegLLTs();
 
-  bool IsEltPtr = DstTy.getScalarType().isPointer();
+  bool IsEltPtr = DstTy.isPointerOrPointerVector();
   if (IsEltPtr) {
     LLT ScalarPtrTy = LLT::scalar(DstTy.getScalarSizeInBits());
     LLT NewTy = DstTy.changeElementType(ScalarPtrTy);
diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
index a5827c26c04f48..d58b62871817d8 100644
--- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
@@ -199,7 +199,7 @@ void MachineIRBuilder::validateShiftOp(const LLT Res, const LLT Op0,
 MachineInstrBuilder
 MachineIRBuilder::buildPtrAdd(const DstOp &Res, const SrcOp &Op0,
                               const SrcOp &Op1, std::optional<unsigned> Flags) {
-  assert(Res.getLLTTy(*getMRI()).getScalarType().isPointer() &&
+  assert(Res.getLLTTy(*getMRI()).isPointerOrPointerVector() &&
          Res.getLLTTy(*getMRI()) == Op0.getLLTTy(*getMRI()) && "type mismatch");
   assert(Op1.getLLTTy(*getMRI()).getScalarType().isScalar() && "invalid offset type");
 
diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp
index c65e91741533c8..2632b5b9feac9d 100644
--- a/llvm/lib/CodeGen/MachineVerifier.cpp
+++ b/llvm/lib/CodeGen/MachineVerifier.cpp
@@ -1288,10 +1288,10 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !PtrTy.isValid() || !OffsetTy.isValid())
       break;
 
-    if (!PtrTy.getScalarType().isPointer())
+    if (!PtrTy.isPointerOrPointerVector())
       report("gep first operand must be a pointer", MI);
 
-    if (OffsetTy.getScalarType().isPointer())
+    if (OffsetTy.isPointerOrPointerVector())
       report("gep offset operand must not be a pointer", MI);
 
     // TODO: Is the offset allowed to be a scalar with a vector?
@@ -1304,7 +1304,7 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !SrcTy.isValid() || !MaskTy.isValid())
       break;
 
-    if (!DstTy.getScalarType().isPointer())
+    if (!DstTy.isPointerOrPointerVector())
       report("ptrmask result type must be a pointer", MI);
 
     if (!MaskTy.getScalarType().isScalar())
@@ -1330,15 +1330,13 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {
     if (!DstTy.isValid() || !SrcTy.isValid())
       break;
 
-    LLT DstElTy = DstTy.getScalarType();
-    LLT SrcElTy = SrcTy.getScalarType();
-    if (DstElTy.isPointer() || SrcElTy.isPointer())
+    if (DstTy.isPointerOrPointerVector() || SrcTy.isPointerOrPointerVector())
       report("Generic extend/truncate can not operate on pointers", MI);
 
     verifyVectorElementMatch(DstTy, SrcTy, MI);
 
-    unsigned DstSize = DstElTy.getSizeInBits();
-    unsigned SrcSize = SrcElTy.getSizeInBits();
+    unsigned DstSize = DstTy.getScalarSizeInBits();
+    unsigned SrcSize = SrcTy.getScalarSizeInBits();
     switch (MI->getOpcode()) {
     default:
       if (DstSize <= SrcSize)
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
index 2515991fbea114..53a0213e9bdb65 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
@@ -2091,7 +2091,7 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
   case AArch64::G_DUP: {
     // Convert the type from p0 to s64 to help selection.
     LLT DstTy = MRI.getType(I.getOperand(0).getReg());
-    if (!DstTy.getElementType().isPointer())
+    if (!DstTy.isPointerVector())
       return false;
     auto NewSrc = MIB.buildCopy(LLT::scalar(64), I.getOperand(1).getReg());
     MRI.setType(I.getOperand(0).getReg(),
diff --git a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
index cbf5655706e694..ab25e2b9af8b45 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
@@ -343,10 +343,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
 
   auto IsPtrVecPred = [=](const LegalityQuery &Query) {
     const LLT &ValTy = Query.Types[0];
-    if (!ValTy.isVector())
-      return false;
-    const LLT EltTy = ValTy.getElementType();
-    return EltTy.isPointer() && EltTy.getAddressSpace() == 0;
+    return ValTy.isPointerVector() && ValTy.getAddressSpace() == 0;
   };
 
   getActionDefinitionsBuilder(G_LOAD)
@@ -521,7 +518,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           [=](const LegalityQuery &Query) {
             const LLT &Ty = Query.Types[0];
             const LLT &SrcTy = Query.Types[1];
-            return Ty.isVector() && !SrcTy.getElementType().isPointer() &&
+            return Ty.isVector() && !SrcTy.isPointerVector() &&
                    Ty.getElementType() != SrcTy.getElementType();
           },
           0, 1)
@@ -555,7 +552,7 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
           [=](const LegalityQuery &Query) {
             const LLT &Ty = Query.Types[0];
             const LLT &SrcTy = Query.Types[1];
-            return Ty.isVector() && !SrcTy.getElementType().isPointer() &&
+            return Ty.isVector() && !SrcTy.isPointerVector() &&
                    Ty.getElementType() != SrcTy.getElementType();
           },
           0, 1)
@@ -1649,7 +1646,7 @@ bool AArch64LegalizerInfo::legalizeLoadStore(
     return true;
   }
 
-  if (!ValTy.isVector() || !ValTy.getElementType().isPointer() ||
+  if (!ValTy.isPointerVector() ||
       ValTy.getElementType().getAddressSpace() != 0) {
     LLVM_DEBUG(dbgs() << "Tried to do custom legalization on wrong load/store");
     return false;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
index 97952de3e6a37b..fd476aa89bfcdc 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
@@ -416,11 +416,10 @@ static bool loadStoreBitcastWorkaround(const LLT Ty) {
   if (!Ty.isVector())
     return true;
 
-  LLT EltTy = Ty.getElementType();
-  if (EltTy.isPointer())
+  if (Ty.isPointerVector())
     return true;
 
-  unsigned EltSize = EltTy.getSizeInBits();
+  unsigned EltSize = Ty.getScalarSizeInBits();
   return EltSize != 32 && EltSize != 64;
 }
 
diff --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
index d13cfeeffe56a0..cb34802a5de271 100644
--- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
+++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
@@ -259,6 +259,7 @@ TEST(LowLevelTypeTest, Pointer) {
       // Test kind.
       ASSERT_TRUE(Ty.isValid());
       ASSERT_TRUE(Ty.isPointer());
+      ASSERT_TRUE(Ty.isPointerOrPointerVector());
 
       ASSERT_FALSE(Ty.isScalar());
       ASSERT_FALSE(Ty.isVector());
@@ -266,6 +267,8 @@ TEST(LowLevelTypeTest, Pointer) {
       ASSERT_TRUE(VTy.isValid());
       ASSERT_TRUE(VTy.isVector());
       ASSERT_TRUE(VTy.getElementType().isPointer());
+      ASSERT_TRUE(VTy.isPointerVector());
+      ASSERT_TRUE(VTy.isPointerOrPointerVector());
 
       EXPECT_EQ(Ty, VTy.getElementType());
       EXPECT_EQ(Ty.getSizeInBits(), VTy.getScalarSizeInBits());

@@ -134,15 +134,17 @@ class LLT {

explicit LLT(MVT VT);

constexpr bool isValid() const { return IsScalar || RawData != 0; }
constexpr bool isValid() const { return IsScalar || IsPointer || IsVector; }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Writing isValid this way just makes it clear that there is no need to call it from the other is methods below, since they are all testing at least one of the Is flags already.

@jayfoad jayfoad merged commit d57515b into llvm:main Feb 13, 2024
@jayfoad jayfoad deleted the llt-pointervector branch February 13, 2024 08:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants