Skip to content

[mlir][IR] Add VectorTypeElementInterface with !llvm.ptr #133455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"

Expand Down Expand Up @@ -259,7 +260,8 @@ def LLVMStructType : LLVMType<"LLVMStruct", "struct", [
def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"getIndexBitwidth", "areCompatible", "verifyEntries",
"getPreferredAlignment"]>]> {
"getPreferredAlignment"]>,
VectorElementTypeInterface]> {
let summary = "LLVM pointer type";
let description = [{
The `!llvm.ptr` type is an LLVM pointer type. This type typically represents
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Ptr_Type<string name, string typeMnemonic, list<Trait> traits = []>

def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
MemRefElementTypeInterface,
VectorElementTypeInterface,
DeclareTypeInterfaceMethods<DataLayoutTypeInterface, [
"areCompatible", "getIndexBitwidth", "verifyEntries",
"getPreferredAlignment"]>
Expand Down
32 changes: 31 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,37 @@

include "mlir/IR/OpBase.td"

def FloatTypeInterface : TypeInterface<"FloatType"> {
//===----------------------------------------------------------------------===//
// VectorElementTypeInterface
//===----------------------------------------------------------------------===//

def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> {
let cppNamespace = "::mlir";
let description = [{
Implementing this interface establishes a contract between this type and the
vector type, indicating that this type can be used as element of vectors.

Vector element types are treated as a bag of bits without any assumed
structure. The size of an element type must be a compile-time constant.
However, the bit-width may remain opaque or unavailable during
transformations that do not depend on the element type.

Note: This type interface is still evolving. It currently has no methods
and is just used as marker to allow types to opt into being vector elements.
This may change in the future, for example, to require types to provide
their size or alignment given a data layout. Please post an RFC before
adding this interface to additional types. Implementing this interface on
downstream types is discourged, until we specified the exact properties of
a vector element type in more detail.
}];
}

//===----------------------------------------------------------------------===//
// FloatTypeInterface
//===----------------------------------------------------------------------===//

def FloatTypeInterface : TypeInterface<"FloatType",
[VectorElementTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
Expand Down
12 changes: 9 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//

def Builtin_Index : Builtin_Type<"Index", "index"> {
def Builtin_Index : Builtin_Type<"Index", "index",
[VectorElementTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
Expand Down Expand Up @@ -495,7 +496,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//

def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
def Builtin_Integer : Builtin_Type<"Integer", "integer",
[VectorElementTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
Expand Down Expand Up @@ -1267,7 +1269,11 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//

def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
// Note: VectorType uses this type constraint instead of a plain
// VectorElementTypeInterface, so that methods with mlir::Type are generated.
// We may want to drop this in future and require VectorElementTypeInterface
// for all methods.
def Builtin_VectorTypeElementType : AnyTypeOf<[VectorElementTypeInterface]> {
let cppFunctionName = "isValidVectorTypeElementType";
}

Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,13 @@ static bool isSupportedTypeForConversion(Type type) {
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
return false;

// Scalable types are not supported.
if (auto vectorType = dyn_cast<VectorType>(type))
if (auto vectorType = dyn_cast<VectorType>(type)) {
// Vectors of pointers cannot be casted.
if (isa<LLVM::LLVMPointerType>(vectorType.getElementType()))
return false;
// Scalable types are not supported.
return !vectorType.isScalable();
}
return true;
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}

bool LLVMFixedVectorType::isValidElementType(Type type) {
return llvm::isa<LLVMPointerType, LLVMPPCFP128Type>(type);
return llvm::isa<LLVMPPCFP128Type>(type);
}

LogicalResult
Expand Down Expand Up @@ -892,7 +892,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
if (auto intType = llvm::dyn_cast<IntegerType>(elementType))
return intType.isSignless();
return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type,
Float80Type, Float128Type>(elementType);
Float80Type, Float128Type, LLVMPointerType>(elementType);
}
return false;
}
Expand Down
57 changes: 28 additions & 29 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2002,8 +2002,8 @@ func.func @gather(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1
}

// CHECK-LABEL: func @gather
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>

// -----
Expand All @@ -2015,8 +2015,8 @@ func.func @gather_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
}

// CHECK-LABEL: func @gather_scalable
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: return %[[G]] : vector<[3]xf32>

// -----
Expand All @@ -2028,8 +2028,8 @@ func.func @gather_global_memory(%arg0: memref<?xf32, 1>, %arg1: vector<3xi32>, %
}

// CHECK-LABEL: func @gather_global_memory
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> !llvm.vec<3 x ptr<1>>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<3xi32>) -> vector<3x!llvm.ptr<1>>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr<1>>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: return %[[G]] : vector<3xf32>

// -----
Expand All @@ -2041,8 +2041,8 @@ func.func @gather_global_memory_scalable(%arg0: memref<?xf32, 1>, %arg1: vector<
}

// CHECK-LABEL: func @gather_global_memory_scalable
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr<1>>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr<1>, vector<[3]xi32>) -> vector<[3]x!llvm.ptr<1>>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr<1>>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: return %[[G]] : vector<[3]xf32>

// -----
Expand All @@ -2055,8 +2055,8 @@ func.func @gather_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2: v
}

// CHECK-LABEL: func @gather_index
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<3xi64> to vector<3xindex>

// -----
Expand All @@ -2068,13 +2068,12 @@ func.func @gather_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xindex
}

// CHECK-LABEL: func @gather_index_scalable
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xi64>) -> vector<[3]xi64>
// CHECK: %{{.*}} = builtin.unrealized_conversion_cast %[[G]] : vector<[3]xi64> to vector<[3]xindex>

