Skip to content

Commit 4d273b9

Browse files
authored
[mlir][sparse] ensure [dis]assembler wrapper methods properly inline (#81907)
1 parent 3e004d1 commit 4d273b9

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
6161
}
6262
}
6363

64-
// Convert input and output values to [dis[assemble ops for sparse tensors.
64+
// Convert input and output values to [dis]assemble ops for sparse tensors.
6565
void convVals(OpBuilder &builder, Location loc, TypeRange types,
6666
ValueRange fromVals, ValueRange extraVals,
6767
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
@@ -161,8 +161,6 @@ namespace {
161161
//
162162
// TODO: refine output sparse tensors to work well with external framework
163163
//
164-
// TODO: use "inlining" instead of a wrapper?
165-
//
166164
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
167165
using OpRewritePattern::OpRewritePattern;
168166

@@ -211,7 +209,8 @@ struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
211209
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
212210
ValueRange(), inputs, 0, /*isIn=*/true);
213211

214-
// Call original, now internal method.
212+
// Call the original, now private method. A subsequent inlining pass can
213+
// determine whether cloning the method body in place is worthwhile.
215214
auto org = SymbolRefAttr::get(context, wrapper);
216215
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
217216
inputs);

mlir/test/Dialect/SparseTensor/torch_linalg.mlir

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt %s --sparse-assembler | FileCheck %s --check-prefix=CHECK-HI
22
// RUN: mlir-opt %s --sparse-assembler \
3+
// RUN: --inline | FileCheck %s --check-prefix=CHECK-INL
4+
// RUN: mlir-opt %s --sparse-assembler \
35
// RUN: --linalg-generalize-named-ops \
46
// RUN: --linalg-fuse-elementwise-ops \
57
// RUN: --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-MID
@@ -20,7 +22,13 @@
2022
// CHECK-HI: func.func private @_internal_main
2123
// CHECK-HI: linalg.matmul
2224
// CHECK-HI: return
23-
//
25+
26+
// CHECK-INL-LABEL: func.func @main
27+
// CHECK-INL: sparse_tensor.assemble
28+
// CHECK-INL: linalg.matmul
29+
// CHECK-INL: return
30+
// CHECK-INL-NOT: func.func private @_internal_main
31+
2432
// CHECK-MID-LABEL: func.func @main
2533
// CHECK-MID: memref.load
2634
// CHECK-MID: call @_internal_main

0 commit comments

Comments
 (0)