Skip to content

Commit 075e3fd

Browse files
[mlir][bufferize] Move arith BufferizableOpInterface impl to arith dialect
Also switch the implementation of `-arith-bufferize` to BufferizableOpInterface. Differential Revision: https://reviews.llvm.org/D118325
1 parent ccce1a0 commit 075e3fd

File tree

15 files changed

+62
-147
lines changed

15 files changed

+62
-147
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
14+
class DialectRegistry;
15+
16+
namespace arith {
17+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
18+
} // namespace arith
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_ARITHMETIC_BUFFERIZABLEOPINTERFACEIMPL_H

mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,8 @@
1212
#include "mlir/Pass/Pass.h"
1313

1414
namespace mlir {
15-
namespace bufferization {
16-
class BufferizeTypeConverter;
17-
} // namespace bufferization
18-
1915
namespace arith {
2016

21-
/// Add patterns to bufferize Arithmetic ops.
22-
void populateArithmeticBufferizePatterns(
23-
bufferization::BufferizeTypeConverter &typeConverter,
24-
RewritePatternSet &patterns);
25-
2617
/// Create a pass to bufferize Arithmetic ops.
2718
std::unique_ptr<Pass> createArithmeticBufferizePass();
2819

mlir/include/mlir/Dialect/Arithmetic/Transforms/Passes.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ include "mlir/Pass/PassBase.td"
1414
def ArithmeticBufferize : Pass<"arith-bufferize", "FuncOp"> {
1515
let summary = "Bufferize Arithmetic dialect ops.";
1616
let constructor = "mlir::arith::createArithmeticBufferizePass()";
17-
let dependentDialects = ["bufferization::BufferizationDialect",
18-
"memref::MemRefDialect"];
1917
}
2018

2119
def ArithmeticExpandOps : Pass<"arith-expand", "FuncOp"> {

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h

Lines changed: 0 additions & 27 deletions
This file was deleted.

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp renamed to mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
//===- ArithInterfaceImpl.cpp - Arith Impl. of BufferizableOpInterface ----===//
1+
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
10-
9+
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
1110
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1211
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1312
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
@@ -18,9 +17,8 @@
1817
using namespace mlir::bufferization;
1918

2019
namespace mlir {
21-
namespace linalg {
22-
namespace comprehensive_bufferize {
23-
namespace arith_ext {
20+
namespace arith {
21+
namespace {
2422

2523
/// Bufferization of arith.constant. Replace with memref.get_global.
2624
struct ConstantOpInterface
@@ -100,14 +98,13 @@ struct IndexCastOpInterface
10098
return success();
10199
}
102100
};
103-
} // namespace arith_ext
104-
} // namespace comprehensive_bufferize
105-
} // namespace linalg
101+
102+
} // namespace
103+
} // namespace arith
106104
} // namespace mlir
107105

108-
void mlir::linalg::comprehensive_bufferize::arith_ext::
109-
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
110-
registry.addOpInterface<arith::ConstantOp, arith_ext::ConstantOpInterface>();
111-
registry
112-
.addOpInterface<arith::IndexCastOp, arith_ext::IndexCastOpInterface>();
106+
void mlir::arith::registerBufferizableOpInterfaceExternalModels(
107+
DialectRegistry &registry) {
108+
registry.addOpInterface<ConstantOp, ConstantOpInterface>();
109+
registry.addOpInterface<IndexCastOp, IndexCastOpInterface>();
113110
}

mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,61 +8,37 @@
88

99
#include "PassDetail.h"
1010

11+
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
1112
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
13+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1214
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1315
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1416
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1517

1618
using namespace mlir;
19+
using namespace bufferization;
1720

1821
namespace {
19-
20-
/// Bufferize arith.index_cast.
21-
struct BufferizeIndexCastOp : public OpConversionPattern<arith::IndexCastOp> {
22-
using OpConversionPattern::OpConversionPattern;
23-
24-
LogicalResult
25-
matchAndRewrite(arith::IndexCastOp op, OpAdaptor adaptor,
26-
ConversionPatternRewriter &rewriter) const override {
27-
auto tensorType = op.getType().cast<RankedTensorType>();
28-
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
29-
op, adaptor.getIn(),
30-
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
31-
return success();
32-
}
33-
};
34-
3522
/// Pass to bufferize Arithmetic ops.
3623
struct ArithmeticBufferizePass
3724
: public ArithmeticBufferizeBase<ArithmeticBufferizePass> {
3825
void runOnOperation() override {
39-
bufferization::BufferizeTypeConverter typeConverter;
40-
RewritePatternSet patterns(&getContext());
41-
ConversionTarget target(getContext());
42-
43-
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect>();
44-
45-
arith::populateArithmeticBufferizePatterns(typeConverter, patterns);
26+
std::unique_ptr<BufferizationOptions> options =
27+
getPartialBufferizationOptions();
28+
options->addToDialectFilter<arith::ArithmeticDialect>();
4629

47-
target.addDynamicallyLegalOp<arith::IndexCastOp>(
48-
[&](arith::IndexCastOp op) {
49-
return typeConverter.isLegal(op.getType());
50-
});
51-
52-
if (failed(applyPartialConversion(getOperation(), target,
53-
std::move(patterns))))
30+
if (failed(bufferizeOp(getOperation(), *options)))
5431
signalPassFailure();
5532
}
56-
};
5733

34+
void getDependentDialects(DialectRegistry &registry) const override {
35+
registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect,
36+
arith::ArithmeticDialect>();
37+
arith::registerBufferizableOpInterfaceExternalModels(registry);
38+
}
39+
};
5840
} // namespace
5941

