Skip to content

[mlir][IR] Add VectorTypeElementInterface with !llvm.ptr #133455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 28, 2025

This commit extends the MLIR vector type to support pointer-like types such as !llvm.ptr and !ptr.ptr, as indicated by the newly added VectorTypeElementInterface. This makes the LLVM dialect closer to LLVM IR. LLVM IR already supports pointers as vector element type.

Only integers, floats, pointers and index are valid vector element types for now. Additional vector element types may be added in the future after further discussions. The interface is still evolving and may eventually turn into one of the alternatives that were discussed on the RFC.

This commit also disallows !llvm.ptr as an element type of !llvm.vec. This type exists due to limitations of the MLIR vector type.

RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360

@llvmbot
Copy link
Member

llvmbot commented Mar 28, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit extends the MLIR vector type to support pointer-like types such as !llvm.ptr and !ptr.ptr, as indicated by the newly added VectorTypeElementInterface. This makes the LLVM dialect closer to LLVM IR. LLVM IR already supports pointers as vector element type.

This commit also disallows !llvm.ptr as an element type of !llvm.vec. This type exists due to limitations of the MLIR vector type.

RFC: https://discourse.llvm.org/t/rfc-allow-pointers-as-element-type-of-vector/85360


Patch is 49.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133455.diff

23 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+24-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+9-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+6-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+2-2)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+28-29)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+1-1)
  • (modified) mlir/test/Dialect/LLVMIR/opaque-ptr.mlir (+3-3)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+7-7)
  • (modified) mlir/test/Dialect/LLVMIR/types.mlir (+2-2)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/IR/test-verifiers-type.mlir (+15)
  • (modified) mlir/test/Target/LLVMIR/Import/constant.ll (+4-4)
  • (modified) mlir/test/Target/LLVMIR/Import/incorrect-scalable-vector-check.ll (+1-1)
  • (modified) mlir/test/Target/LLVMIR/Import/intrinsic.ll (+7-37)
  • (modified) mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir (+10-58)
  • (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (+8-8)
  • (modified) mlir/test/Target/LLVMIR/llvmir-types.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+3-3)
  • (modified) mlir/test/Target/LLVMIR/opaque-ptr.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 2ecbf8f50c50c..438df08cd1f28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -11,6 +11,7 @@
 
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 include "mlir/Interfaces/MemorySlotInterfaces.td"
 
@@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
 def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "getIndexBitwidth", "areCompatible", "verifyEntries",
-      "getPreferredAlignment"]>]> {
+      "getPreferredAlignment"]>,
+    VectorElementTypeInterface]> {
   let summary = "LLVM pointer type";
   let description = [{
     The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index 857e68cec8c76..73b2a0857cef3 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>
 
 def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     MemRefElementTypeInterface,
+    VectorElementTypeInterface,
     DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
       "areCompatible", "getIndexBitwidth", "verifyEntries",
       "getPreferredAlignment"]>
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..cf2098d7d64e0 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,30 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// VectorElementTypeInterface
+//===----------------------------------------------------------------------===//
+
+def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type can be used as element in vector types.
+
+    Implementing this interface establishes a contract between this type and the
+    vector type indicating that this type can be used as element of vectors.
+
+    The interface currently has no methods and is used by types to opt into
+    being vector elements. This may change in the future, in particular to
+    require types to provide their size or alignment given a data layout.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType",
+    [VectorElementTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..1db52f1a041ee 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -447,7 +447,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index",
+    [VectorElementTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -477,7 +478,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer : Builtin_Type<"Integer", "integer",
+    [VectorElementTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -1249,7 +1251,11 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
+// Note: VectorType uses this type constraint instead of a plain
+// VectorElementTypeInterface, so that methods with mlir::Type are generated.
+// We may want to drop this in future and require VectorElementTypeInterface
+// for all methods.
+def Builtin_VectorTypeElementType : AnyTypeOf<[VectorElementTypeInterface]> {
   let cppFunctionName = "isValidVectorTypeElementType";
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 1053a25515d5c..51dcb071f9c18 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
   if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
     return false;
 
-  // Scalable types are not supported.
-  if (auto vectorType = dyn_cast<VectorType>(type))
+  if (auto vectorType = dyn_cast<VectorType>(type)) {
+    // Vectors of pointers cannot be casted.
+    if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
+      return false;
+    // Scalable types are not supported.
     return !vectorType.isScalable();
+  }
   return true;
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8f39ede721c92..403756765268e 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -690,7 +690,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 bool LLVMFixedVectorType::isValidElementType(Type type) {
-  return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
+  return llvm::isa<LLVMPPCFP128Type>(type);
 }
 
 LogicalResult
@@ -890,7 +890,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
     if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
       return intType.isSignless();
     return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
-                     Float80Type, Float128Type>(elementType);
+                     Float80Type, Float128Type, LLVMPointerType>(elementType);
   }
   return false;
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 44b4a25a051f1..3df14528bac39 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
 }
 
 // CHECK-LABEL: func @gather
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
@@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
 }
 
 // CHECK-LABEL: func @gather_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: return %[[G]] : vector<[3]xf32>
 
 // -----
@@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
 }
 
 // CHECK-LABEL: func @gather_global_memory
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 // CHECK: return %[[G]] : vector<3xf32>
 
 // -----
@@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
 }
 
 // CHECK-LABEL: func @gather_global_memory_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 // CHECK: return %[[G]] : vector<[3]xf32>
 
 // -----
@@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
 }
 
 // CHECK-LABEL: func @gather_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
 // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>
 
 // -----
