@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
88
88
// Utility functions for propagating static information
89
89
// ===----------------------------------------------------------------------===//
90
90
91
- // / Helper function that infers the constant values from a list of \p values,
92
- // / a \p memRefTy, and another helper function \p getAttributes.
93
- // / The inferred constant values replace the related `OpFoldResult` in
94
- // / \p values.
91
+ // / Helper function that sets values[i] to constValues[i] if the latter is a
92
+ // / static value, as indicated by ShapedType::kDynamic.
95
93
// /
96
- // / \note This function shouldn't be used directly, instead, use the
97
- // / `getConstifiedMixedXXX` methods from the related operations.
98
- // /
99
- // / \p getAttributes retuns a list of potentially constant values, as determined
100
- // / by \p isDynamic, from the given \p memRefTy. The returned list must have as
101
- // / many elements as \p values or be empty.
102
- // /
103
- // / E.g., consider the following example:
104
- // / ```
105
- // / memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
106
- // / memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
107
- // / ```
108
- // / `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
109
- // / Now using this helper function with:
110
- // / - `values == [2, %dyn_stride]`,
111
- // / - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
112
- // / - `getAttributes == getConstantStrides` (i.e., a wrapper around
113
- // / `getStridesAndOffset`), and
114
- // / - `isDynamic == ShapedType::isDynamic`
115
- // / Will yield: `values == [2, 1]`
116
- static void constifyIndexValues (
117
- SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
118
- MLIRContext *ctxt,
119
- llvm::function_ref<SmallVector<int64_t >(MemRefType)> getAttributes,
120
- llvm::function_ref<bool(int64_t )> isDynamic) {
121
- SmallVector<int64_t > constValues = getAttributes (memRefTy);
122
- Builder builder (ctxt);
123
- for (const auto &it : llvm::enumerate (constValues)) {
124
- int64_t constValue = it.value ();
125
- if (!isDynamic (constValue))
126
- values[it.index ()] = builder.getIndexAttr (constValue);
127
- }
128
- for (OpFoldResult &ofr : values) {
129
- if (auto attr = dyn_cast<Attribute>(ofr)) {
130
- // FIXME: We shouldn't need to do that, but right now, the static indices
131
- // are created with the wrong type: `i64` instead of `index`.
132
- // As a result, if we were to keep the attribute as is, we may fail to see
133
- // that two attributes are equal because one would have the i64 type and
134
- // the other the index type.
135
- // The alternative would be to create constant indices with getI64Attr in
136
- // this and the previous loop, but it doesn't logically make sense (we are
137
- // dealing with indices here) and would only strenghten the inconsistency
138
- // around how static indices are created (some places use getI64Attr,
139
- // others use getIndexAttr).
140
- // The workaround here is to stick to the IndexAttr type for all the
141
- // values, hence we recreate the attribute even when it is already static
142
- // to make sure the type is consistent.
143
- ofr = builder.getIndexAttr (llvm::cast<IntegerAttr>(attr).getInt ());
94
+ // / If constValues[i] is dynamic, tries to extract a constant value from
95
+ // / value[i] to allow for additional folding opportunities. Also convertes all
96
+ // / existing attributes to index attributes. (They may be i64 attributes.)
97
+ static void constifyIndexValues (SmallVectorImpl<OpFoldResult> &values,
98
+ ArrayRef<int64_t > constValues) {
99
+ assert (constValues.size () == values.size () &&
100
+ " incorrect number of const values" );
101
+ for (auto [i, cstVal] : llvm::enumerate (constValues)) {
102
+ Builder builder (values[i].getContext ());
103
+ if (!ShapedType::isDynamic (cstVal)) {
104
+ // Constant value is known, use it directly.
105
+ values[i] = builder.getIndexAttr (cstVal);
144
106
continue ;
145
107
}
146
- std::optional<int64_t > maybeConstant =
147
- getConstantIntValue (cast<Value>(ofr));
148
- if (maybeConstant)
149
- ofr = builder. getIndexAttr (*maybeConstant);
108
+ if ( std::optional<int64_t > cst = getConstantIntValue (values[i])) {
109
+ // Try to extract a constant or convert an existing to index.
110
+ values[i] = builder. getIndexAttr (*cst);
111
+ }
150
112
}
151
113
}
152
114
153
- // / Wrapper around `getShape` that conforms to the function signature
154
- // / expected for `getAttributes` in `constifyIndexValues`.
155
- static SmallVector<int64_t > getConstantSizes (MemRefType memRefTy) {
156
- ArrayRef<int64_t > sizes = memRefTy.getShape ();
157
- return SmallVector<int64_t >(sizes);
158
- }
159
-
160
- // / Wrapper around `getStridesAndOffset` that returns only the offset and
161
- // / conforms to the function signature expected for `getAttributes` in
162
- // / `constifyIndexValues`.
163
- static SmallVector<int64_t > getConstantOffset (MemRefType memrefType) {
164
- SmallVector<int64_t > strides;
165
- int64_t offset;
166
- LogicalResult hasStaticInformation =
167
- memrefType.getStridesAndOffset (strides, offset);
168
- if (failed (hasStaticInformation))
169
- return SmallVector<int64_t >();
170
- return SmallVector<int64_t >(1 , offset);
171
- }
172
-
173
- // / Wrapper around `getStridesAndOffset` that returns only the strides and
174
- // / conforms to the function signature expected for `getAttributes` in
175
- // / `constifyIndexValues`.
176
- static SmallVector<int64_t > getConstantStrides (MemRefType memrefType) {
177
- SmallVector<int64_t > strides;
178
- int64_t offset;
179
- LogicalResult hasStaticInformation =
180
- memrefType.getStridesAndOffset (strides, offset);
181
- if (failed (hasStaticInformation))
182
- return SmallVector<int64_t >();
183
- return strides;
184
- }
185
-
186
115
// ===----------------------------------------------------------------------===//
187
116
// AllocOp / AllocaOp
188
117
// ===----------------------------------------------------------------------===//
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1445
1374
1446
1375
SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes () {
1447
1376
SmallVector<OpFoldResult> values = getAsOpFoldResult (getSizes ());
1448
- constifyIndexValues (values, getSource ().getType (), getContext (),
1449
- getConstantSizes, ShapedType::isDynamic);
1377
+ constifyIndexValues (values, getSource ().getType ().getShape ());
1450
1378
return values;
1451
1379
}
1452
1380
1453
1381
SmallVector<OpFoldResult>
1454
1382
ExtractStridedMetadataOp::getConstifiedMixedStrides () {
1455
1383
SmallVector<OpFoldResult> values = getAsOpFoldResult (getStrides ());
1456
- constifyIndexValues (values, getSource ().getType (), getContext (),
1457
- getConstantStrides, ShapedType::isDynamic);
1384
+ SmallVector<int64_t > staticValues;
1385
+ int64_t unused;
1386
+ LogicalResult status =
1387
+ getSource ().getType ().getStridesAndOffset (staticValues, unused);
1388
+ (void )status;
1389
+ assert (succeeded (status) && " could not get strides from type" );
1390
+ constifyIndexValues (values, staticValues);
1458
1391
return values;
1459
1392
}
1460
1393
1461
1394
OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset () {
1462
1395
OpFoldResult offsetOfr = getAsOpFoldResult (getOffset ());
1463
1396
SmallVector<OpFoldResult> values (1 , offsetOfr);
1464
- constifyIndexValues (values, getSource ().getType (), getContext (),
1465
- getConstantOffset, ShapedType::isDynamic);
1397
+ SmallVector<int64_t > staticValues, unused;
1398
+ int64_t offset;
1399
+ LogicalResult status =
1400
+ getSource ().getType ().getStridesAndOffset (unused, offset);
1401
+ (void )status;
1402
+ assert (succeeded (status) && " could not get offset from type" );
1403
+ staticValues.push_back (offset);
1404
+ constifyIndexValues (values, staticValues);
1466
1405
return values[0 ];
1467
1406
}
1468
1407
@@ -1975,24 +1914,32 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1975
1914
1976
1915
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes () {
1977
1916
SmallVector<OpFoldResult> values = getMixedSizes ();
1978
- constifyIndexValues (values, getType (), getContext (), getConstantSizes,
1979
- ShapedType::isDynamic);
1917
+ constifyIndexValues (values, getType ().getShape ());
1980
1918
return values;
1981
1919
}
1982
1920
1983
1921
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides () {
1984
1922
SmallVector<OpFoldResult> values = getMixedStrides ();
1985
- constifyIndexValues (values, getType (), getContext (), getConstantStrides,
1986
- ShapedType::isDynamic);
1923
+ SmallVector<int64_t > staticValues;
1924
+ int64_t unused;
1925
+ LogicalResult status = getType ().getStridesAndOffset (staticValues, unused);
1926
+ (void )status;
1927
+ assert (succeeded (status) && " could not get strides from type" );
1928
+ constifyIndexValues (values, staticValues);
1987
1929
return values;
1988
1930
}
1989
1931
1990
1932
OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset () {
1991
1933
SmallVector<OpFoldResult> values = getMixedOffsets ();
1992
1934
assert (values.size () == 1 &&
1993
1935
" reinterpret_cast must have one and only one offset" );
1994
- constifyIndexValues (values, getType (), getContext (), getConstantOffset,
1995
- ShapedType::isDynamic);
1936
+ SmallVector<int64_t > staticValues, unused;
1937
+ int64_t offset;
1938
+ LogicalResult status = getType ().getStridesAndOffset (unused, offset);
1939
+ (void )status;
1940
+ assert (succeeded (status) && " could not get offset from type" );
1941
+ staticValues.push_back (offset);
1942
+ constifyIndexValues (values, staticValues);
1996
1943
return values[0 ];
1997
1944
}
1998
1945
0 commit comments