Skip to content

Commit 5aee156

Browse files
inbelicj2kun
andauthored
Reland: "[mlir][index][spirv] Add conversion for index to spirv" (#69790)
Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion. Fixes #63713 --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 5e458f5 commit 5aee156

File tree

12 files changed

+776
-3
lines changed

12 files changed

+776
-3
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
10+
#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
#include <memory>
14+
15+
namespace mlir {
16+
class RewritePatternSet;
17+
class SPIRVTypeConverter;
18+
class Pass;
19+
20+
#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
21+
#include "mlir/Conversion/Passes.h.inc"
22+
23+
namespace index {
24+
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
25+
RewritePatternSet &patterns);
26+
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
27+
} // namespace index
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
3636
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
3737
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
38+
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
3839
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
3940
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4041
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
644644
];
645645
}
646646

647+
//===----------------------------------------------------------------------===//
648+
// ConvertIndexToSPIRVPass
649+
//===----------------------------------------------------------------------===//
650+
651+
def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
652+
let summary = "Lower the `index` dialect to the `spirv` dialect.";
653+
let description = [{
654+
This pass lowers Index dialect operations to SPIR-V dialect operations.
655+
Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
656+
`ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
657+
specified by use-64bit-index.
658+
}];
659+
660+
let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
661+
662+
let options = [
663+
Option<"use64bitIndex", "use-64bit-index",
664+
"bool", /*default=*/"false",
665+
"Use 64-bit integers to convert index types">
666+
];
667+
}
668+
647669
//===----------------------------------------------------------------------===//
648670
// LinalgToStandard
649671
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
5555
/// values will be packed into one 32-bit value to be memory efficient.
5656
bool emulateLT32BitScalarTypes{true};
5757

58-
/// Use 64-bit integers to convert index types.
59-
bool use64bitIndex{false};
60-
6158
/// Whether to enable fast math mode during conversion. If true, various
6259
/// patterns would assume no NaN/infinity numbers as inputs, and thus there
6360
/// will be no special guards emitted to check and handle such cases.
6461
bool enableFastMathMode{false};
62+
63+
/// Use 64-bit integers when converting index types.
64+
bool use64bitIndex{false};
6565
};
6666

6767
/// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter {
7777
/// Gets the SPIR-V correspondence for the standard index type.
7878
Type getIndexType() const;
7979

80+
/// Gets the bitwidth of the index type when converted to SPIR-V.
81+
unsigned getIndexTypeBitwidth() const {
82+
return options.use64bitIndex ? 64 : 32;
83+
}
84+
8085
const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
8186

8287
/// Returns the options controlling the SPIR-V type converter.

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
2424
add_subdirectory(GPUToSPIRV)
2525
add_subdirectory(GPUToVulkan)
2626
add_subdirectory(IndexToLLVM)
27+
add_subdirectory(IndexToSPIRV)
2728
add_subdirectory(LinalgToStandard)
2829
add_subdirectory(LLVMCommon)
2930
add_subdirectory(MathToFuncs)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRIndexToSPIRV
2+
IndexToSPIRV.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRIndexDialect
15+
MLIRSPIRVConversion
16+
MLIRSPIRVDialect
17+
)

0 commit comments

Comments
 (0)