Skip to content

[mlir][vector] Update tests/patterns for vector.transpose #91359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 7, 2024

Pretty much all logic that we have today for lowering vector.transpose
assumes fixed length vectors (it's done via vector.shuffle that don't
support scalable vectors). This patch updates related tests and patterns
to capture and document that limitation more explicitly.

Note that vector.transpose is a valid operation in the context of
scalable vectors, but we are yet to implement the missing lowerings.

Most changes are implemented in the test file. Here's a summary:

  • @transpose_nx8x2xf32 is renamed as @transpose_scalable and move near
    other test using lowering_strategy = "shuffle_1d" (to avoid
    duplicating TD sequences)
  • tests specific to X86 (avx2_lowering_strategy = true) are moved to
    a dedicated file (to seperate generic tests from target-specific
    tests)
  • @transpose10_nx4xnx1xf32 duplicated @transpose10_4xnx1xf32 and was
    deleted (the latter is renamed as @transpose10_4x1xf32_scalable to
    match its fixed-width counterpart: @transpose10_4x1xf32)

@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Pretty much all logic that we have today for lowering vector.transpose
assumes fixed length vectors (it's done via vector.shuffle that don't
support scalable vectors). This patch updates related tests and patterns
to capture and document that limitation more explicitly.

Note that vector.transpose is a valid operation in the context of
scalable vectors, but we are yet to implement the missing lowerings.

Most changes are implemented in the test file. Here's a summary:

  • @transpose_nx8x2xf32 is renamed as @transpose_scalable and move near
    other test using lowering_strategy = "shuffle_1d" (to avoid
    duplicating TD sequences)
  • tests specific to X86 (avx2_lowering_strategy = true) are moved to
    a dedicated file (to seperate generic tests from target-specific
    tests)
  • @<!-- -->transpose10_nx4xnx1xf32 duplicated @<!-- -->transpose10_4xnx1xf32 and was
    deleted (the latter is renamed as @<!-- -->transpose10_4x1xf32_scalable to
    match its fixed-width counterpart: @<!-- -->transpose10_4x1xf32)
  • The changes in LowerVectorTranspose.cpp are NFCs - they just make sure
    that "scalable" vectors are filtered out at the very beginning

Patch is 44.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91359.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp (+4-3)
  • (modified) mlir/test/Dialect/Vector/vector-transpose-lowering.mlir (+39-522)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 792550dcfaf222..ebb28983145314 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -320,6 +320,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
 
   LogicalResult matchAndRewrite(vector::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
+    if (op.getSourceVectorType().isScalable())
+      return rewriter.notifyMatchFailure(
+          op, "scalable vectors are not supported by this pattern");
+
     auto loc = op.getLoc();
 
     Value input = op.getVector();
@@ -352,9 +356,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
       return success();
     }
 
-    if (inputType.isScalable())
-      return failure();
-
     // Handle a true 2-D matrix transpose differently when requested.
     if (vectorTransformOptions.vectorTransposeLowering ==
             vector::VectorTransposeLowering::Flat &&
diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
index 628a8ce5095994..219a72df52a19c 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
@@ -110,6 +110,17 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
   return %0 : vector<4x2xf32>
 }
 
+/// Scalable vectors are not supported
+
+// CHECK-LABEL: func @transpose_scalable
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.shape_cast
+// CHECK: vector.transpose
+func.func @transpose_scalable(%arg0: vector<2x[4]xf32>) -> vector<[4]x2xf32> {
+  %0 = vector.transpose %arg0, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+  return %0 : vector<[4]x2xf32>
+}
+
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
@@ -132,502 +143,22 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
   return %0 : vector<4x2xf32>
 }
 
