Skip to content

Commit 25925c1

Browse files
committed
[mlir][index][spirv] Add conversion for index to spirv
Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion. Fixes #63713
1 parent b15b846 commit 25925c1

File tree

10 files changed

+747
-3
lines changed

10 files changed

+747
-3
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
10+
#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
11+
12+
#include "mlir/Pass/Pass.h"
13+
#include <memory>
14+
15+
namespace mlir {
16+
class RewritePatternSet;
17+
class SPIRVTypeConverter;
18+
class Pass;
19+
20+
#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
21+
#include "mlir/Conversion/Passes.h.inc"
22+
23+
namespace index {
24+
void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
25+
RewritePatternSet &patterns);
26+
std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
27+
} // namespace index
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
3636
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
3737
#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
38+
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
3839
#include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
3940
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4041
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
644644
];
645645
}
646646

647+
//===----------------------------------------------------------------------===//
648+
// ConvertIndexToSPIRVPass
649+
//===----------------------------------------------------------------------===//
650+
651+
def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
652+
let summary = "Lower the `index` dialect to the `spirv` dialect.";
653+
let description = [{
654+
This pass lowers Index dialect operations to SPIR-V dialect operations.
655+
Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
656+
`ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
657+
specified by use-64bit-index.
658+
}];
659+
660+
let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
661+
662+
let options = [
663+
Option<"use64bitIndex", "use-64bit-index",
664+
"bool", /*default=*/"false",
665+
"Use 64-bit integers to convert index types">
666+
];
667+
}
668+
647669
//===----------------------------------------------------------------------===//
648670
// LinalgToStandard
649671
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
5555
/// values will be packed into one 32-bit value to be memory efficient.
5656
bool emulateLT32BitScalarTypes{true};
5757

58-
/// Use 64-bit integers to convert index types.
59-
bool use64bitIndex{false};
60-
6158
/// Whether to enable fast math mode during conversion. If true, various
6259
/// patterns would assume no NaN/infinity numbers as inputs, and thus there
6360
/// will be no special guards emitted to check and handle such cases.
6461
bool enableFastMathMode{false};
62+
63+
/// Use 64-bit integers when converting index types.
64+
bool use64bitIndex{false};
6565
};
6666

6767
/// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter {
7777
/// Gets the SPIR-V correspondence for the standard index type.
7878
Type getIndexType() const;
7979

80+
/// Gets the bitwidth of the index type when converted to SPIR-V.
81+
unsigned getIndexTypeBitwidth() const {
82+
return options.use64bitIndex ? 64 : 32;
83+
}
84+
8085
const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
8186

8287
/// Returns the options controlling the SPIR-V type converter.

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
2424
add_subdirectory(GPUToSPIRV)
2525
add_subdirectory(GPUToVulkan)
2626
add_subdirectory(IndexToLLVM)
27+
add_subdirectory(IndexToSPIRV)
2728
add_subdirectory(LinalgToStandard)
2829
add_subdirectory(LLVMCommon)
2930
add_subdirectory(MathToFuncs)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_mlir_conversion_library(MLIRIndexToSPIRV
2+
IndexToSPIRV.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRIndexDialect
15+
)

0 commit comments

Comments
 (0)