-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][vector] Relax strides check for 1-element vector load/stores #108998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.
@llvm/pr-subscribers-mlir-vector Author: Ivan Butygin (Hardcode84) ChangesSingle elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious. Full diff: https://github.com/llvm/llvm-project/pull/108998.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+ VectorType vecTy,
MemRefType memRefTy) {
+ // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+ // need any strides limitations.
+ if (!vecTy.isScalable() &&
+ (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+ return success();
+
if (!isLastMemrefDimUnitStride(memRefTy))
return op->emitOpError("most minor memref dim must have unit stride");
return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
VectorType resVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
VectorType valueVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
return
}
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
|
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesSingle elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious. Full diff: https://github.com/llvm/llvm-project/pull/108998.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+ VectorType vecTy,
MemRefType memRefTy) {
+ // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+ // need any strides limitations.
+ if (!vecTy.isScalable() &&
+ (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+ return success();
+
if (!isLastMemrefDimUnitStride(memRefTy))
return op->emitOpError("most minor memref dim must have unit stride");
return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
VectorType resVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
VectorType valueVecTy = getVectorType();
MemRefType memRefTy = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+ if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
// Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
return
}
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+ return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+ %i : index, %j : index) {
+ // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
%i : index, %j : index) {
|
Please, could you also update the docs? https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop |
…lvm#108998) Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.
…lvm#109267) Follow up to llvm#108998. Non-contiguous strides are allowed now for 1-element vector load/stores.
Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.