// -----


func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> {
%0 = arith.constant 3 : index
%1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
Expand All @@ -2083,8 +2082,8 @@ func.func @gather_1d_from_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2

// CHECK-LABEL: func @gather_1d_from_2d
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<4 x ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<4x!llvm.ptr>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
// CHECK: return %[[G]] : vector<4xf32>

// -----
Expand All @@ -2097,8 +2096,8 @@ func.func @gather_1d_from_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]x

// CHECK-LABEL: func @gather_1d_from_2d_scalable
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 4 x ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
// CHECK: %[[G:.*]] = llvm.intr.masked.gather %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[4]x!llvm.ptr>, vector<[4]xi1>, vector<[4]xf32>) -> vector<[4]xf32>
// CHECK: return %[[G]] : vector<[4]xf32>

// -----
Expand All @@ -2114,8 +2113,8 @@ func.func @scatter(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi
}

// CHECK-LABEL: func @scatter
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> !llvm.vec<3 x ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into !llvm.vec<3 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi32>) -> vector<3x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>

// -----

Expand All @@ -2126,8 +2125,8 @@ func.func @scatter_scalable(%arg0: memref<?xf32>, %arg1: vector<[3]xi32>, %arg2:
}

// CHECK-LABEL: func @scatter_scalable
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> !llvm.vec<? x 3 x ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi32>) -> vector<[3]x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[3]xf32>, vector<[3]xi1> into vector<[3]x!llvm.ptr>

// -----

Expand All @@ -2138,8 +2137,8 @@ func.func @scatter_index(%arg0: memref<?xindex>, %arg1: vector<3xindex>, %arg2:
}

// CHECK-LABEL: func @scatter_index
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> !llvm.vec<3 x ptr>, i64
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<3xi64>) -> vector<3x!llvm.ptr>, i64
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into vector<3x!llvm.ptr>

// -----

Expand All @@ -2150,8 +2149,8 @@ func.func @scatter_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[3]xinde
}

// CHECK-LABEL: func @scatter_index_scalable
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> !llvm.vec<? x 3 x ptr>, i64
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into !llvm.vec<? x 3 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, vector<[3]xi64>) -> vector<[3]x!llvm.ptr>, i64
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 8 : i32} : vector<[3]xi64>, vector<[3]xi1> into vector<[3]x!llvm.ptr>

