@@ -113,15 +113,9 @@ static T pyTryCast(py::handle object) {
113
113
// / A python-wrapped dense array attribute with an element type and a derived
114
114
// / implementation class.
115
115
template <typename EltTy, typename DerivedT>
116
- class PyDenseArrayAttribute
117
- : public PyConcreteAttribute<PyDenseArrayAttribute<EltTy, DerivedT>> {
116
+ class PyDenseArrayAttribute : public PyConcreteAttribute <DerivedT> {
118
117
public:
119
- static constexpr typename PyConcreteAttribute<
120
- PyDenseArrayAttribute<EltTy, DerivedT>>::IsAFunctionTy isaFunction =
121
- DerivedT::isaFunction;
122
- static constexpr const char *pyClassName = DerivedT::pyClassName;
123
- using PyConcreteAttribute<
124
- PyDenseArrayAttribute<EltTy, DerivedT>>::PyConcreteAttribute;
118
+ using PyConcreteAttribute<DerivedT>::PyConcreteAttribute;
125
119
126
120
// / Iterator over the integer elements of a dense array.
127
121
class PyDenseArrayIterator {
@@ -158,33 +152,29 @@ class PyDenseArrayAttribute
158
152
EltTy getItem (intptr_t i) { return DerivedT::getElement (*this , i); }
159
153
160
154
// / Bind the attribute class.
161
- static void bindDerived (typename PyConcreteAttribute<
162
- PyDenseArrayAttribute<EltTy, DerivedT>>::ClassTy &c) {
155
+ static void bindDerived (typename PyConcreteAttribute<DerivedT>::ClassTy &c) {
163
156
// Bind the constructor.
164
157
c.def_static (
165
158
" get" ,
166
159
[](const std::vector<EltTy> &values, DefaultingPyMlirContext ctx) {
167
160
MlirAttribute attr =
168
161
DerivedT::getAttribute (ctx->get (), values.size (), values.data ());
169
- return PyDenseArrayAttribute<EltTy, DerivedT> (ctx->getRef (), attr);
162
+ return DerivedT (ctx->getRef (), attr);
170
163
},
171
164
py::arg (" values" ), py::arg (" context" ) = py::none (),
172
165
" Gets a uniqued dense array attribute" );
173
166
// Bind the array methods.
174
- c.def (" __getitem__" ,
175
- [](PyDenseArrayAttribute<EltTy, DerivedT> &arr, intptr_t i) {
176
- if (i >= mlirDenseArrayGetNumElements (arr))
177
- throw py::index_error (" DenseArray index out of range" );
178
- return arr.getItem (i);
179
- });
180
- c.def (" __len__" , [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
181
- return mlirDenseArrayGetNumElements (arr);
167
+ c.def (" __getitem__" , [](DerivedT &arr, intptr_t i) {
168
+ if (i >= mlirDenseArrayGetNumElements (arr))
169
+ throw py::index_error (" DenseArray index out of range" );
170
+ return arr.getItem (i);
182
171
});
183
- c.def (" __iter__ " , [](const PyDenseArrayAttribute<EltTy, DerivedT> &arr) {
184
- return PyDenseArrayIterator (arr);
172
+ c.def (" __len__ " , [](const DerivedT &arr) {
173
+ return mlirDenseArrayGetNumElements (arr);
185
174
});
186
- c.def (" __add__" , [](PyDenseArrayAttribute<EltTy, DerivedT> &arr,
187
- py::list extras) {
175
+ c.def (" __iter__" ,
176
+ [](const DerivedT &arr) { return PyDenseArrayIterator (arr); });
177
+ c.def (" __add__" , [](DerivedT &arr, py::list extras) {
188
178
std::vector<EltTy> values;
189
179
intptr_t numOldElements = mlirDenseArrayGetNumElements (arr);
190
180
values.reserve (numOldElements + py::len (extras));
@@ -194,7 +184,7 @@ class PyDenseArrayAttribute
194
184
values.push_back (pyTryCast<EltTy>(attr));
195
185
MlirAttribute attr = DerivedT::getAttribute (arr.getContext ()->get (),
196
186
values.size (), values.data ());
197
- return PyDenseArrayAttribute<EltTy, DerivedT> (arr.getContext (), attr);
187
+ return DerivedT (arr.getContext (), attr);
198
188
});
199
189
}
200
190
};
0 commit comments