-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][IR] Add ScalarTypeInterface
and use as VectorType
element type
#132400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Add ScalarTypeInterface
and use as VectorType
element type
#132400
Conversation
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-nvgpu Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-affine Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-backend-amdgpu Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-ods Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
@llvm/pr-subscribers-mlir-math Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new builtin type interface: Instead of maintaining a list of valid element types for Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff 54 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
|
Should I add a builder for |
You can test this locally with the following command:git-clang-format --diff 53a395fda32cb0edd899202b6614595185b01ef1 85c0b6be5c046b342987ff3523836bd87806e971 --extensions cpp,h -- mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h mlir/include/mlir/IR/BuiltinTypes.h mlir/lib/AsmParser/TypeParser.cpp mlir/lib/CAPI/IR/BuiltinTypes.cpp mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp mlir/lib/Dialect/ArmSME/IR/Utils.cpp mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp mlir/lib/Dialect/Quant/IR/QuantTypes.cpp mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp mlir/lib/Dialect/Traits.cpp mlir/lib/Dialect/Vector/IR/VectorOps.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp mlir/lib/IR/BuiltinTypes.cpp mlir/lib/Target/LLVMIR/ModuleImport.cpp mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp mlir/unittests/IR/ShapedTypeTest.cpp View the diff from clang-format here.diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 54c540b28f..5b073a1dbc 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
if (iface.isConvertibleInstruction(inst->getOpcode()))
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
moduleImport);
- // TODO: Implement the `convertInstruction` hooks in the
- // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+ // TODO: Implement the `convertInstruction` hooks in the
+ // `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure();
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spending the time on the code to illustrate the proposal! We should wait for the RFC thread to converge first.
abstract types like function types or async tokens cannot be scalar types. | ||
|
||
Scalar types should be limited to types that can lower to something that | ||
egress dialects would consider a valid vector element type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I agree wit this last paragraph: it seems "out of the blue" with respect to the rest of this description, and is not very precise (what are the egress dialect one has in mind here? If we think of LLVM we should just call this out).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be worth calling out shaped types being explicitly disallowed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A scalar type is a type that people would generally consider a scalar type. Integers, floats, and pointers are examples of scalars, shaped types (like tensors, vectors, and arrays), tuples, and other complex structures are examples of non-scalar types.
(There's my mostly-joking submission to the definition question)
abstract types like function types or async tokens cannot be scalar types. | ||
|
||
Scalar types should be limited to types that can lower to something that | ||
egress dialects would consider a valid vector element type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be worth calling out shaped types being explicitly disallowed?
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) { | |||
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, | |||
MlirType elementType) { | |||
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), | |||
unwrap(elementType))); | |||
cast<ScalarTypeInterface>(unwrap(elementType)))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought you added a builder that performs the cast, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this PR does not have it yet. I wanted to avoid the cast in the builder so that we can have better static type checking in some places. But now the diff became to large, so maybe it's better to add it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think it's not worth to be that explicit -- any type issues will be caught in the callee anyway
Closed in favor of #133455. |
This commit adds a new builtin type interface:
ScalarTypeInterface
Instead of maintaining a list of valid element types for
VectorType
, restrict valid element types toScalarTypeInterface
.