+/// Scalable vectors are not supported
 
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
-    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
-    transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose"
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: func @transpose4x8
-func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
-  //      CHECK: vector.extract {{.*}}[0]
-  // CHECK-NEXT: vector.extract {{.*}}[1]
-  // CHECK-NEXT: vector.extract {{.*}}[2]
-  // CHECK-NEXT: vector.extract {{.*}}[3]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<8x4xf32>
-  %0 = vector.transpose %arg0, [1, 0] : vector<4x8xf32> to vector<8x4xf32>
-  return %0 : vector<8x4xf32>
-}
-
-// CHECK-LABEL: func @transpose021_1x4x8
-func.func @transpose021_1x4x8xf32(%arg0: vector<1x4x8xf32>) -> vector<1x8x4xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 1]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 2]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 3]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 8, 9, 4, 5, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 3, 10, 11, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<4x8xf32> to vector<32xf32>
-  // CHECK-NEXT: vector.shape_cast {{.*}} vector<32xf32> to vector<1x8x4xf32>
-  %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x4x8xf32> to vector<1x8x4xf32>
-  return %0 : vector<1x8x4xf32>
-}
-
-// CHECK-LABEL: func @transpose8x8
-func.func @transpose8x8xf32(%arg0: vector<8x8xf32>) -> vector<8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0]
-  // CHECK-NEXT: vector.extract {{.*}}[1]
-  // CHECK-NEXT: vector.extract {{.*}}[2]
-  // CHECK-NEXT: vector.extract {{.*}}[3]
-  // CHECK-NEXT: vector.extract {{.*}}[4]
-  // CHECK-NEXT: vector.extract {{.*}}[5]
-  // CHECK-NEXT: vector.extract {{.*}}[6]
-  // CHECK-NEXT: vector.extract {{.*}}[7]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.insert {{.*}}[4]
-  // CHECK-NEXT: vector.insert {{.*}}[5]
-  // CHECK-NEXT: vector.insert {{.*}}[6]
-  // CHECK-NEXT: vector.insert {{.*}}[7]
-  %0 = vector.transpose %arg0, [1, 0] : vector<8x8xf32> to vector<8x8xf32>
-  return %0 : vector<8x8xf32>
-}
-
-// CHECK-LABEL: func @transpose021_1x8x8
-func.func @transpose021_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<1x8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 1]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 2]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 3]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 4]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 5]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 6]
-  // CHECK-NEXT: vector.extract {{.*}}[0, 7]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.insert {{.*}}[4]
-  // CHECK-NEXT: vector.insert {{.*}}[5]
-  // CHECK-NEXT: vector.insert {{.*}}[6]
-  // CHECK-NEXT: vector.insert {{.*}}[7]
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32>
-  %0 = vector.transpose %arg0, [0, 2, 1] : vector<1x8x8xf32> to vector<1x8x8xf32>
-  return %0 : vector<1x8x8xf32>
-}
-
-// CHECK-LABEL: func @transpose120_8x1x8
-func.func @transpose120_8x1x8xf32(%arg0: vector<8x1x8xf32>) -> vector<1x8x8xf32> {
-  //      CHECK: vector.extract {{.*}}[0, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[1, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[2, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[3, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[4, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[5, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[6, 0]
-  // CHECK-NEXT: vector.extract {{.*}}[7, 0]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.insert {{.*}}[4]
-  // CHECK-NEXT: vector.insert {{.*}}[5]
-  // CHECK-NEXT: vector.insert {{.*}}[6]
-  // CHECK-NEXT: vector.insert {{.*}}[7]
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<1x8x8xf32>
-  %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x1x8xf32> to vector<1x8x8xf32>
-  return %0 : vector<1x8x8xf32>
-}
-
-// CHECK-LABEL: func @transpose120_8x8x1
-func.func @transpose120_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x1x8xf32> {
-  //      CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0]
-  // CHECK-NEXT: vector.extract {{.*}}[1]
-  // CHECK-NEXT: vector.extract {{.*}}[2]
-  // CHECK-NEXT: vector.extract {{.*}}[3]
-  // CHECK-NEXT: vector.extract {{.*}}[4]
-  // CHECK-NEXT: vector.extract {{.*}}[5]
-  // CHECK-NEXT: vector.extract {{.*}}[6]
-  // CHECK-NEXT: vector.extract {{.*}}[7]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.shuffle {{.*}} [2, 10, 3, 11, 6, 14, 7, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [2, 3, 8, 9, 6, 7, 12, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0xcc", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-NEXT: llvm.inline_asm asm_dialect = intel "vblendps $0, $1, $2, 0x33", "=x,x,x" {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
-  // CHECK-COUNT-4: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vector.insert {{.*}}[0]
-  // CHECK-NEXT: vector.insert {{.*}}[1]
-  // CHECK-NEXT: vector.insert {{.*}}[2]
-  // CHECK-NEXT: vector.insert {{.*}}[3]
-  // CHECK-NEXT: vector.insert {{.*}}[4]
-  // CHECK-NEXT: vector.insert {{.*}}[5]
-  // CHECK-NEXT: vector.insert {{.*}}[6]
-  // CHECK-NEXT: vector.insert {{.*}}[7]
-  // CHECK-NEXT: vector.shape_cast %{{.*}} : vector<8x8xf32> to vector<8x1x8xf32>
-  %0 = vector.transpose %arg0, [1, 2, 0] : vector<8x8x1xf32> to vector<8x1x8xf32>
-  return %0 : vector<8x1x8xf32>
-}
-
-// CHECK-LABEL: func @transpose102_8x8x1
-func.func @transpose102_8x8x1xf32(%arg0: vector<8x8x1xf32>) -> vector<8x8x1xf32> {
-  //      CHECK: vector.shape_cast %{{.*}} : vector<8x8x1xf32> to vector<8x8xf32>
-  // CHECK-NEXT: vector.extract {{.*}}[0]
-  // CHECK-NEXT: vector.extract {{.*}}[1]
-  // CHECK-NEXT: vector.extract {{.*}}[2]
-  // CHECK-NEXT: vector.extract {{.*}}[3]
-  // CHECK-NEXT: vector.extract {{.*}}[4]
-  // CHECK-NEXT: vector.extract {{.*}}[5]
-  // CHECK-NEXT: vector.extract {{.*}}[6]
-  // CHECK-NEXT: vector.extract {{.*}}[7]
-  // CHECK-NEXT: vector.shuffle {{.*}} [0, 8, 1, 9, 4, 12, 5, 13] : vector<8xf32>, vector<8xf32>
-  // CHECK-NEXT: vect...
[truncated]

Pretty much all logic that we have today for lowering vector.transpose
assumes fixed length vectors (it's done via vector.shuffle that don't
support scalable vectors). This patch updates related tests and patterns
to capture and document that limitation more explicitly.

Note that `vector.transpose` is a valid operation in the context of
scalable vectors, but we are yet to implement the missing lowerings.

Most changes are implemented in the test file. Here's a summary:
* @transpose_nx8x2xf32 is renamed as @transpose_scalable and move near
  other test using `lowering_strategy = "shuffle_1d"` (to avoid
  duplicating TD sequences)
* tests specific to X86  (`avx2_lowering_strategy = true`) are moved to
  a dedicated file (to seperate generic tests from target-specific
  tests)
* `@transpose10_nx4xnx1xf32` duplicated `@transpose10_4xnx1xf32` and was
  deleted (the latter is renamed as `@transpose10_4x1xf32_scalable` to
  match its fixed-width counterpart: `@transpose10_4x1xf32`)
* The changes in LowerVectorTranspose.cpp are NFCs - they just make sure
  that "scalable" vectors are filtered out at the very beginning
@banach-space banach-space force-pushed the andrzej/audit_vector_transpose branch from 952f78a to 819daf8 Compare May 7, 2024 21:21
@hanhanW
Copy link
Contributor

hanhanW commented May 7, 2024

tests specific to X86 (avx2_lowering_strategy = true) are moved to a dedicated file (to seperate generic tests from target-specific tests)

Where is the file? I don't see it in the PR. It looks like you're removing bunch of lit tests from vector-transpose-lowering.mlir.

@banach-space
Copy link
Contributor Author

tests specific to X86 (avx2_lowering_strategy = true) are moved to a dedicated file (to seperate generic tests from target-specific tests)

Where is the file? I don't see it in the PR. It looks like you're removing bunch of lit tests from vector-transpose-lowering.mlir.

Sorry about that, forgot to attach it. Fixed in the latest commit.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit, otherwise LGTM.

Comment on lines 3 to 4
// NOTE: This file tests lowering from the X86 dialect. Since X86 does not
// support scalable vectors, all examples in this file use fixed-width vectors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not lowering from the x86 dialect (also to be clear there's no x86 dialect, only x86vector), the input is the vector dialect. However, this does not actually lower to x86vector dialect operations, it lowers directly to inline assembly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I should've been more specific. What I meant is that these lowerings (should be "lowerings" rather than "lowering") are implemented in the X86Vector (rather than X86) Dialect:

Let me update.

// support scalable vectors, all examples in this file use fixed-width vectors.

// CHECK-LABEL: func @transpose4x8
func.func @transpose4x8xf32(%arg0: vector<4x8xf32>) -> vector<8x4xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I diffed these tests with the deleted tests -- identical match.

@banach-space banach-space force-pushed the andrzej/audit_vector_transpose branch from 65babe1 to e27d11c Compare May 9, 2024 13:55
@banach-space banach-space merged commit 0bacffb into llvm:main May 13, 2024
4 checks passed
@banach-space banach-space deleted the andrzej/audit_vector_transpose branch May 13, 2024 07:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants