Skip to content

[mlir][IR] Add ScalarTypeInterface and use as VectorType element type #132400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

matthias-springer
Copy link
Member

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir-quant
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-nvgpu

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-affine

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-ods

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-math

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new builtin type interface: ScalarTypeInterface

Instead of maintaining a list of valid element types for VectorType, restrict valid element types to ScalarTypeInterface.


Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff

54 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
  • (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
  • (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
  • (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
  • (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
  • (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
  • (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
  • (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
  • (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
  • (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
  • (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
  • (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
  • (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
  • (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
  • (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vector...
[truncated]

@matthias-springer
Copy link
Member Author

Should I add a builder for VectorType that accepts mlir::Type?

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 53a395fda32cb0edd899202b6614595185b01ef1 85c0b6be5c046b342987ff3523836bd87806e971 --extensions cpp,h -- mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h mlir/include/mlir/IR/BuiltinTypes.h mlir/lib/AsmParser/TypeParser.cpp mlir/lib/CAPI/IR/BuiltinTypes.cpp mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp mlir/lib/Dialect/ArmSME/IR/Utils.cpp mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp mlir/lib/Dialect/Quant/IR/QuantTypes.cpp mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp mlir/lib/Dialect/Traits.cpp mlir/lib/Dialect/Vector/IR/VectorOps.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp mlir/lib/IR/BuiltinTypes.cpp mlir/lib/Target/LLVMIR/ModuleImport.cpp mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp mlir/unittests/IR/ShapedTypeTest.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 54c540b28f..5b073a1dbc 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
   if (iface.isConvertibleInstruction(inst->getOpcode()))
     return iface.convertInstruction(odsBuilder, inst, llvmOperands,
                                     moduleImport);
-    // TODO: Implement the `convertInstruction` hooks in the
-    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+  // TODO: Implement the `convertInstruction` hooks in the
+  // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spending the time on the code to illustrate the proposal! We should wait for the RFC thread to converge first.

abstract types like function types or async tokens cannot be scalar types.

Scalar types should be limited to types that can lower to something that
egress dialects would consider a valid vector element type.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I agree wit this last paragraph: it seems "out of the blue" with respect to the rest of this description, and is not very precise (what are the egress dialect one has in mind here? If we think of LLVM we should just call this out).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be worth calling out shaped types being explicitly disallowed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A scalar type is a type that people would generally consider a scalar type. Integers, floats, and pointers are examples of scalars, shaped types (like tensors, vectors, and arrays), tuples, and other complex structures are examples of non-scalar types.

(There's my mostly-joking submission to the definition question)

abstract types like function types or async tokens cannot be scalar types.

Scalar types should be limited to types that can lower to something that
egress dialects would consider a valid vector element type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be worth calling out shaped types being explicitly disallowed?

@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
cast<ScalarTypeInterface>(unwrap(elementType))));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you added a builder that performs the cast, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this PR does not have it yet. I wanted to avoid the cast in the builder so that we can have better static type checking in some places. But now the diff became to large, so maybe it's better to add it.

Copy link
Member

@kuhar kuhar Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it's not worth to be that explicit -- any type issues will be caught in the callee anyway

@matthias-springer
Copy link
Member Author

Closed in favor of #133455.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants