Skip to content

[mlir][vector] Relax strides check for 1-element vector load/stores #108998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2024

Conversation

Hardcode84
Copy link
Contributor

Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.

Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.
@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes

Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.


Full diff: https://github.com/llvm/llvm-project/pull/108998.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-2)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+20)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 VectorType vecTy,
                                                  MemRefType memRefTy) {
+  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+  // need any strides limitations.
+  if (!vecTy.isScalable() &&
+      (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+    return success();
+
   if (!isLastMemrefDimUnitStride(memRefTy))
     return op->emitOpError("most minor memref dim must have unit stride");
   return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
   VectorType resVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
   VectorType valueVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
   return
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                          %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                         %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {

@llvmbot
Copy link
Member

llvmbot commented Sep 17, 2024

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Single elememst vector load/stores are equivalent to scalar load/stores, so they don't need memref to be contigious.


Full diff: https://github.com/llvm/llvm-project/pull/108998.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+9-2)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+20)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d3aef4ac38af03..816447713de417 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4769,7 +4769,14 @@ void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 
 static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
+                                                 VectorType vecTy,
                                                  MemRefType memRefTy) {
+  // If rank==0 or size==1 it's equivalent to scalar load/store, so we don't
+  // need any strides limitations.
+  if (!vecTy.isScalable() &&
+      (vecTy.getRank() == 0 || vecTy.getNumElements() == 1))
+    return success();
+
   if (!isLastMemrefDimUnitStride(memRefTy))
     return op->emitOpError("most minor memref dim must have unit stride");
   return success();
@@ -4779,7 +4786,7 @@ LogicalResult vector::LoadOp::verify() {
   VectorType resVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
@@ -4811,7 +4818,7 @@ LogicalResult vector::StoreOp::verify() {
   VectorType valueVecTy = getVectorType();
   MemRefType memRefTy = getMemRefType();
 
-  if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
+  if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
     return failure();
 
   // Checks for vector memrefs.
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index 4759fcc9511fb2..08d1a189231bcc 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -732,6 +732,26 @@ func.func @vector_load_and_store_0d_scalar_memref(%memref : memref<200x100xf32>,
   return
 }
 
+// CHECK-LABEL: @vector_load_and_store_0d_scalar_strided_memref
+func.func @vector_load_and_store_0d_scalar_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                          %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<f32>
+  return
+}
+
+// CHECK-LABEL: @vector_load_and_store_unit_vec_strided_memref
+func.func @vector_load_and_store_unit_vec_strided_memref(%memref : memref<200x100xf32, strided<[?, ?], offset: ?>>,
+                                                         %i : index, %j : index) {
+  // CHECK: %[[ld:.*]] = vector.load %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  %0 = vector.load %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  // CHECK: vector.store %[[ld]], %{{.*}}[%{{.*}}] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  vector.store %0, %memref[%i, %j] : memref<200x100xf32, strided<[?, ?], offset: ?>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_load_and_store_1d_scalar_memref
 func.func @vector_load_and_store_1d_scalar_memref(%memref : memref<200x100xf32>,
                                              %i : index, %j : index) {

@Hardcode84 Hardcode84 merged commit f325085 into llvm:main Sep 19, 2024
12 checks passed
@Hardcode84 Hardcode84 deleted the unit-vec-load branch September 19, 2024 10:12
@banach-space
Copy link
Contributor

Please, could you also update the docs? https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop

@Hardcode84
Copy link
Contributor Author

@banach-space #109267

Hardcode84 added a commit that referenced this pull request Sep 19, 2024
…109267)

Follow up to #108998.

Non-contiguous strides are allowed now for 1-element vector load/stores.
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…lvm#108998)

Single elememst vector load/stores are equivalent to scalar load/stores,
so they don't need memref to be contigious.
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…lvm#109267)

Follow up to llvm#108998.

Non-contiguous strides are allowed now for 1-element vector load/stores.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants