Skip to content

Commit 2e21dff

Browse files
MacDuecjdb
authored andcommitted
[memref] Handle edge case in subview of full static size fold (llvm#105635)
It is possible to have a subview with a fully static size and a type that matches the source type, but a dynamic offset that may be different. However, currently the memref dialect folds: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { %0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>> return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` To: ```mlir func.func @subview_of_static_full_size( %arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %arg1: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> { return %arg0 : memref<16x4xf32, strided<[4, 1], offset: ?>> } ``` Which drops the dynamic offset from the `subview` op.
1 parent e4401d0 commit 2e21dff

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

mlir/include/mlir/IR/BuiltinAttributes.td

+4
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,10 @@ def StridedLayoutAttr : Builtin_Attr<"StridedLayout", "strided_layout",
10121012
let extraClassDeclaration = [{
10131013
/// Print the attribute to the given output stream.
10141014
void print(raw_ostream &os) const;
1015+
1016+
/// Returns true if this layout is static, i.e. the strides and offset all
1017+
/// have a known value > 0.
1018+
bool hasStaticLayout() const;
10151019
}];
10161020
}
10171021

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -3279,11 +3279,14 @@ void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
32793279
}
32803280

32813281
OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3282-
auto resultShapedType = llvm::cast<ShapedType>(getResult().getType());
3283-
auto sourceShapedType = llvm::cast<ShapedType>(getSource().getType());
3284-
3285-
if (resultShapedType.hasStaticShape() &&
3286-
resultShapedType == sourceShapedType) {
3282+
MemRefType sourceMemrefType = getSource().getType();
3283+
MemRefType resultMemrefType = getResult().getType();
3284+
auto resultLayout =
3285+
dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3286+
3287+
if (resultMemrefType == sourceMemrefType &&
3288+
resultMemrefType.hasStaticShape() &&
3289+
(!resultLayout || resultLayout.hasStaticLayout())) {
32873290
return getViewSource();
32883291
}
32893292

@@ -3301,7 +3304,7 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
33013304
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
33023305
bool allSizesSame = llvm::equal(sizes, srcSizes);
33033306
if (allOffsetsZero && allStridesOne && allSizesSame &&
3304-
resultShapedType == sourceShapedType)
3307+
resultMemrefType == sourceMemrefType)
33053308
return getViewSource();
33063309
}
33073310

mlir/lib/IR/BuiltinAttributes.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
229229
os << ">";
230230
}
231231

232+
/// Returns true if this layout is static, i.e. the strides and offset all have
233+
/// a known value > 0.
234+
bool StridedLayoutAttr::hasStaticLayout() const {
235+
return !ShapedType::isDynamic(getOffset()) &&
236+
!ShapedType::isDynamicShape(getStrides());
237+
}
238+
232239
/// Returns the strided layout as an affine map.
233240
AffineMap StridedLayoutAttr::getAffineMap() const {
234241
return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());

mlir/test/Dialect/MemRef/canonicalize.mlir

+13
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,19 @@ func.func @subview_of_static_full_size(%arg0 : memref<4x6x16x32xi8>) -> memref<4
7070

7171
// -----
7272

73+
// CHECK-LABEL: func @negative_subview_of_static_full_size
74+
// CHECK-SAME: %[[ARG0:.+]]: memref<16x4xf32, strided<[4, 1], offset: ?>>
75+
// CHECK-SAME: %[[IDX:.+]]: index
76+
// CHECK: %[[S:.+]] = memref.subview %[[ARG0]][%[[IDX]], 0] [16, 4] [1, 1]
77+
// CHECK-SAME: to memref<16x4xf32, strided<[4, 1], offset: ?>>
78+
// CHECK: return %[[S]] : memref<16x4xf32, strided<[4, 1], offset: ?>>
79+
func.func @negative_subview_of_static_full_size(%arg0: memref<16x4xf32, strided<[4, 1], offset: ?>>, %idx: index) -> memref<16x4xf32, strided<[4, 1], offset: ?>> {
80+
%0 = memref.subview %arg0[%idx, 0][16, 4][1, 1] : memref<16x4xf32, strided<[4, 1], offset: ?>> to memref<16x4xf32, strided<[4, 1], offset: ?>>
81+
return %0 : memref<16x4xf32, strided<[4, 1], offset: ?>>
82+
}
83+
84+
// -----
85+
7386
func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
7487
%arg2 : index) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
7588
{

0 commit comments

Comments
 (0)