Skip to content

Commit 5e58169

Browse files
committed
add IntegerLikeTypeInterface to enable out-of-tree uses of int attribute parsers
1 parent 0fdb908 commit 5e58169

8 files changed

+85
-31
lines changed

mlir/include/mlir/IR/BuiltinAttributes.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -548,7 +548,9 @@ class DenseElementsAttr : public Attribute {
548548
std::enable_if_t<std::is_same<T, APInt>::value>;
549549
template <typename T, typename = APIntValueTemplateCheckT<T>>
550550
FailureOr<iterator_range_impl<IntElementIterator>> tryGetValues() const {
551-
if (!getElementType().isIntOrIndex())
551+
auto intLikeType =
552+
llvm::dyn_cast<IntegerLikeTypeInterface>(getElementType());
553+
if (!intLikeType)
552554
return failure();
553555
return iterator_range_impl<IntElementIterator>(getType(), raw_int_begin(),
554556
raw_int_end());

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

+38
Original file line numberDiff line numberDiff line change
@@ -257,4 +257,42 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
257257
}];
258258
}
259259

260+
def IntegerLikeTypeInterface : TypeInterface<"IntegerLikeTypeInterface"> {
261+
let cppNamespace = "::mlir";
262+
let description = [{
263+
This type interface is for types that behave like integers. It provides
264+
the API that allows MLIR utilities to treat them the same was as MLIR
265+
treats integer types in settings like parsing and printing.
266+
}];
267+
268+
let methods = [
269+
InterfaceMethod<
270+
/*desc=*/[{
271+
Returns the storage bit width for this type.
272+
}],
273+
/*retTy=*/"unsigned",
274+
/*methodName=*/"getStorageBitWidth",
275+
/*args=*/(ins)
276+
>,
277+
InterfaceMethod<
278+
/*desc=*/[{
279+
Returns true if this type is signed.
280+
}],
281+
/*retTy=*/"bool",
282+
/*methodName=*/"isSigned",
283+
/*args=*/(ins),
284+
/*defaultImplementation=*/"return true;"
285+
>,
286+
InterfaceMethod<
287+
/*desc=*/[{
288+
Returns true if this type is signless.
289+
}],
290+
/*retTy=*/"bool",
291+
/*methodName=*/"isSignless",
292+
/*args=*/(ins),
293+
/*defaultImplementation=*/"return true;"
294+
>,
295+
];
296+
}
297+
260298
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

mlir/include/mlir/IR/BuiltinTypes.td

+4-2
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,8 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
466466
//===----------------------------------------------------------------------===//
467467