60-
void mlir::arith::populateArithmeticBufferizePatterns(
61-
bufferization::BufferizeTypeConverter &typeConverter,
62-
RewritePatternSet &patterns) {
63-
patterns.add<BufferizeIndexCastOp>(typeConverter, patterns.getContext());
64-
}
65-
6642
std::unique_ptr<Pass> mlir::arith::createArithmeticBufferizePass() {
6743
return std::make_unique<ArithmeticBufferizePass>();
6844
}

mlir/lib/Dialect/Arithmetic/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRArithmeticTransforms
2+
BufferizableOpInterfaceImpl.cpp
23
Bufferize.cpp
34
ExpandOps.cpp
45

@@ -10,6 +11,7 @@ add_mlir_dialect_library(MLIRArithmeticTransforms
1011

1112
LINK_LIBS PUBLIC
1213
MLIRArithmetic
14+
MLIRBufferization
1315
MLIRBufferizationTransforms
1416
MLIRIR
1517
MLIRMemRef

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
set(LLVM_OPTIONAL_SOURCES
22
AffineInterfaceImpl.cpp
3-
ArithInterfaceImpl.cpp
43
LinalgInterfaceImpl.cpp
54
ModuleBufferization.cpp
65
SCFInterfaceImpl.cpp
@@ -16,17 +15,6 @@ add_mlir_dialect_library(MLIRAffineBufferizableOpInterfaceImpl
1615
MLIRBufferization
1716
)
1817

19-
add_mlir_dialect_library(MLIRArithBufferizableOpInterfaceImpl
20-
ArithInterfaceImpl.cpp
21-
22-
LINK_LIBS PUBLIC
23-
MLIRArithmetic
24-
MLIRBufferization
25-
MLIRIR
26-
MLIRMemRef
27-
MLIRStandardOpsTransforms
28-
)
29-
3018
add_mlir_dialect_library(MLIRLinalgBufferizableOpInterfaceImpl
3119
LinalgInterfaceImpl.cpp
3220

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3434
MLIRAffineBufferizableOpInterfaceImpl
3535
MLIRAffineUtils
3636
MLIRAnalysis
37-
MLIRArithBufferizableOpInterfaceImpl
3837
MLIRArithmetic
38+
MLIRArithmeticTransforms
3939
MLIRBufferization
4040
MLIRComplex
4141
MLIRInferTypeOpInterface

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
#include "PassDetail.h"
1010

11+
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
1112
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1213
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1314
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1415
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
15-
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
1616
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
1717
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
1818
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
@@ -52,7 +52,7 @@ struct LinalgComprehensiveModuleBufferize
5252
vector::VectorDialect, scf::SCFDialect,
5353
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
5454
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
55-
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
55+
arith::registerBufferizableOpInterfaceExternalModels(registry);
5656
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
5757
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
5858
std_ext::registerModuleBufferizationExternalModels(registry);

mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,3 @@ func @rank_reducing(
9696
}
9797
return %5: tensor<?x1x6x8xf32>
9898
}
99-
100-
// -----
101-
102-
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)>
103-
// CHECK-LABEL: func @index_cast(
104-
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
105-
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
106-
%index_tensor = arith.index_cast %tensor : tensor<i32> to tensor<index>
107-
%index_scalar = arith.index_cast %scalar : i32 to index
108-
return %index_tensor, %index_scalar : tensor<index>, index
109-
}
110-
// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<i32, #[[$MAP]]>
111-
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]]
112-
// CHECK-SAME: memref<i32, #[[$MAP]]> to memref<index, #[[$MAP]]>
113-
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]]
114-
// CHECK: return %[[INDEX_TENSOR]]

