Skip to content

Commit cc9eb3a

Browse files
committed
[mlir][spirv] Add a generic "convert-to-spirv" pass
This commit implements a MVP version of an MLIR lowering pipeline to SPIR-V. The goal is to have a better test coverage of SPIR-V compilation upstream, and enable writing simple kernels by hand. The dialects supported in this version include arith, vector (only 1-D vectors with size 2,3,4,8 or 16), scf, ub, index, func and math.
1 parent 79e668f commit cc9eb3a

File tree

13 files changed

+1067
-0
lines changed

13 files changed

+1067
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
10+
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
20+
} // namespace mlir
21+
22+
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
3131
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
3232
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
33+
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
3334
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
3435
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
3536
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
3131
];
3232
}
3333

34+
//===----------------------------------------------------------------------===//
35+
// ToSPIRV
36+
//===----------------------------------------------------------------------===//
37+
38+
def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
39+
let summary = "Convert to SPIR-V";
40+
let description = [{
41+
This is a generic pass to convert to SPIR-V.
42+
}];
43+
let dependentDialects = ["spirv::SPIRVDialect"];
44+
}
45+
3446
//===----------------------------------------------------------------------===//
3547
// AffineToStandard
3648
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
1919
add_subdirectory(ControlFlowToSCF)
2020
add_subdirectory(ControlFlowToSPIRV)
2121
add_subdirectory(ConvertToLLVM)
22+
add_subdirectory(ConvertToSPIRV)
2223
add_subdirectory(FuncToEmitC)
2324
add_subdirectory(FuncToLLVM)
2425
add_subdirectory(FuncToSPIRV)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
set(LLVM_OPTIONAL_SOURCES
2+
ConvertToSPIRVPass.cpp
3+
)
4+
5+
add_mlir_conversion_library(MLIRConvertToSPIRVPass
6+
ConvertToSPIRVPass.cpp
7+
8+
ADDITIONAL_HEADER_DIRS
9+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV
10+
11+
DEPENDS
12+
MLIRConversionPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
MLIRIR
16+
MLIRPass
17+
MLIRRewrite
18+
MLIRSPIRVConversion
19+
MLIRSPIRVDialect
20+
MLIRSupport
21+
MLIRTransformUtils
22+
)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
10+
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
11+
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
12+
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
13+
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14+
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
15+
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
16+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Pass/Pass.h"
23+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
24+
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
#include <memory>
27+
28+
#define DEBUG_TYPE "convert-to-spirv"
29+
30+
namespace mlir {
31+
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
32+
#include "mlir/Conversion/Passes.h.inc"
33+
} // namespace mlir
34+
35+
using namespace mlir;
36+
37+
namespace {
38+
39+
/// A pass to perform the SPIR-V conversion.
40+
struct ConvertToSPIRVPass final
41+
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
42+
43+
public:
44+
void runOnOperation() final {
45+
MLIRContext *context = &getContext();
46+
Operation *op = getOperation();
47+
48+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
49+
SPIRVTypeConverter typeConverter(targetAttr);
50+
51+
RewritePatternSet patterns(context);
52+
ScfToSPIRVContext scfToSPIRVContext;
53+
54+
// Populate patterns.
55+
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
56+
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
57+
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
58+
populateFuncToSPIRVPatterns(typeConverter, patterns);
59+
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
60+
populateVectorToSPIRVPatterns(typeConverter, patterns);
61+
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
62+
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
63+
64+
std::unique_ptr<ConversionTarget> target =
65+
SPIRVConversionTarget::get(targetAttr);
66+
67+
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
68+
return signalPassFailure();
69+
}
70+
};
71+
72+
} // namespace

0 commit comments

Comments
 (0)