Skip to content

Commit ae646a2

Browse files
committed
[irdl] Add support for attribute wrappers
1 parent 8825fd5 commit ae646a2

18 files changed

+628
-10
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
//===- AttributeWrapper.h - IRDL type wrapper definition --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares wrappers around attribute and type definitions to
10+
// manipulate them in a unified way, using their names and a list of parameters
11+
// encoded as a list of attributes. These wrappers are necessary for IRDL, since
12+
// attributes and types don't have names, nor a way to interact with them in a
13+
// generic way.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef MLIR_DIALECT_IRDL_ATTRIBUTEWRAPPER_H_
18+
#define MLIR_DIALECT_IRDL_ATTRIBUTEWRAPPER_H_
19+
20+
#include "mlir/IR/BuiltinAttributes.h"
21+
#include "mlir/IR/Dialect.h"
22+
#include "mlir/IR/ExtensibleDialect.h"
23+
#include "mlir/IR/OpDefinition.h"
24+
#include "llvm/ADT/SmallString.h"
25+
26+
namespace mlir {
27+
namespace irdl {
28+
29+
/// A wrapper around an attribute definition to manipulate its attribute
30+
/// instances in a unified way, using a name and a list of parameters encoded as
31+
/// attributes.
32+
///
33+
/// If the attribute is defined in C++, CppAttributeWrapper should be preferred.
34+
class AttributeWrapper {
35+
public:
36+
virtual ~AttributeWrapper(){};
37+
38+
/// Return the unique identifier of the attribute definition.
39+
virtual TypeID getTypeID() = 0;
40+
41+
/// Check if the given attribute is an instance of the one wrapped.
42+
virtual bool isCorrectAttribute(Attribute attr) = 0;
43+
44+
/// Get the parameters of an attribute. The attribute is expected to be an
45+
/// instance of the wrapped attribute, which is checked with
46+
/// `isCorrectAttribute`.
47+
virtual SmallVector<Attribute> getAttributeParameters(Attribute attr) = 0;
48+
49+
/// Return the attribute definition name, including the dialect prefix.
50+
virtual StringRef getName() = 0;
51+
52+
/// Instantiate the attribute from parameters.
53+
/// It is expected that the amount of parameters is correct, which is checked
54+
/// with `getParameterAmount`.
55+
virtual Attribute
56+
instantiateAttribute(function_ref<InFlightDiagnostic()> emitError,
57+
ArrayRef<Attribute> parameters) = 0;
58+
59+
/// Return the amount of parameters the attribute expects.
60+
virtual size_t getParameterAmount() = 0;
61+
};
62+
63+
/// A wrapper around a type definition to manipulate its type instances
64+
/// in an unified way, using a name and a list of parameters encoded as
65+
/// attributes. The wrappers also acts as an attribute wrapper, and expects
66+
/// the type to be nested in a `TypeAttr`.
67+
///
68+
/// If the type is defined as a C++ class, CppTypeWrapper should be preferred.
69+
class TypeWrapper : public AttributeWrapper {
70+
public:
71+
virtual ~TypeWrapper(){};
72+
73+
/// Check if the given type is an instance of the one wrapped.
74+
virtual bool isCorrectType(Type t) = 0;
75+
76+
/// Get the parameters of a type. The type is expected to be an instance of
77+
/// the wrapped type, which is checked with `isCorrectType`.
78+
virtual SmallVector<Attribute> getTypeParameters(Type t) = 0;
79+
80+
/// Instantiate the type from parameters.
81+
/// It is expected that the amount of parameters is correct, which is checked
82+
/// with `getParameterAmount`.
83+
virtual Type instantiateType(function_ref<InFlightDiagnostic()> emitError,
84+
ArrayRef<Attribute> parameters) = 0;
85+
86+
/// Check if the given attribute is a `TypeAttr`, and that it contains an
87+
/// instance of the type wrapped.
88+
bool isCorrectAttribute(Attribute attr) override;
89+
90+
/// Get the parameters of a type. The type is expected to be nested in a
91+
/// `TypeAttr`, and to be an instance of the wrapped type, which is checked
92+
/// with `isCorrectType`.
93+
SmallVector<Attribute> getAttributeParameters(Attribute attr) override;
94+
95+
/// Instantiate the type from parameters, and wrap it in a `TypeAttr`.
96+
/// It is expected that the amount of parameters is correct, which is checked
97+
/// with `getParameterAmount`.
98+
Attribute instantiateAttribute(function_ref<InFlightDiagnostic()> emitError,
99+
ArrayRef<Attribute> parameters) override;
100+
};
101+
102+
using AttributeWrapperPtr = AttributeWrapper *;
103+
using TypeWrapperPtr = TypeWrapper *;
104+
105+
/// A wrapper around an attribute definition defined with a C++ class to
106+
/// manipulate its attribute instances in a unified way, using a name and a list
107+
/// of parameters encoded as attributes.
108+
template <typename A>
109+
class CppAttributeWrapper : public AttributeWrapper {
110+
public:
111+
TypeID getTypeID() override { return TypeID::get<A>(); }
112+
113+
/// Get the parameters of an attribute.
114+
virtual SmallVector<Attribute> getAttributeParameters(A attr) = 0;
115+
116+
SmallVector<Attribute> getAttributeParameters(Attribute attr) override {
117+
return getAttributeParameters(cast<A>(attr));
118+
};
119+
120+
bool isCorrectAttribute(Attribute attr) override { return isa<A>(attr); }
121+
};
122+
123+
/// A wrapper around a type definition defined with a C++ class to
124+
/// manipulate its type instances in a unified way, using a name and a list
125+
/// of parameters encoded as attributes.
126+
template <typename T>
127+
class CppTypeWrapper : public TypeWrapper {
128+
public:
129+
TypeID getTypeID() override { return TypeID::get<T>(); }
130+
131+
/// Get the parameters of a type.
132+
virtual SmallVector<Attribute> getTypeParameters(T t) = 0;
133+
134+
SmallVector<Attribute> getTypeParameters(Type type) override {
135+
return getTypeParameters(cast<T>(type));
136+
};
137+
138+
bool isCorrectType(Type type) override { return isa<T>(type); }
139+
};
140+
141+
} // namespace irdl
142+
} // namespace mlir
143+
144+
#endif // MLIR_DIALECT_IRDL_ATTRIBUTEWRAPPER_H_

