Skip to content

[mlir][memref] memref.subview: Verify result strides with rank reductions #80158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
}

/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
/// static value).
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) {
/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
/// marked as dropped in `droppedDims`.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
const llvm::SmallBitVector &droppedDims) {
assert(t1.getRank() == droppedDims.size() && "incorrect number of bits");
assert(t1.getRank() - t2.getRank() == droppedDims.count() &&
"incorrect number of dropped dims");
int64_t t1Offset, t2Offset;
SmallVector<int64_t> t1Strides, t2Strides;
auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset);
auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset);
if (failed(res1) || failed(res2))
return false;
for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides))
if (s1 != s2)
for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
if (droppedDims[i])
continue;
if (t1Strides[i] != t2Strides[j])
return false;
++j;
}
return true;
}

Expand Down Expand Up @@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() {
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);

// Strides must match if there are no rank reductions.
// TODO: Verify strides when there are rank reductions. Strides are partially
// checked in `computeMemRefRankReductionMask`.
if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType))
// Strides must match.
if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch,
*this, expectedType);

Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter,
SmallVector<OpFoldResult> finalStrides;
finalStrides.reserve(subRank);

#ifndef NDEBUG
// Iteration variable for result dimensions of the subview op.
int64_t j = 0;
#endif // NDEBUG
for (unsigned i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;

finalSizes.push_back(subSizes[i]);
finalStrides.push_back(strides[i]);
// TODO: Assert that the computed stride matches the respective stride of
// the result type of the subview op (if both are static), once the verifier
// of memref.subview verfies result strides correctly for ops with rank
// reductions.
#ifndef NDEBUG
// Assert that the computed stride matches the stride of the result type of
// the subview op (if both are static).
std::optional<int64_t> computedStride = getConstantIntValue(strides[i]);
if (computedStride && !ShapedType::isDynamic(resultStrides[j]))
assert(*computedStride == resultStrides[j] &&
"mismatch between computed stride and result type stride");
++j;
#endif // NDEBUG
}
assert(finalSizes.size() == subRank &&
"Should have populated all the values at this point");
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
// -----

func.func @rank_reducing_subview_canonicalize(%arg0 : memref<?x?x?xf32>, %arg1 : index,
%arg2 : index) -> memref<?x?xf32, strided<[?, 1], offset: ?>>
%arg2 : index) -> memref<?x?xf32, strided<[?, ?], offset: ?>>
{
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
return %0 : memref<?x?xf32, strided<[?, 1], offset: ?>>
%0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref<?x?x?xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
return %0 : memref<?x?xf32, strided<[?, ?], offset: ?>>
}
// CHECK-LABEL: func @rank_reducing_subview_canonicalize
// CHECK-SAME: %[[ARG0:.+]]: memref<?x?x?xf32>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref<?x?x?xf32>,
{
%0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1]
: memref<?x?x?xf32>
to memref<?xf32, strided<[1], offset: ?>>
to memref<?xf32, strided<[?], offset: ?>>
%1 = memref.subview %0[6] [1] [1]
: memref<?xf32, strided<[1], offset: ?>>
: memref<?xf32, strided<[?], offset: ?>>
to memref<f32, strided<[], offset: ?>>
return %1 : memref<f32, strided<[], offset: ?>>
}
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) {
: memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32>
return
}

// -----

func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) {
// expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}}
%subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1]
: memref<7x22x333x4444xi32> to memref<7x11x4444xi32>
return
}