Skip to content

Commit 9478bf0

Browse files
authored
[mlir] Introduce trailingNDimsContiguous for MemRefs (llvm#78247)
Extracts logic from `vector::isContiguousSlice` to check whether the trailing dim of a memref are contiguous into a dedicated hook in BuiitinTypes.{h|cpp}. Follow-up for llvm#76848.
1 parent 44436a9 commit 9478bf0

File tree

3 files changed

+46
-29
lines changed

3 files changed

+46
-29
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
518518
/// stride. Also return "true" for types with no strides.
519519
bool isLastMemrefDimUnitStride(MemRefType type);
520520

521+
/// Return "true" if the last N dimensions of the given type are contiguous.
522+
///
523+
/// Examples:
524+
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
525+
/// considering both _all_ and _only_ the trailing 3 dims,
526+
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
527+
/// considering the trailing 3 dims.
528+
///
529+
bool trailingNDimsContiguous(MemRefType type, int64_t n);
530+
521531
} // namespace mlir
522532

523533
#endif // MLIR_IR_BUILTINTYPES_H

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
257257
ArrayRef<int64_t> vectorShape = vectorType.getShape();
258258
auto vecRank = vectorType.getRank();
259259

260-
// Extract the trailing dims and strides of the input memref
261-
auto memrefShape = memrefType.getShape().take_back(vecRank);
262-
int64_t offset;
263-
SmallVector<int64_t> stridesFull;
264-
if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
265-
return false;
266-
auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
267-
memrefType.getLayout().isIdentity();
268-
269-
// TODO: Add support for memref with trailing dynamic shapes. Memrefs
270-
// with leading dynamic dimensions are already supported.
271-
if (ShapedType::isDynamicShape(memrefShape))
260+
if (!trailingNDimsContiguous(memrefType, vecRank))
272261
return false;
273262

274-
// Cond 1: Check whether `memrefType` is contiguous.
275-
if (!strides.empty()) {
276-
// Cond 1.1: A contiguous memref will always have a unit trailing stride.
277-
if (strides.back() != 1)
278-
return false;
279-
280-
// Cond 1.2: Strides of a contiguous memref have to match the flattened
281-
// dims.
282-
strides = strides.drop_back(1);
283-
SmallVector<int64_t> flattenedDims;
284-
for (size_t i = 1; i < memrefShape.size(); i++)
285-
flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
286-
287-
if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
288-
return false;
289-
}
263+
// Extract the trailing dims and strides of the input memref
264+
auto memrefShape = memrefType.getShape().take_back(vecRank);
290265

291-
// Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
266+
// Compare the dims of `vectorType` against `memrefType` (in reverse).
292267
// In the most basic case, all dims will match.
293268
auto firstNonMatchingDim =
294269
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,3 +967,35 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
967967
auto successStrides = getStridesAndOffset(type, strides, offset);
968968
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
969969
}
970+
971+
bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
972+
if (!isLastMemrefDimUnitStride(type))
973+
return false;
974+
975+
auto memrefShape = type.getShape().take_back(n);
976+
if (ShapedType::isDynamicShape(memrefShape))
977+
return false;
978+
979+
if (type.getLayout().isIdentity())
980+
return true;
981+
982+
int64_t offset;
983+
SmallVector<int64_t> stridesFull;
984+
if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
985+
return false;
986+
auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
987+
988+
if (strides.empty())
989+
return true;
990+
991+
// Check whether strides match "flattened" dims.
992+
SmallVector<int64_t> flattenedDims;
993+
auto dimProduct = 1;
994+
for (auto dim : llvm::reverse(memrefShape.drop_front(1))) {
995+
dimProduct *= dim;
996+
flattenedDims.push_back(dimProduct);
997+
}
998+
999+
strides = strides.drop_back(1);
1000+
return llvm::equal(strides, llvm::reverse(flattenedDims));
1001+
}

0 commit comments

Comments
 (0)