Skip to content

Commit 62d32c2

Browse files
[mlir][memref][NFC] Simplify constifyIndexValues (#135940)
Simplify the code by removing function pointers.
1 parent 1906c18 commit 62d32c2

File tree

1 file changed

+48
-101
lines changed

1 file changed

+48
-101
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+48-101
Original file line numberDiff line numberDiff line change
@@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder,
8888
// Utility functions for propagating static information
8989
//===----------------------------------------------------------------------===//
9090

91-
/// Helper function that infers the constant values from a list of \p values,
92-
/// a \p memRefTy, and another helper function \p getAttributes.
93-
/// The inferred constant values replace the related `OpFoldResult` in
94-
/// \p values.
91+
/// Helper function that sets values[i] to constValues[i] if the latter is a
92+
/// static value, as indicated by ShapedType::kDynamic.
9593
///
96-
/// \note This function shouldn't be used directly, instead, use the
97-
/// `getConstifiedMixedXXX` methods from the related operations.
98-
///
99-
/// \p getAttributes retuns a list of potentially constant values, as determined
100-
/// by \p isDynamic, from the given \p memRefTy. The returned list must have as
101-
/// many elements as \p values or be empty.
102-
///
103-
/// E.g., consider the following example:
104-
/// ```
105-
/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] :
106-
/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
107-
/// ```
108-
/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`.
109-
/// Now using this helper function with:
110-
/// - `values == [2, %dyn_stride]`,
111-
/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>`
112-
/// - `getAttributes == getConstantStrides` (i.e., a wrapper around
113-
/// `getStridesAndOffset`), and
114-
/// - `isDynamic == ShapedType::isDynamic`
115-
/// Will yield: `values == [2, 1]`
116-
static void constifyIndexValues(
117-
SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy,
118-
MLIRContext *ctxt,
119-
llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes,
120-
llvm::function_ref<bool(int64_t)> isDynamic) {
121-
SmallVector<int64_t> constValues = getAttributes(memRefTy);
122-
Builder builder(ctxt);
123-
for (const auto &it : llvm::enumerate(constValues)) {
124-
int64_t constValue = it.value();
125-
if (!isDynamic(constValue))
126-
values[it.index()] = builder.getIndexAttr(constValue);
127-
}
128-
for (OpFoldResult &ofr : values) {
129-
if (auto attr = dyn_cast<Attribute>(ofr)) {
130-
// FIXME: We shouldn't need to do that, but right now, the static indices
131-
// are created with the wrong type: `i64` instead of `index`.
132-
// As a result, if we were to keep the attribute as is, we may fail to see
133-
// that two attributes are equal because one would have the i64 type and
134-
// the other the index type.
135-
// The alternative would be to create constant indices with getI64Attr in
136-
// this and the previous loop, but it doesn't logically make sense (we are
137-
// dealing with indices here) and would only strenghten the inconsistency
138-
// around how static indices are created (some places use getI64Attr,
139-
// others use getIndexAttr).
140-
// The workaround here is to stick to the IndexAttr type for all the
141-
// values, hence we recreate the attribute even when it is already static
142-
// to make sure the type is consistent.
143-
ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt());
94+
/// If constValues[i] is dynamic, tries to extract a constant value from
95+
/// value[i] to allow for additional folding opportunities. Also convertes all
96+
/// existing attributes to index attributes. (They may be i64 attributes.)
97+
static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values,
98+
ArrayRef<int64_t> constValues) {
99+
assert(constValues.size() == values.size() &&
100+
"incorrect number of const values");
101+
for (auto [i, cstVal] : llvm::enumerate(constValues)) {
102+
Builder builder(values[i].getContext());
103+
if (!ShapedType::isDynamic(cstVal)) {
104+
// Constant value is known, use it directly.
105+
values[i] = builder.getIndexAttr(cstVal);
144106
continue;
145107
}
146-
std::optional<int64_t> maybeConstant =
147-
getConstantIntValue(cast<Value>(ofr));
148-
if (maybeConstant)
149-
ofr = builder.getIndexAttr(*maybeConstant);
108+
if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
109+
// Try to extract a constant or convert an existing to index.
110+
values[i] = builder.getIndexAttr(*cst);
111+
}
150112
}
151113
}
152114