@@ -2068,13 +2068,12 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
 }
 
 // CHECK-LABEL: func @gather_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
 // CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>
 
 // -----
 
-
 func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
   %0 = arith.constant 3 : index
   %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
@@ -2083,8 +2082,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2
 
 // CHECK-LABEL: func @gather_1d_from_2d
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
 // CHECK: return %[[G]] : vector<4xf32>
 
 // -----
@@ -2097,8 +2096,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x
 
 // CHECK-LABEL: func @gather_1d_from_2d_scalable
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
-// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
+// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
 // CHECK: return %[[G]] : vector<[4]xf32>
 
 // -----
@@ -2114,8 +2113,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
 }
 
 // CHECK-LABEL: func @scatter
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
 // -----
 
@@ -2126,8 +2125,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
 }
 
 // CHECK-LABEL: func @scatter_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
 
 // -----
 
@@ -2138,8 +2137,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
 }
 
 // CHECK-LABEL: func @scatter_index
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>
 
 // -----
 
@@ -2150,8 +2149,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
 }
 
 // CHECK-LABEL: func @scatter_index_scalable
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into vector<[3]x!llvm.ptr>
 
 // -----
 
@@ -2163,8 +2162,8 @@ func.func @scatter_1d_into_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg
 
 // CHECK-LABEL: func @scatter_1d_into_2d
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into vector<4x!llvm.ptr>
 
 // -----
 
@@ -2176,8 +2175,8 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]
 
 // CHECK-LABEL: func @scatter_1d_into_2d_scalable
 // CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
-// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
-// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.vec<? x 4 x ptr>
+// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
+// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>
 
 // -----
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ae73dfef0c74f..64e51f5554628 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1669,8 +1669,8 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
 }
 
 // CHECK-LABEL: func @gather_with_mask
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
 
 // -----
 
@@ -1685,8 +1685,8 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
 }
 
 // CHECK-LABEL: func @gather_with_mask_scalable
-// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
-// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
+// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
 
 
 // -----
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fb9631d99b91a..a3f6448e09d11 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -1337,16 +1337,16 @@ func.func @invalid_bitcast_i64_to_ptr() {
 
 // -----
 
-func.func @invalid_bitca...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_element_type_interface branch from 4ae95d3 to bda37d9 Compare March 28, 2025 15:52
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_element_type_interface branch from bda37d9 to 316f420 Compare April 2, 2025 18:50
@joker-eph
Copy link
Collaborator

That's reasonable to me, but we should wait for more feedback of course.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a bunch of tests that seem plainly removed from translation. Is this a mis-rebase?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_element_type_interface branch from a79dd2f to f0e7fe2 Compare April 4, 2025 18:07
@matthias-springer matthias-springer requested a review from ftynse April 4, 2025 18:17
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_element_type_interface branch from 1725617 to 804e7c6 Compare April 8, 2025 14:07
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants