Skip to content

Commit 0ca9ad3

Browse files
angelz913keith
authored andcommitted
Reland "[mlir][spirv] Add a generic convert-to-spirv pass" (llvm#96359)
This PR relands llvm#95942, which was reverted in llvm#96332 due to link failures. It fixes the issue by updating CMake dependencies. The bazel support, originally introduced in llvm#96334, is also included in this PR. --------- Co-authored-by: Keith Smiley <[email protected]>
1 parent 5862570 commit 0ca9ad3

File tree

14 files changed

+1005
-0
lines changed

14 files changed

+1005
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
10+
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
20+
} // namespace mlir
21+
22+
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
3131
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
3232
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
33+
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
3334
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
3435
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
3536
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
3131
];
3232
}
3333

34+
//===----------------------------------------------------------------------===//
35+
// ToSPIRV
36+
//===----------------------------------------------------------------------===//
37+
38+
def ConvertToSPIRVPass : Pass<"convert-to-spirv"> {
39+
let summary = "Convert to SPIR-V";
40+
let description = [{
41+
This is a generic pass to convert to SPIR-V.
42+
}];
43+
let dependentDialects = ["spirv::SPIRVDialect"];
44+
}
45+
3446
//===----------------------------------------------------------------------===//
3547
// AffineToStandard
3648
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_subdirectory(ControlFlowToLLVM)
1919
add_subdirectory(ControlFlowToSCF)
2020
add_subdirectory(ControlFlowToSPIRV)
2121
add_subdirectory(ConvertToLLVM)
22+
add_subdirectory(ConvertToSPIRV)
2223
add_subdirectory(FuncToEmitC)
2324
add_subdirectory(FuncToLLVM)
2425
add_subdirectory(FuncToSPIRV)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
set(LLVM_OPTIONAL_SOURCES
2+
ConvertToSPIRVPass.cpp
3+
)
4+
5+
add_mlir_conversion_library(MLIRConvertToSPIRVPass
6+
ConvertToSPIRVPass.cpp
7+
8+
ADDITIONAL_HEADER_DIRS
9+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV
10+
11+
DEPENDS
12+
MLIRConversionPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
MLIRArithToSPIRV
16+
MLIRArithTransforms
17+
MLIRFuncToSPIRV
18+
MLIRIndexToSPIRV
19+
MLIRIR
20+
MLIRPass
21+
MLIRRewrite
22+
MLIRSCFToSPIRV
23+
MLIRSPIRVConversion
24+
MLIRSPIRVDialect
25+
MLIRSPIRVTransforms
26+
MLIRSupport
27+
MLIRTransforms
28+
MLIRTransformUtils
29+
MLIRUBToSPIRV
30+
MLIRVectorToSPIRV
31+
MLIRVectorTransforms
32+
)
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
10+
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
11+
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
12+
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
13+
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
14+
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
15+
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
16+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19+
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
21+
#include "mlir/IR/PatternMatch.h"
22+
#include "mlir/Pass/Pass.h"
23+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
24+
#include "mlir/Transforms/DialectConversion.h"
25+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26+
#include <memory>
27+
28+
#define DEBUG_TYPE "convert-to-spirv"
29+
30+
namespace mlir {
31+
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
32+
#include "mlir/Conversion/Passes.h.inc"
33+
} // namespace mlir
34+
35+
using namespace mlir;
36+
37+
namespace {
38+
39+
/// A pass to perform the SPIR-V conversion.
40+
struct ConvertToSPIRVPass final
41+
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
42+
43+
void runOnOperation() override {
44+
MLIRContext *context = &getContext();
45+
Operation *op = getOperation();
46+
47+
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
48+
SPIRVTypeConverter typeConverter(targetAttr);
49+
50+
RewritePatternSet patterns(context);
51+
ScfToSPIRVContext scfToSPIRVContext;
52+
53+
// Populate patterns.
54+
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
55+
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
56+
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
57+
populateFuncToSPIRVPatterns(typeConverter, patterns);
58+
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
59+
populateVectorToSPIRVPatterns(typeConverter, patterns);
60+
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
61+
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
62+
63+
std::unique_ptr<ConversionTarget> target =
64+
SPIRVConversionTarget::get(targetAttr);
65+
66+
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
67+
return signalPassFailure();
68+
}
69+
};
70+
71+
} // namespace
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s
2+
3+
//===----------------------------------------------------------------------===//
4+
// arithmetic ops
5+
//===----------------------------------------------------------------------===//
6+
7+
// CHECK-LABEL: @int32_scalar
8+
func.func @int32_scalar(%lhs: i32, %rhs: i32) {
9+
// CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32
10+
%0 = arith.addi %lhs, %rhs: i32
11+
// CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32
12+
%1 = arith.subi %lhs, %rhs: i32
13+
// CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32
14+
%2 = arith.muli %lhs, %rhs: i32
15+
// CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32
16+
%3 = arith.divsi %lhs, %rhs: i32
17+
// CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32
18+
%4 = arith.divui %lhs, %rhs: i32
19+
// CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32
20+
%5 = arith.remui %lhs, %rhs: i32
21+
return
22+
}
23+
24+
// CHECK-LABEL: @int32_scalar_srem
25+
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
26+
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) {
27+
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32
28+
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32
29+
// CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32
30+
// CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32
31+
// CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32
32+
// CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32
33+
%0 = arith.remsi %lhs, %rhs: i32
34+
return
35+
}
36+
37+
// -----
38+
39+
//===----------------------------------------------------------------------===//
40+
// arith bit ops
41+
//===----------------------------------------------------------------------===//
42+
43+
// CHECK-LABEL: @bitwise_scalar
44+
func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) {
45+
// CHECK: spirv.BitwiseAnd
46+
%0 = arith.andi %arg0, %arg1 : i32
47+
// CHECK: spirv.BitwiseOr
48+
%1 = arith.ori %arg0, %arg1 : i32
49+
// CHECK: spirv.BitwiseXor
50+
%2 = arith.xori %arg0, %arg1 : i32
51+
return
52+
}
53+
54+
// CHECK-LABEL: @bitwise_vector
55+
func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
56+
// CHECK: spirv.BitwiseAnd
57+
%0 = arith.andi %arg0, %arg1 : vector<4xi32>
58+
// CHECK: spirv.BitwiseOr
59+
%1 = arith.ori %arg0, %arg1 : vector<4xi32>
60+
// CHECK: spirv.BitwiseXor
61+
%2 = arith.xori %arg0, %arg1 : vector<4xi32>
62+
return
63+
}
64+
65+
// CHECK-LABEL: @logical_scalar
66+
func.func @logical_scalar(%arg0 : i1, %arg1 : i1) {
67+
// CHECK: spirv.LogicalAnd
68+
%0 = arith.andi %arg0, %arg1 : i1
69+
// CHECK: spirv.LogicalOr
70+
%1 = arith.ori %arg0, %arg1 : i1
71+
// CHECK: spirv.LogicalNotEqual
72+
%2 = arith.xori %arg0, %arg1 : i1
73+
return
74+
}
75+
76+
// CHECK-LABEL: @logical_vector
77+
func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
78+
// CHECK: spirv.LogicalAnd
79+
%0 = arith.andi %arg0, %arg1 : vector<4xi1>
80+
// CHECK: spirv.LogicalOr
81+
%1 = arith.ori %arg0, %arg1 : vector<4xi1>
82+
// CHECK: spirv.LogicalNotEqual
83+
%2 = arith.xori %arg0, %arg1 : vector<4xi1>
84+
return
85+
}
86+
87+
// CHECK-LABEL: @shift_scalar
88+
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
89+
// CHECK: spirv.ShiftLeftLogical
90+
%0 = arith.shli %arg0, %arg1 : i32
91+
// CHECK: spirv.ShiftRightArithmetic
92+
%1 = arith.shrsi %arg0, %arg1 : i32
93+
// CHECK: spirv.ShiftRightLogical
94+
%2 = arith.shrui %arg0, %arg1 : i32
95+
return
96+
}
97+
98+
// CHECK-LABEL: @shift_vector
99+
func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
100+
// CHECK: spirv.ShiftLeftLogical
101+
%0 = arith.shli %arg0, %arg1 : vector<4xi32>
102+
// CHECK: spirv.ShiftRightArithmetic
103+
%1 = arith.shrsi %arg0, %arg1 : vector<4xi32>
104+
// CHECK: spirv.ShiftRightLogical
105+
%2 = arith.shrui %arg0, %arg1 : vector<4xi32>
106+
return
107+
}
108+
109+
// -----
110+
111+
//===----------------------------------------------------------------------===//
112+
// arith.cmpf
113+
//===----------------------------------------------------------------------===//
114+
115+
// CHECK-LABEL: @cmpf
116+
func.func @cmpf(%arg0 : f32, %arg1 : f32) {
117+
// CHECK: spirv.FOrdEqual
118+
%1 = arith.cmpf oeq, %arg0, %arg1 : f32
119+
return
120+
}
121+
122+
// CHECK-LABEL: @vec1cmpf
123+
func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) {
124+
// CHECK: spirv.FOrdGreaterThan
125+
%0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32>
126+
// CHECK: spirv.FUnordLessThan
127+
%1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32>
128+
return
129+
}
130+
131+
// -----
132+
133+
//===----------------------------------------------------------------------===//
134+
// arith.cmpi
135+
//===----------------------------------------------------------------------===//
136+
137+
// CHECK-LABEL: @cmpi
138+
func.func @cmpi(%arg0 : i32, %arg1 : i32) {
139+
// CHECK: spirv.IEqual
140+
%0 = arith.cmpi eq, %arg0, %arg1 : i32
141+
return
142+
}
143+
144+
// CHECK-LABEL: @indexcmpi
145+
func.func @indexcmpi(%arg0 : index, %arg1 : index) {
146+
// CHECK: spirv.IEqual
147+
%0 = arith.cmpi eq, %arg0, %arg1 : index
148+
return
149+
}
150+
151+
// CHECK-LABEL: @vec1cmpi
152+
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) {
153+
// CHECK: spirv.ULessThan
154+
%0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32>
155+
// CHECK: spirv.SGreaterThan
156+
%1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32>
157+
return
158+
}
159+
160+
// CHECK-LABEL: @boolcmpi_equality
161+
func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) {
162+
// CHECK: spirv.LogicalEqual
163+
%0 = arith.cmpi eq, %arg0, %arg1 : i1
164+
// CHECK: spirv.LogicalNotEqual
165+
%1 = arith.cmpi ne, %arg0, %arg1 : i1
166+
return
167+
}
168+
169+
// CHECK-LABEL: @boolcmpi_unsigned
170+
func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) {
171+
// CHECK-COUNT-2: spirv.Select
172+
// CHECK: spirv.UGreaterThanEqual
173+
%0 = arith.cmpi uge, %arg0, %arg1 : i1
174+
// CHECK-COUNT-2: spirv.Select
175+
// CHECK: spirv.ULessThan
176+
%1 = arith.cmpi ult, %arg0, %arg1 : i1
177+
return
178+
}
179+
180+
// CHECK-LABEL: @vec1boolcmpi_equality
181+
func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
182+
// CHECK: spirv.LogicalEqual
183+
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1>
184+
// CHECK: spirv.LogicalNotEqual
185+
%1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1>
186+
return
187+
}
188+
189+
// CHECK-LABEL: @vec1boolcmpi_unsigned
190+
func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) {
191+
// CHECK-COUNT-2: spirv.Select
192+
// CHECK: spirv.UGreaterThanEqual
193+
%0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1>
194+
// CHECK-COUNT-2: spirv.Select
195+
// CHECK: spirv.ULessThan
196+
%1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1>
197+
return
198+
}
199+
200+
// CHECK-LABEL: @vecboolcmpi_equality
201+
func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
202+
// CHECK: spirv.LogicalEqual
203+
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1>
204+
// CHECK: spirv.LogicalNotEqual
205+
%1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1>
206+
return
207+
}
208+
209+
// CHECK-LABEL: @vecboolcmpi_unsigned
210+
func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) {
211+
// CHECK-COUNT-2: spirv.Select
212+
// CHECK: spirv.UGreaterThanEqual
213+
%0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1>
214+
// CHECK-COUNT-2: spirv.Select
215+
// CHECK: spirv.ULessThan
216+
%1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1>
217+
return
218+
}

0 commit comments

Comments
 (0)