Skip to content

Commit 6a93da9

Browse files
authored
[mlir][sparse] add ReinterpretMapScopeOption for the pass (#70486)
1 parent 840bf2a commit 6a93da9

File tree

4 files changed

+36
-3
lines changed

4 files changed

+36
-3
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ enum class SparseParallelizationStrategy {
4040
kAnyStorageAnyLoop
4141
};
4242

43+
/// Defines a scope for reinterpret map pass.
44+
enum class ReinterpretMapScope {
45+
kAll, // reinterprets all applicable operations
46+
kGenericOnly, // reinterprets only linalg.generic
47+
kExceptGeneric, // reinterprets operation other than linalg.generic
48+
};
49+
4350
/// Defines data movement strategy between host and device for GPU.
4451
// TODO : Zero copy is disabled due to correctness bugs (tracker #64316)
4552
enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
@@ -51,9 +58,11 @@ enum class GPUDataTransferStrategy { kRegularDMA, kZeroCopy, kPinnedDMA };
5158
// The SparseReinterpretMap pass.
5259
//===----------------------------------------------------------------------===//
5360

54-
void populateSparseReinterpretMap(RewritePatternSet &patterns);
61+
void populateSparseReinterpretMap(RewritePatternSet &patterns,
62+
ReinterpretMapScope scope);
5563

5664
std::unique_ptr<Pass> createSparseReinterpretMapPass();
65+
std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
5766

5867
//===----------------------------------------------------------------------===//
5968
// The PreSparsificationRewriting pass.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,19 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
2727
"linalg::LinalgDialect",
2828
"sparse_tensor::SparseTensorDialect",
2929
];
30+
let options = [
31+
Option<"scope", "scope", "mlir::ReinterpretMapScope",
32+
"mlir::ReinterpretMapScope::kAll",
33+
"Set the reiterpretation scope", [{llvm::cl::values(
34+
clEnumValN(mlir::ReinterpretMapScope::kAll, "all",
35+
"Run on every applicable operations."),
36+
clEnumValN(mlir::ReinterpretMapScope::kGenericOnly,
37+
"only-generic",
38+
"Run only on linalg.generic operations."),
39+
clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
40+
"except-generic",
41+
"Run on operations expect linalg.generic (e.g., foreach)"))}]>,
42+
];
3043
}
3144

3245
def PreSparsificationRewrite : Pass<"pre-sparsification-rewrite", "ModuleOp"> {

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@ namespace {
1919

2020
} // namespace
2121

22-
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns) {}
22+
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
23+
ReinterpretMapScope scope) {}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,14 @@ struct SparseReinterpretMap
4949
: public impl::SparseReinterpretMapBase<SparseReinterpretMap> {
5050
SparseReinterpretMap() = default;
5151
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
52+
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
53+
scope = options.scope;
54+
}
5255

5356
void runOnOperation() override {
5457
auto *ctx = &getContext();
5558
RewritePatternSet patterns(ctx);
56-
populateSparseReinterpretMap(patterns);
59+
populateSparseReinterpretMap(patterns, scope);
5760
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
5861
}
5962
};
@@ -372,6 +375,13 @@ std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass() {
372375
return std::make_unique<SparseReinterpretMap>();
373376
}
374377

378+
std::unique_ptr<Pass>
379+
mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
380+
SparseReinterpretMapOptions options;
381+
options.scope = scope;
382+
return std::make_unique<SparseReinterpretMap>(options);
383+
}
384+
375385
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
376386
return std::make_unique<PreSparsificationRewritePass>();
377387
}

0 commit comments

Comments
 (0)