Skip to content

Commit a813e9b

Browse files
committed
[MLIR][TOSA] Added Tosa to Standard/SCF Lowerings (const, if, while)
Includes a lowering for tosa.const, tosa.if, and tosa.while to Standard/SCF dialects. TosaToStandard is used for constant lowerings and TosaToSCF handles the if/while ops. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D97352
1 parent 95d0d8e commit a813e9b

File tree

14 files changed

+472
-0
lines changed

14 files changed

+472
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
3232
#include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
3333
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
34+
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
35+
#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
3436
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3537
#include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h"
3638
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"

mlir/include/mlir/Conversion/Passes.td

+30
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,36 @@ def TosaToLinalgOnTensors : FunctionPass<"tosa-to-linalg-on-tensors"> {
440440
let constructor = "tosa::createTosaToLinalgOnTensors()";
441441
}
442442

443+
//===----------------------------------------------------------------------===//
444+
// TosaToSCF
445+
//===----------------------------------------------------------------------===//
446+
447+
def TosaToSCF : Pass<"tosa-to-scf"> {
448+
let summary = "Lower TOSA to the SCF dialect";
449+
let dependentDialects = ["tensor::TensorDialect, scf::SCFDialect"];
450+
let description = [{
451+
Pass that converts TOSA's control flow operations to the equivalent SCF
452+
operations.
453+
}];
454+
455+
let constructor = "tosa::createTosaToSCF()";
456+
}
457+
458+
//===----------------------------------------------------------------------===//
459+
// TosaToStandard
460+
//===----------------------------------------------------------------------===//
461+
462+
def TosaToStandard : Pass<"tosa-to-standard"> {
463+
let summary = "Lower TOSA to the Standard dialect";
464+
let dependentDialects = ["StandardOpsDialect"];
465+
let description = [{
466+
Pass that converts TOSA operations to the equivalent operations using the
467+
operations in the Standard dialect.
468+
}];
469+
470+
let constructor = "tosa::createTosaToStandard()";
471+
}
472+
443473
//===----------------------------------------------------------------------===//
444474
// VectorToSCF
445475
//===----------------------------------------------------------------------===//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- TosaToSCF.h - TOSA to SCF dialect lowerings -------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the passes for the TOSA to SCF Dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
14+
#define MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
namespace tosa {
20+
21+
std::unique_ptr<Pass> createTosaToSCF();
22+
23+
void populateTosaToSCFConversionPatterns(MLIRContext *context,
24+
OwningRewritePatternList *patterns);
25+
26+
/// Populates passes to convert from TOSA to SCF.
27+
void addTosaToSCFPasses(OpPassManager &pm);
28+
29+
} // namespace tosa
30+
} // namespace mlir
31+
32+
#endif // MLIR_CONVERSION_TOSATOSCF_TOSATOSCF_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===-- TosaToStandard.h - TOSA optimization pass declarations --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the passes for the TOSA to Standard Dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
14+
#define MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
namespace tosa {
20+
21+
std::unique_ptr<Pass> createTosaToStandard();
22+
23+
void populateTosaToStandardConversionPatterns(
24+
MLIRContext *context, OwningRewritePatternList *patterns);
25+
26+
/// Populates passes to convert from TOSA to Standard.
27+
void addTosaToStandardPasses(OpPassManager &pm);
28+
29+
} // namespace tosa
30+
} // namespace mlir
31+
32+
#endif // MLIR_CONVERSION_TOSATOSTANDARD_TOSATOSTANDARD_H

mlir/lib/Conversion/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ add_subdirectory(SPIRVToLLVM)
2222
add_subdirectory(StandardToLLVM)
2323
add_subdirectory(StandardToSPIRV)
2424
add_subdirectory(TosaToLinalg)
25+
add_subdirectory(TosaToSCF)
26+
add_subdirectory(TosaToStandard)
2527
add_subdirectory(ArmSVEToLLVM)
2628
add_subdirectory(VectorToROCDL)
2729
add_subdirectory(VectorToLLVM)

mlir/lib/Conversion/PassDetail.h

+8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ namespace spirv {
5959
class SPIRVDialect;
6060
} // end namespace spirv
6161

