diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index d3aef4ac38af0..816447713de41 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, //===----------------------------------------------------------------------===// static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, + VectorType vecTy, MemRefType memRefTy) { + // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't + // need any strides limitations. + if (!vecTy.isScalable() && + (vecTy.getRank() == 0 || vecTy.getNumElements() == 1)) + return success(); + if (!isLastMemrefDimUnitStride(memRefTy)) return op->emitOpError("most minor memref dim must have unit stride"); return success(); @@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() { VectorType resVecTy = getVectorType(); MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy))) return failure(); // Checks for vector memrefs. @@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() { VectorType valueVecTy = getVectorType(); MemRefType memRefTy = getMemRefType(); - if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) + if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy))) return failure(); // Checks for vector memrefs. diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4759fcc9511fb..08d1a189231bc 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>, return } +// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref +func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector + %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector + vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector + return +} + +// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref +func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>, + %i : index, %j : index) { + // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32> + %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32> + // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32> + vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32> + return +} + // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>, %i : index, %j : index) {