// -----

Expand All @@ -2163,8 +2162,8 @@ func.func @scatter_1d_into_2d(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg

// CHECK-LABEL: func @scatter_1d_into_2d
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> !llvm.vec<4 x ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into !llvm.vec<4 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<4xi32>) -> vector<4x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<4xf32>, vector<4xi1> into vector<4x!llvm.ptr>

// -----

Expand All @@ -2176,8 +2175,8 @@ func.func @scatter_1d_into_2d_scalable(%arg0: memref<4x?xf32>, %arg1: vector<[4]

// CHECK-LABEL: func @scatter_1d_into_2d_scalable
// CHECK: %[[B:.*]] = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> !llvm.vec<? x 4 x ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into !llvm.vec<? x 4 x ptr>
// CHECK: %[[P:.*]] = llvm.getelementptr %[[B]][%{{.*}}] : (!llvm.ptr, vector<[4]xi32>) -> vector<[4]x!llvm.ptr>, f32
// CHECK: llvm.intr.masked.scatter %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : vector<[4]xf32>, vector<[4]xi1> into vector<[4]x!llvm.ptr>

// -----

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1669,8 +1669,8 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
}

// CHECK-LABEL: func @gather_with_mask
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<3x!llvm.ptr>, vector<3xi1>, vector<3xf32>) -> vector<3xf32>

// -----

Expand All @@ -1685,8 +1685,8 @@ func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi
}

// CHECK-LABEL: func @gather_with_mask_scalable
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.vec<? x 3 x ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[G0:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>
// CHECK: %[[G1:.*]] = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 4 : i32} : (vector<[3]x!llvm.ptr>, vector<[3]xi1>, vector<[3]xf32>) -> vector<[3]xf32>


// -----
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1337,16 +1337,16 @@ func.func @invalid_bitcast_i64_to_ptr() {

// -----

func.func @invalid_bitcast_vec_to_ptr(%arg : !llvm.vec<4 x ptr>) {
func.func @invalid_bitcast_vec_to_ptr(%arg : vector<4x!llvm.ptr>) {
// expected-error@+1 {{cannot cast vector of pointers to pointer}}
%0 = llvm.bitcast %arg : !llvm.vec<4 x ptr> to !llvm.ptr
%0 = llvm.bitcast %arg : vector<4x!llvm.ptr> to !llvm.ptr
}

// -----

func.func @invalid_bitcast_ptr_to_vec(%arg : !llvm.ptr) {
// expected-error@+1 {{cannot cast pointer to vector of pointers}}
%0 = llvm.bitcast %arg : !llvm.ptr to !llvm.vec<4 x ptr>
%0 = llvm.bitcast %arg : !llvm.ptr to vector<4x!llvm.ptr>
}

// -----
Expand All @@ -1358,9 +1358,9 @@ func.func @invalid_bitcast_addr_cast(%arg : !llvm.ptr<1>) {

// -----

func.func @invalid_bitcast_addr_cast_vec(%arg : !llvm.vec<4 x ptr<1>>) {
func.func @invalid_bitcast_addr_cast_vec(%arg : vector<4x!llvm.ptr<1>>) {
// expected-error@+1 {{cannot cast pointers of different address spaces, use 'llvm.addrspacecast' instead}}
%0 = llvm.bitcast %arg : !llvm.vec<4 x ptr<1>> to !llvm.vec<4 x ptr>
%0 = llvm.bitcast %arg : vector<4x!llvm.ptr<1>> to vector<4x!llvm.ptr>
}

// -----
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVMIR/mem2reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ llvm.func @load_first_vector_elem() -> i16 {
llvm.func @load_first_llvm_vector_elem() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x !llvm.vec<4 x ptr> : (i32) -> !llvm.ptr
%1 = llvm.alloca %0 x vector<4x!llvm.ptr> : (i32) -> !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> i16
llvm.return %2 : i16
}
Expand Down
Loading