468468
def Builtin_Index : Builtin_Type<"Index", "index",
469-
[VectorElementTypeInterface]> {
469+
[VectorElementTypeInterface,
470+
DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
470471
let summary = "Integer-like type with unknown platform-dependent bit width";
471472
let description = [{
472473
Syntax:
@@ -497,7 +498,8 @@ def Builtin_Index : Builtin_Type<"Index", "index",
497498
//===----------------------------------------------------------------------===//
498499

499500
def Builtin_Integer : Builtin_Type<"Integer", "integer",
500-
[VectorElementTypeInterface]> {
501+
[VectorElementTypeInterface,
502+
DeclareTypeInterfaceMethods<IntegerLikeTypeInterface, ["getStorageBitWidth"]>]> {
501503
let summary = "Integer type with arbitrary precision up to a fixed limit";
502504
let description = [{
503505
Syntax:

mlir/lib/AsmParser/AttributeParser.cpp

+9-7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/AffineMap.h"
1717
#include "mlir/IR/BuiltinAttributes.h"
1818
#include "mlir/IR/BuiltinDialect.h"
19+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/DialectResourceBlobManager.h"
2122
#include "mlir/IR/IntegerSet.h"
@@ -366,8 +367,12 @@ static std::optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
366367
return std::nullopt;
367368

368369
// Extend or truncate the bitwidth to the right size.
369-
unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
370-
: type.getIntOrFloatBitWidth();
370+
unsigned width;
371+
if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
372+
width = intLikeType.getStorageBitWidth();
373+
} else {
374+
width = type.getIntOrFloatBitWidth();
375+
}
371376

372377
if (width > result.getBitWidth()) {
373378
result = result.zext(width);
@@ -425,10 +430,6 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
425430
return FloatAttr::get(floatType, *result);
426431
}
427432

428-
if (!isa<IntegerType, IndexType>(type))
429-
return emitError(loc, "integer literal not valid for specified type"),
430-
nullptr;
431-
432433
if (isNegative && type.isUnsignedInteger()) {
433434
emitError(loc,
434435
"negative integer literal not valid for unsigned integer type");
@@ -584,7 +585,8 @@ DenseElementsAttr TensorLiteralParser::getAttr(SMLoc loc, ShapedType type) {
584585
}
585586

586587
// Handle integer and index types.
587-
if (eltType.isIntOrIndex()) {
588+
auto integerLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType);
589+
if (integerLikeType || eltType.isIntOrIndex()) {
588590
std::vector<APInt> intValues;
589591
if (failed(getIntAttrElements(loc, eltType, intValues)))
590592
return nullptr;

mlir/lib/IR/AsmPrinter.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2656,7 +2656,7 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
26562656
os << ")";
26572657
});
26582658
}
2659-
} else if (elementType.isIntOrIndex()) {
2659+
} else if (isa<IntegerLikeTypeInterface>(elementType)) {
26602660
auto valueIt = attr.value_begin<APInt>();
26612661
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
26622662
printDenseIntElement(*(valueIt + index), os, elementType);

mlir/lib/IR/AttributeDetail.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ inline size_t getDenseElementBitWidth(Type eltType) {
3737
// Align the width for complex to 8 to make storage and interpretation easier.
3838
if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType))
3939
return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
40-
if (eltType.isIndex())
41-
return IndexType::kInternalStorageBitWidth;
40+
if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(eltType))
41+
return intLikeType.getStorageBitWidth();
42+
4243
return eltType.getIntOrFloatBitWidth();
4344
}
4445

mlir/lib/IR/BuiltinAttributes.cpp

+17-18
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "AttributeDetail.h"
1111
#include "mlir/IR/AffineMap.h"
1212
#include "mlir/IR/BuiltinDialect.h"
13+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1314
#include "mlir/IR/Dialect.h"
1415
#include "mlir/IR/DialectResourceBlobManager.h"
1516
#include "mlir/IR/IntegerSet.h"
@@ -379,22 +380,20 @@ APSInt IntegerAttr::getAPSInt() const {
379380

380381
LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
381382
Type type, APInt value) {
382-
if (IntegerType integerType = llvm::dyn_cast<IntegerType>(type)) {
383-
if (integerType.getWidth() != value.getBitWidth())
384-
return emitError() << "integer type bit width (" << integerType.getWidth()
385-
<< ") doesn't match value bit width ("
386-
<< value.getBitWidth() << ")";
387-
return success();
383+
unsigned width;
384+
if (auto intLikeType = dyn_cast<IntegerLikeTypeInterface>(type)) {
385+
width = intLikeType.getStorageBitWidth();
386+
} else {
387+
return emitError() << "expected integer-like type";
388388
}
389-
if (llvm::isa<IndexType>(type)) {
390-
if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
391-
return emitError()
392-
<< "value bit width (" << value.getBitWidth()
393-
<< ") doesn't match index type internal storage bit width ("
394-
<< IndexType::kInternalStorageBitWidth << ")";
395-
return success();
389+
390+
if (width != value.getBitWidth()) {
391+
return emitError() << "integer-like type bit width (" << width
392+
<< ") doesn't match value bit width ("
393+
<< value.getBitWidth() << ")";
396394
}
397-
return emitError() << "expected integer or index type";
395+
396+
return success();
398397
}
399398

400399
BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
@@ -1019,7 +1018,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
10191018
/// element type of 'type'.
10201019
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
10211020
ArrayRef<APInt> values) {
1022-
assert(type.getElementType().isIntOrIndex());
1021+
assert(isa<IntegerLikeTypeInterface>(type.getElementType()));
10231022
assert(hasSameNumElementsOrSplat(type, values));
10241023
size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
10251024
return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
@@ -1130,11 +1129,11 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
11301129
if (type.isIndex())
11311130
return true;
11321131

1133-
auto intType = llvm::dyn_cast<IntegerType>(type);
1132+
auto intType = llvm::dyn_cast<IntegerLikeTypeInterface>(type);
11341133
if (!intType) {
11351134
LLVM_DEBUG(llvm::dbgs()
1136-
<< "expected integer type when isInt is true, but found " << type
1137-
<< "\n");
1135+
<< "expected integer-like type when isInt is true, but found "
1136+
<< type << "\n");
11381137
return false;
11391138
}
11401139

mlir/lib/IR/BuiltinTypes.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
5959
return success();
6060
}
6161

62+
//===----------------------------------------------------------------------===//
63+
// Index Type
64+
//===----------------------------------------------------------------------===//
65+
66+
unsigned IndexType::getStorageBitWidth() const {
67+
return kInternalStorageBitWidth;
68+
}
69+
6270
//===----------------------------------------------------------------------===//
6371
// Integer Type
6472
//===----------------------------------------------------------------------===//
@@ -86,6 +94,8 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
8694
return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
8795
}
8896

97+
unsigned IntegerType::getStorageBitWidth() const { return getWidth(); }
98+
8999
//===----------------------------------------------------------------------===//
90100
// Float Types
91101
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)