Skip to content

Commit bf561dd

Browse files
committed
[mlir][Vector] Vectorize integer matmuls
The underlying infrastructure supports this already, just add the pattern matching for linalg.generic. Differential Revision: https://reviews.llvm.org/D84335
1 parent e59778a commit bf561dd

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,17 @@ static bool hasMultiplyAddBody(Region &r) {
5252
auto pattern2 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(a, b)));
5353
auto pattern3 = m_Op<YieldOp>(m_Op<AddFOp>(m_Op<MulFOp>(b, a), c));
5454
auto pattern4 = m_Op<YieldOp>(m_Op<AddFOp>(c, m_Op<MulFOp>(b, a)));
55+
auto pattern5 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(a, b), c));
56+
auto pattern6 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(a, b)));
57+
auto pattern7 = m_Op<YieldOp>(m_Op<AddIOp>(m_Op<MulIOp>(b, a), c));
58+
auto pattern8 = m_Op<YieldOp>(m_Op<AddIOp>(c, m_Op<MulIOp>(b, a)));
5559
return pattern1.match(&r.front().back()) ||
5660
pattern2.match(&r.front().back()) ||
57-
pattern3.match(&r.front().back()) || pattern4.match(&r.front().back());
61+
pattern3.match(&r.front().back()) ||
62+
pattern4.match(&r.front().back()) ||
63+
pattern5.match(&r.front().back()) ||
64+
pattern6.match(&r.front().back()) ||
65+
pattern7.match(&r.front().back()) || pattern8.match(&r.front().back());
5866
}
5967

6068
// TODO: Should be Tablegen'd from a single source that generates the op itself.

mlir/test/Dialect/Linalg/transform-patterns.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,23 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
118118
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
119119
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
120120

121+
func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
122+
%C: memref<8x32xi32>) {
123+
linalg.generic #matmul_trait %A, %B, %C {
124+
^bb(%a: i32, %b: i32, %c: i32) :
125+
%d = muli %a, %b: i32
126+
%e = addi %c, %d: i32
127+
linalg.yield %e : i32
128+
} : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
129+
return
130+
}
131+
// CHECK-LABEL: func @vectorization_test_integer
132+
// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32>
133+
// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32>
134+
// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32>
135+
// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32>
136+
// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32>
137+
121138
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
122139
%C: memref<8x32xf32>) {
123140
linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :

0 commit comments

Comments
 (0)