mlir/include/mlir/Dialect/IRDL/IR/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ mlir_tablegen(IRDLTypesGen.h.inc -gen-typedef-decls)
2020
mlir_tablegen(IRDLTypesGen.cpp.inc -gen-typedef-defs)
2121
add_public_tablegen_target(MLIRIRDLTypesIncGen)
2222
add_dependencies(mlir-generic-headers MLIRIRDLTypesIncGen)
23+
24+
# Add IRDL attributes
25+
set(LLVM_TARGET_DEFINITIONS IRDLAttributes.td)
26+
mlir_tablegen(IRDLAttributes.h.inc -gen-attrdef-decls)
27+
mlir_tablegen(IRDLAttributes.cpp.inc -gen-attrdef-defs)
28+
add_public_tablegen_target(MLIRIRDLAttributesIncGen)
29+
add_dependencies(mlir-generic-headers MLIRIRDLAttributesIncGen)

mlir/include/mlir/Dialect/IRDL/IR/IRDL.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
#ifndef MLIR_DIALECT_IRDL_IR_IRDL_H_
1414
#define MLIR_DIALECT_IRDL_IR_IRDL_H_
1515

16+
#include "mlir/Dialect/IRDL/AttributeWrapper.h"
1617
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
1718
#include "mlir/Dialect/IRDL/IR/IRDLTraits.h"
19+
#include "mlir/Dialect/IRDL/IRDLContext.h"
1820
#include "mlir/IR/SymbolTable.h"
1921
#include "mlir/Interfaces/InferTypeOpInterface.h"
2022
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -37,6 +39,9 @@ class OpDefAttr;
3739
#define GET_TYPEDEF_CLASSES
3840
#include "mlir/Dialect/IRDL/IR/IRDLTypesGen.h.inc"
3941

42+
#define GET_ATTRDEF_CLASSES
43+
#include "mlir/Dialect/IRDL/IR/IRDLAttributes.h.inc"
44+
4045
#define GET_OP_CLASSES
4146
#include "mlir/Dialect/IRDL/IR/IRDLOps.h.inc"
4247

mlir/include/mlir/Dialect/IRDL/IR/IRDL.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ def IRDL_Dialect : Dialect {
7070
}];
7171

7272
let useDefaultTypePrinterParser = 1;
73+
let useDefaultAttributePrinterParser = 1;
74+
75+
let extraClassDeclaration = [{
76+
public:
77+
/// Contains registered attribute and type wrappers.
78+
::mlir::irdl::IRDLContext irdlContext;
79+
}];
7380

7481
let name = "irdl";
7582
let cppNamespace = "::mlir::irdl";
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
//===- IRDLAttributes.td - IR Definition Language Dialect --*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the attributes used in IRDL.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
14+
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
15+
16+
include "IRDL.td"
17+
include "mlir/IR/AttrTypeBase.td"
18+
19+
class IRDL_AttrDef<string name, string attrMnemonic, list<Trait> traits = []>
20+
: AttrDef<IRDL_Dialect, name, traits, "::mlir::Attribute"> {
21+
let mnemonic = attrMnemonic;
22+
}
23+
24+
def TypeOrAttrDefRefAttr : IRDL_AttrDef<"TypeOrAttrDefRef",
25+
"type_or_attr_def_ref"> {
26+
let summary = "reference to a type or attribute definition";
27+
let description = [{
28+
An `irdl.type_or_attr_def_ref` references a type or attribute definition.
29+
A type or attribute definition can either be a `TypeWrapper *` or
30+
an `AttrWrapper *`, which refers to a C++-defined type or attribute,
31+
or a `SymbolRefAttr` which refers to a type or attribute defined with IRDL.
32+
}];
33+
let parameters =
34+
(ins OptionalParameter<"mlir::irdl::TypeWrapperPtr">:$typeWrapper,
35+
OptionalParameter<"mlir::irdl::AttributeWrapperPtr">:$attrWrapper,
36+
OptionalParameter<"mlir::SymbolRefAttr">:$symRef);
37+
38+
let builders = [
39+
AttrBuilder<(ins "mlir::irdl::TypeWrapper*":$typeWrapper), [{
40+
return $_get($_ctxt, typeWrapper, nullptr, mlir::SymbolRefAttr());
41+
}]>,
42+
AttrBuilder<(ins "mlir::irdl::AttributeWrapper*":$attrWrapper), [{
43+
return $_get($_ctxt, nullptr, attrWrapper, mlir::SymbolRefAttr());
44+
}]>,
45+
AttrBuilderWithInferredContext<(ins "mlir::SymbolRefAttr":$symRef), [{
46+
return $_get(symRef.getContext(), nullptr, nullptr, symRef);
47+
}]>
48+
];
49+
50+
let genVerifyDecl = 1;
51+
let hasCustomAssemblyFormat = 1;
52+
}
53+
54+
#endif // MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES

mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
1515

1616
include "IRDL.td"
17+
include "IRDLAttributes.td"
1718
include "IRDLTypes.td"
1819
include "IRDLInterfaces.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -302,8 +303,8 @@ def IRDL_Parametric : IRDL_ConstraintOp<"parametric",
302303
let description = [{
303304
`irdl.parametric` defines a constraint that accepts only a single type
304305
or attribute base. The attribute base is defined by a symbolic reference
305-
to the corresponding definition. It will additionally constraint the
306-
parameters of the type/attribute.
306+
to the corresponding definition, or a type/attribute wrapper. It will
307+
additionally constraint the parameters of the type/attribute.
307308

308309
Example:
309310

@@ -325,7 +326,7 @@ def IRDL_Parametric : IRDL_ConstraintOp<"parametric",
325326
for any `T` takes a `cmath.complex` with parameter `T` and returns a `T`.
326327
}];
327328