mlir/test/lib/Dialect/Linalg/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ add_mlir_library(MLIRLinalgTestPasses
1414
LINK_LIBS PUBLIC
1515
MLIRAffine
1616
MLIRAffineBufferizableOpInterfaceImpl
17-
MLIRArithBufferizableOpInterfaceImpl
1817
MLIRArithmetic
18+
MLIRArithmeticTransforms
1919
MLIRBufferization
2020
MLIRBufferizationTransforms
2121
MLIRGPUTransforms

mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212

1313
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1414
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15+
#include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h"
1516
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1617
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1718
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
1819
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
19-
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
2020
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
2121
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
2222
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
@@ -59,7 +59,7 @@ struct TestComprehensiveFunctionBufferize
5959
vector::VectorDialect, scf::SCFDialect, StandardOpsDialect,
6060
arith::ArithmeticDialect, AffineDialect>();
6161
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
62-
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
62+
arith::registerBufferizableOpInterfaceExternalModels(registry);
6363
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
6464
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
6565
std_ext::registerBufferizableOpInterfaceExternalModels(registry);

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6580,27 +6580,6 @@ cc_library(
65806580
],
65816581
)
65826582

6583-
cc_library(
6584-
name = "ArithBufferizableOpInterfaceImpl",
6585-
srcs = [
6586-
"lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp",
6587-
],
6588-
hdrs = [
6589-
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h",
6590-
],
6591-
includes = ["include"],
6592-
deps = [
6593-
":ArithmeticDialect",
6594-
":BufferizationDialect",
6595-
":BufferizationTransforms",
6596-
":IR",
6597-
":MemRefDialect",
6598-
":Support",
6599-
":TransformUtils",
6600-
"//llvm:Support",
6601-
],
6602-
)
6603-
66046583
cc_library(
66056584
name = "LinalgBufferizableOpInterfaceImpl",
66066585
srcs = [
@@ -6876,8 +6855,8 @@ cc_library(
68766855
":AffineBufferizableOpInterfaceImpl",
68776856
":AffineUtils",
68786857
":Analysis",
6879-
":ArithBufferizableOpInterfaceImpl",
68806858
":ArithmeticDialect",
6859+
":ArithmeticTransforms",
68816860
":BufferizationDialect",
68826861
":BufferizationTransforms",
68836862
":ComplexDialect",
@@ -7566,7 +7545,10 @@ cc_library(
75667545
"lib/Dialect/Arithmetic/Transforms/*.cpp",
75677546
"lib/Dialect/Arithmetic/Transforms/*.h",
75687547
]),
7569-
hdrs = ["include/mlir/Dialect/Arithmetic/Transforms/Passes.h"],
7548+
hdrs = [
7549+
"include/mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h",
7550+
"include/mlir/Dialect/Arithmetic/Transforms/Passes.h",
7551+
],
75707552
includes = ["include"],
75717553
deps = [
75727554
":ArithmeticDialect",
@@ -7577,7 +7559,10 @@ cc_library(
75777559
":MemRefDialect",
75787560
":Pass",
75797561
":StandardOps",
7562+
":Support",
7563+
":TransformUtils",
75807564
":Transforms",
7565+
"//llvm:Support",
75817566
],
75827567
)
75837568

utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ cc_library(
389389
"//llvm:Support",
390390
"//mlir:Affine",
391391
"//mlir:AffineBufferizableOpInterfaceImpl",
392-
"//mlir:ArithBufferizableOpInterfaceImpl",
393392
"//mlir:ArithmeticDialect",
393+
"//mlir:ArithmeticTransforms",
394394
"//mlir:BufferizationDialect",
395395
"//mlir:BufferizationTransforms",
396396
"//mlir:GPUDialect",

0 commit comments

Comments
 (0)