From 17c0745a45af86b521739b111b453c16c1301810 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 31 Jan 2024 16:33:43 +0000 Subject: [PATCH] [mlir][memref] `memref.subview`: Verify result strides with rank reductions This is a follow-up on #79865. Result strides are now also verified if the `memref.subview` op has rank reductions. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++++++------- .../Transforms/ExpandStridedMetadata.cpp | 17 ++++++++++---- mlir/test/Dialect/MemRef/canonicalize.mlir | 6 ++--- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 4 ++-- mlir/test/Dialect/MemRef/invalid.mlir | 9 ++++++++ 5 files changed, 42 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index f43217f6f27ae..841c5d1686b44 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2756,17 +2756,26 @@ static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) { } /// Return true if `t1` and `t2` have equal strides (both dynamic or of same -/// static value). -static bool haveCompatibleStrides(MemRefType t1, MemRefType t2) { +/// static value). Dimensions of `t1` may be dropped in `t2`; these must be +/// marked as dropped in `droppedDims`. +static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, + const llvm::SmallBitVector &droppedDims) { + assert(t1.getRank() == droppedDims.size() && "incorrect number of bits"); + assert(t1.getRank() - t2.getRank() == droppedDims.count() && + "incorrect number of dropped dims"); int64_t t1Offset, t2Offset; SmallVector t1Strides, t2Strides; auto res1 = getStridesAndOffset(t1, t1Strides, t1Offset); auto res2 = getStridesAndOffset(t2, t2Strides, t2Offset); if (failed(res1) || failed(res2)) return false; - for (auto [s1, s2] : llvm::zip_equal(t1Strides, t2Strides)) - if (s1 != s2) + for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) { + if (droppedDims[i]) + continue; + if (t1Strides[i] != t2Strides[j]) return false; + ++j; + } return true; } @@ -2843,10 +2852,8 @@ LogicalResult SubViewOp::verify() { return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, *this, expectedType); - // Strides must match if there are no rank reductions. - // TODO: Verify strides when there are rank reductions. Strides are partially - // checked in `computeMemRefRankReductionMask`. - if (unusedDims->none() && !haveCompatibleStrides(expectedType, subViewType)) + // Strides must match. + if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims)) return produceSubViewErrorMsg(SliceVerificationResult::LayoutMismatch, *this, expectedType); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index f6af0791ba756..96eb7cfd2db69 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -144,16 +144,25 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, SmallVector finalStrides; finalStrides.reserve(subRank); +#ifndef NDEBUG + // Iteration variable for result dimensions of the subview op. + int64_t j = 0; +#endif // NDEBUG for (unsigned i = 0; i < sourceRank; ++i) { if (droppedDims.test(i)) continue; finalSizes.push_back(subSizes[i]); finalStrides.push_back(strides[i]); - // TODO: Assert that the computed stride matches the respective stride of - // the result type of the subview op (if both are static), once the verifier - // of memref.subview verfies result strides correctly for ops with rank - // reductions. +#ifndef NDEBUG + // Assert that the computed stride matches the stride of the result type of + // the subview op (if both are static). + std::optional computedStride = getConstantIntValue(strides[i]); + if (computedStride && !ShapedType::isDynamic(resultStrides[j])) + assert(*computedStride == resultStrides[j] && + "mismatch between computed stride and result type stride"); + ++j; +#endif // NDEBUG } assert(finalSizes.size() == subRank && "Should have populated all the values at this point"); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index 993ef32edc9d4..a772a25da5738 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -62,13 +62,13 @@ func.func @subview_canonicalize(%arg0 : memref, %arg1 : index, // ----- func.func @rank_reducing_subview_canonicalize(%arg0 : memref, %arg1 : index, - %arg2 : index) -> memref> + %arg2 : index) -> memref> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref> - return %0 : memref> + %0 = memref.subview %arg0[%c0, %arg1, %c1] [%c4, 1, %arg2] [%c1, %c1, %c1] : memref to memref> + return %0 : memref> } // CHECK-LABEL: func @rank_reducing_subview_canonicalize // CHECK-SAME: %[[ARG0:.+]]: memref diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 3407bdbc7c8f9..5b853a6cc5a37 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -613,9 +613,9 @@ func.func @subview_of_subview_rank_reducing(%m: memref, { %0 = memref.subview %m[3, 1, 8] [1, %sz, 1] [1, 1, 1] : memref - to memref> + to memref> %1 = memref.subview %0[6] [1] [1] - : memref> + : memref> to memref> return %1 : memref> } diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index be60a3dcb1b20..8f5ba5ea8fc78 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -1082,3 +1082,12 @@ func.func @subview_invalid_strides(%m: memref<7x22x333x4444xi32>) { : memref<7x22x333x4444xi32> to memref<7x11x333x4444xi32> return } + +// ----- + +func.func @subview_invalid_strides_rank_reduction(%m: memref<7x22x333x4444xi32>) { + // expected-error @below{{expected result type to be 'memref<7x11x1x4444xi32, strided<[32556744, 2959704, 4444, 1]>>' or a rank-reduced version. (mismatch of result layout)}} + %subview = memref.subview %m[0, 0, 0, 0] [7, 11, 1, 4444] [1, 2, 1, 1] + : memref<7x22x333x4444xi32> to memref<7x11x4444xi32> + return +}