328-
let arguments = (ins SymbolRefAttr:$base_type,
329+
let arguments = (ins TypeOrAttrDefRefAttr:$base_type,
329330
Variadic<IRDL_AttributeType>:$args);
330331
let results = (outs IRDL_AttributeType:$output);
331332
let assemblyFormat = " $base_type `<` $args `>` ` ` attr-dict ";
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//===- IRDLContext.h - IRDL context -----------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Manages the registration context of IRDL dialects.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_IRDL_IRDLCONTEXT_H_
14+
#define MLIR_DIALECT_IRDL_IRDLCONTEXT_H_
15+
16+
#include "mlir/Dialect/IRDL/AttributeWrapper.h"
17+
#include "mlir/Support/LLVM.h"
18+
#include "llvm/ADT/StringMap.h"
19+
20+
namespace mlir {
21+
namespace irdl {
22+
23+
/// Context for the runtime registration of IRDL dialect definitions.
24+
/// This class keeps track of all the attribute and types defined in C++
25+
/// that can be used in IRDL.
26+
class IRDLContext {
27+
llvm::StringMap<std::unique_ptr<AttributeWrapper>> attributes;
28+
llvm::StringMap<std::unique_ptr<TypeWrapper>> types;
29+
DenseMap<TypeID, AttributeWrapper *> typeIDToAttributeWrapper;
30+
DenseMap<TypeID, TypeWrapper *> typeIDToTypeWrapper;
31+
32+
public:
33+
IRDLContext();
34+
35+
/// Add a concrete attribute wrapper to IRDL.
36+
/// The attribute definition wrapped can then be used in IRDL with its name.
37+
template <typename A>
38+
void addAttributeWrapper() {
39+
addAttributeWrapper(std::make_unique<A>());
40+
}
41+
42+
/// Add an attribute wrapper to IRDL.
43+
/// The attribute definition wrapped can then be used in IRDL with its name.
44+
void addAttributeWrapper(std::unique_ptr<AttributeWrapper> wrapper);
45+
46+
AttributeWrapper *getAttributeWrapper(StringRef typeName);
47+
AttributeWrapper *getAttributeWrapper(TypeID typeID);
48+
49+
/// Add a concrete type wrapper to IRDL.
50+
/// The type definition wrapped can then be used in IRDL with its name.
51+
template <typename T>
52+
void addTypeWrapper() {
53+
addTypeWrapper(std::make_unique<T>());
54+
}
55+
56+
/// Add a type wrapper to IRDL.
57+
/// The type definition wrapped can then be used in IRDL with its name.
58+
void addTypeWrapper(std::unique_ptr<TypeWrapper> wrapper);
59+
60+
TypeWrapper *getTypeWrapper(StringRef typeName);
61+
TypeWrapper *getTypeWrapper(TypeID typeID);
62+
63+
llvm::StringMap<std::unique_ptr<AttributeWrapper>> const &getAllAttributes() {
64+
return attributes;
65+
}
66+
67+
llvm::StringMap<std::unique_ptr<TypeWrapper>> const &getAllTypes() {
68+
return types;
69+
}
70+
};
71+
72+
} // namespace irdl
73+
} // namespace mlir
74+
75+
#endif // MLIR_DIALECT_IRDL_IRDLCONTEXT_H_

mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
1414
#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
1515

16+
#include "mlir/Dialect/IRDL/AttributeWrapper.h"
1617
#include "mlir/IR/Attributes.h"
1718
#include "mlir/Support/LLVM.h"
1819
#include "llvm/ADT/ArrayRef.h"
@@ -97,6 +98,28 @@ class IsConstraint : public Constraint {
9798
Attribute expectedAttribute;
9899
};
99100

101+
/// A constraint that checks that an attribute that has an attribute wrapper is
102+
/// of a specific attribute definition, and that all of its parameters satisfy
103+
/// the given constraints.
104+
/// This class also takes care of types, by expecting a `TypeAttr` attribute.
105+
class ParametricConstraint : public Constraint {
106+
public:
107+
ParametricConstraint(::mlir::irdl::AttributeWrapper *expectedAttribute,
108+
SmallVector<unsigned> constraints)
109+
: expectedAttribute(expectedAttribute),
110+
constraints(std::move(constraints)) {}
111+
112+
virtual ~ParametricConstraint() = default;
113+
114+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
115+
Attribute attr,
116+
ConstraintVerifier &context) const override;
117+
118+
private:
119+
::mlir::irdl::AttributeWrapper *expectedAttribute;
120+
SmallVector<unsigned> constraints;
121+
};
122+
100123
/// A constraint that checks that an attribute is of a
101124
/// specific dynamic attribute definition, and that all of its parameters
102125
/// satisfy the given constraints.

0 commit comments

Comments
 (0)