Skip to content

Commit f325085

Browse files
authored
[mlir][vector] Relax strides check for 1-element vector load/stores (#108998)
Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.
1 parent d267daa commit f325085

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
47694769
//===----------------------------------------------------------------------===//
47704770

47714771
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4772+
VectorType vecTy,
47724773
MemRefType memRefTy) {
4774+
// If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
4775+
// need any strides limitations.
4776+
if (!vecTy.isScalable() &&
4777+
(vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
4778+
return success();
4779+
47734780
if (!isLastMemrefDimUnitStride(memRefTy))
47744781
return op->emitOpError("most minor memref dim must have unit stride");
47754782
return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
47794786
VectorType resVecTy = getVectorType();
47804787
MemRefType memRefTy = getMemRefType();
47814788

4782-
if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4789+
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
47834790
return failure();
47844791

47854792
// Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
48114818
VectorType valueVecTy = getVectorType();
48124819
MemRefType memRefTy = getMemRefType();
48134820

4814-
if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4821+
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
48154822
return failure();
48164823

48174824
// Checks for vector memrefs.

mlir/test/Dialect/Vector/ops.mlir

+20
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
732732
return
733733
}
734734

735+
// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
736+
func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
737+
%i : index, %j : index) {
738+
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
739+
%0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
740+
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
741+
vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
742+
return
743+
}
744+
745+
// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
746+
func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
747+
%i : index, %j : index) {
748+
// CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
749+
%0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
750+
// CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
751+
vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
752+
return
753+
}
754+
735755
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
736756
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
737757
%i : index, %j : index) {

0 commit comments

Comments
 (0)