62+
namespace tensor {
63+
class TensorDialect;
64+
} // end namespace tensor
65+
66+
namespace tosa {
67+
class TosaDialect;
68+
} // end namespace tosa
69+
6270
namespace vector {
6371
class VectorDialect;
6472
} // end namespace vector
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
add_mlir_conversion_library(MLIRTosaToSCF
2+
TosaToSCF.cpp
3+
TosaToSCFPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
7+
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
8+
9+
DEPENDS
10+
MLIRConversionPassIncGen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRIR
14+
MLIRSCF
15+
MLIRStandard
16+
MLIRPass
17+
MLIRTensor
18+
MLIRTosa
19+
MLIRTosaTransforms
20+
MLIRSupport
21+
)
+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===- TosaToSCF.cpp - Lowering Tosa to SCF Dialect -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// These rewriters lower from the Tosa to the SCF dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
14+
#include "mlir/Dialect/SCF/SCF.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
17+
#include "mlir/IR/BlockAndValueMapping.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20+
21+
using namespace mlir;
22+
using namespace tosa;
23+
24+
static void inlineIfCase(Region &srcRegion, Region &dstRegion,
25+
OperandRange operands, PatternRewriter &rewriter) {
26+
BlockAndValueMapping mapper;
27+
dstRegion.takeBody(srcRegion);
28+
Block *headBlock = &dstRegion.front();
29+
for (auto it : llvm::zip(headBlock->getArguments(), operands))
30+
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
31+
32+
for (auto &block : dstRegion) {
33+
llvm::SmallVector<Operation *> toDelete;
34+
block.walk([&](tosa::YieldOp yield) {
35+
rewriter.setInsertionPoint(yield);
36+
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
37+
toDelete.push_back(yield);
38+
});
39+
for (Operation *val : toDelete)
40+
rewriter.eraseOp(val);
41+
}
42+
43+
headBlock->eraseArguments(
44+
llvm::to_vector<4>(llvm::seq<unsigned>(0, headBlock->getNumArguments())));
45+
}
46+
47+
static void inlineWhileCase(Region &srcRegion, Region &dstRegion,
48+
OperandRange operands, PatternRewriter &rewriter,
49+
bool isCond) {
50+
BlockAndValueMapping mapper;
51+
dstRegion.takeBody(srcRegion);
52+
53+
for (auto &block : dstRegion) {
54+
llvm::SmallVector<Operation *> toDelete;
55+
block.walk([&](tosa::YieldOp yield) {
56+
rewriter.setInsertionPoint(yield);
57+
if (isCond) {
58+
auto condition = rewriter.create<tensor::ExtractOp>(
59+
yield.getLoc(), yield.getOperand(0));
60+
rewriter.create<scf::ConditionOp>(yield.getLoc(), condition,
61+
block.getArguments());
62+
} else {
63+
rewriter.setInsertionPoint(yield);
64+
rewriter.create<scf::YieldOp>(yield.getLoc(), yield.inputs());
65+
}
66+
toDelete.push_back(yield);
67+
});
68+
for (Operation *val : toDelete)
69+
rewriter.eraseOp(val);
70+
}
71+
}
72+
73+
namespace {
74+
75+
class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
76+
public:
77+
using OpRewritePattern<tosa::IfOp>::OpRewritePattern;
78+
79+
LogicalResult matchAndRewrite(tosa::IfOp op,
80+
PatternRewriter &rewriter) const final {
81+
auto condition = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.cond());
82+
auto newIf = rewriter.replaceOpWithNewOp<scf::IfOp>(op, op.getResultTypes(),
83+
condition, true);
84+
85+
inlineIfCase(op.then_branch(), newIf.thenRegion(), op.inputs(), rewriter);
86+
inlineIfCase(op.else_branch(), newIf.elseRegion(), op.inputs(), rewriter);
87+
return success();
88+
}
89+
};
90+
91+
class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
92+
public:
93+
using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
94+
95+
LogicalResult matchAndRewrite(tosa::WhileOp op,
96+
PatternRewriter &rewriter) const final {
97+
auto newWhile = rewriter.replaceOpWithNewOp<scf::WhileOp>(
98+
op, op.getResultTypes(), op.inputs());
99+
100+
inlineWhileCase(op.cond(), newWhile.before(), op.inputs(), rewriter, true);
101+
inlineWhileCase(op.body(), newWhile.after(), op.inputs(), rewriter, false);
102+
103+
return success();
104+
}
105+
};
106+
107+
} // namespace
108+
109+
void mlir::tosa::populateTosaToSCFConversionPatterns(
110+
MLIRContext *context, OwningRewritePatternList *patterns) {
111+
patterns->insert<IfOpConverter>(context);
112+
patterns->insert<WhileOpConverter>(context);
113+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- TosaToSCFPass.cpp - Lowering Tosa to SCF Dialect -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This transformation pass legalizes Tosa operations to the SCF dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "../PassDetail.h"
14+
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
15+
#include "mlir/Dialect/SCF/SCF.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
18+
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
19+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Pass/PassManager.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
using namespace mlir;
26+
using namespace tosa;
27+
28+
namespace {
29+
struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
30+
public:
31+
void runOnOperation() override {
32+
OwningRewritePatternList patterns;
33+
TypeConverter typeConverter;
34+
ConversionTarget target(getContext());
35+
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
36+
target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
37+
38+
auto op = getOperation();
39+
mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
40+
&patterns);
41+
if (failed(applyPartialConversion(op, target, std::move(patterns))))
42+
signalPassFailure();
43+
}
44+
};
45+
} // namespace
46+
47+
std::unique_ptr<Pass> mlir::tosa::createTosaToSCF() {
48+
return std::make_unique<TosaToSCF>();
49+
}
50+
51+
void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) {
52+
pm.addNestedPass<FuncOp>(createTosaToSCF());
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRTosaToStandard
2+
TosaToStandard.cpp
3+
TosaToStandardPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
7+
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
8+
9+
DEPENDS
10+
MLIRConversionPassIncGen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRIR
14+
MLIRStandard
15+
MLIRPass
16+
MLIRTosa
17+
MLIRTosaTransforms
18+
MLIRSupport
19+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- TosaToStandard.cpp - Lowering Tosa to Standard Dialect -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// These rewriters lower from the Tosa to the Standard dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToStandard/TosaToStandard.h"
14+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
using namespace mlir;
20+
using namespace tosa;
21+
22+
namespace {
23+
24+
class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
25+
public:
26+
using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(tosa::ConstOp op,
29+
PatternRewriter &rewriter) const final {
30+
rewriter.replaceOpWithNewOp<::ConstantOp>(op, op.value());
31+
return success();
32+
}
33+
};
34+
35+
} // namespace
36+
37+
void mlir::tosa::populateTosaToStandardConversionPatterns(
38+
MLIRContext *context, OwningRewritePatternList *patterns) {
39+
patterns->insert<ConstOpConverter>(context);
40+
}

0 commit comments

Comments
 (0)