153-
/// Wrapper around `getShape` that conforms to the function signature
154-
/// expected for `getAttributes` in `constifyIndexValues`.
155-
static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) {
156-
ArrayRef<int64_t> sizes = memRefTy.getShape();
157-
return SmallVector<int64_t>(sizes);
158-
}
159-
160-
/// Wrapper around `getStridesAndOffset` that returns only the offset and
161-
/// conforms to the function signature expected for `getAttributes` in
162-
/// `constifyIndexValues`.
163-
static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) {
164-
SmallVector<int64_t> strides;
165-
int64_t offset;
166-
LogicalResult hasStaticInformation =
167-
memrefType.getStridesAndOffset(strides, offset);
168-
if (failed(hasStaticInformation))
169-
return SmallVector<int64_t>();
170-
return SmallVector<int64_t>(1, offset);
171-
}
172-
173-
/// Wrapper around `getStridesAndOffset` that returns only the strides and
174-
/// conforms to the function signature expected for `getAttributes` in
175-
/// `constifyIndexValues`.
176-
static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) {
177-
SmallVector<int64_t> strides;
178-
int64_t offset;
179-
LogicalResult hasStaticInformation =
180-
memrefType.getStridesAndOffset(strides, offset);
181-
if (failed(hasStaticInformation))
182-
return SmallVector<int64_t>();
183-
return strides;
184-
}
185-
186115
//===----------------------------------------------------------------------===//
187116
// AllocOp / AllocaOp
188117
//===----------------------------------------------------------------------===//
@@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
14451374

14461375
SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
14471376
SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1448-
constifyIndexValues(values, getSource().getType(), getContext(),
1449-
getConstantSizes, ShapedType::isDynamic);
1377+
constifyIndexValues(values, getSource().getType().getShape());
14501378
return values;
14511379
}
14521380

14531381
SmallVector<OpFoldResult>
14541382
ExtractStridedMetadataOp::getConstifiedMixedStrides() {
14551383
SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1456-
constifyIndexValues(values, getSource().getType(), getContext(),
1457-
getConstantStrides, ShapedType::isDynamic);
1384+
SmallVector<int64_t> staticValues;
1385+
int64_t unused;
1386+
LogicalResult status =
1387+
getSource().getType().getStridesAndOffset(staticValues, unused);
1388+
(void)status;
1389+
assert(succeeded(status) && "could not get strides from type");
1390+
constifyIndexValues(values, staticValues);
14581391
return values;
14591392
}
14601393

14611394
OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
14621395
OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
14631396
SmallVector<OpFoldResult> values(1, offsetOfr);
1464-
constifyIndexValues(values, getSource().getType(), getContext(),
1465-
getConstantOffset, ShapedType::isDynamic);
1397+
SmallVector<int64_t> staticValues, unused;
1398+
int64_t offset;
1399+
LogicalResult status =
1400+
getSource().getType().getStridesAndOffset(unused, offset);
1401+
(void)status;
1402+
assert(succeeded(status) && "could not get offset from type");
1403+
staticValues.push_back(offset);
1404+
constifyIndexValues(values, staticValues);
14661405
return values[0];
14671406
}
14681407

@@ -1975,24 +1914,32 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
19751914

19761915
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
19771916
SmallVector<OpFoldResult> values = getMixedSizes();
1978-
constifyIndexValues(values, getType(), getContext(), getConstantSizes,
1979-
ShapedType::isDynamic);
1917+
constifyIndexValues(values, getType().getShape());
19801918
return values;
19811919
}
19821920

19831921
SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
19841922
SmallVector<OpFoldResult> values = getMixedStrides();
1985-
constifyIndexValues(values, getType(), getContext(), getConstantStrides,
1986-
ShapedType::isDynamic);
1923+
SmallVector<int64_t> staticValues;
1924+
int64_t unused;
1925+
LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
1926+
(void)status;
1927+
assert(succeeded(status) && "could not get strides from type");
1928+
constifyIndexValues(values, staticValues);
19871929
return values;
19881930
}
19891931

19901932
OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
19911933
SmallVector<OpFoldResult> values = getMixedOffsets();
19921934
assert(values.size() == 1 &&
19931935
"reinterpret_cast must have one and only one offset");
1994-
constifyIndexValues(values, getType(), getContext(), getConstantOffset,
1995-
ShapedType::isDynamic);
1936+
SmallVector<int64_t> staticValues, unused;
1937+
int64_t offset;
1938+
LogicalResult status = getType().getStridesAndOffset(unused, offset);
1939+
(void)status;
1940+
assert(succeeded(status) && "could not get offset from type");
1941+
staticValues.push_back(offset);
1942+
constifyIndexValues(values, staticValues);
19961943
return values[0];
19971944
}
19981945

0 commit comments

Comments
 (0)