diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h new file mode 100644 index 0000000000000..58a1c5246eef9 --- /dev/null +++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h @@ -0,0 +1,30 @@ +//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H +#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H + +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +class RewritePatternSet; +class SPIRVTypeConverter; +class Pass; + +#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" + +namespace index { +void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter, + RewritePatternSet &patterns); +std::unique_ptr> createConvertIndexToSPIRVPass(); +} // namespace index +} // namespace mlir + +#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index e714f5070f23d..c13c457fd9749 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -35,6 +35,7 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 38b05c792d405..9979faed42517 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// ConvertIndexToSPIRVPass +//===----------------------------------------------------------------------===// + +def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> { + let summary = "Lower the `index` dialect to the `spirv` dialect."; + let description = [{ + This pass lowers Index dialect operations to SPIR-V dialect operations. + Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`, + `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as + specified by use-64bit-index. + }]; + + let dependentDialects = ["::mlir::spirv::SPIRVDialect"]; + + let options = [ + Option<"use64bitIndex", "use-64bit-index", + "bool", /*default=*/"false", + "Use 64-bit integers to convert index types"> + ]; +} + //===----------------------------------------------------------------------===// // LinalgToStandard //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 89ded981d38f9..933d62e35fce8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -55,13 +55,13 @@ struct SPIRVConversionOptions { /// values will be packed into one 32-bit value to be memory efficient. bool emulateLT32BitScalarTypes{true}; - /// Use 64-bit integers to convert index types. - bool use64bitIndex{false}; - /// Whether to enable fast math mode during conversion. If true, various /// patterns would assume no NaN/infinity numbers as inputs, and thus there /// will be no special guards emitted to check and handle such cases. bool enableFastMathMode{false}; + + /// Use 64-bit integers when converting index types. + bool use64bitIndex{false}; }; /// Type conversion from builtin types to SPIR-V types for shader interface. @@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter { /// Gets the SPIR-V correspondence for the standard index type. Type getIndexType() const; + /// Gets the bitwidth of the index type when converted to SPIR-V. + unsigned getIndexTypeBitwidth() const { + return options.use64bitIndex ? 64 : 32; + } + const spirv::TargetEnv &getTargetEnv() const { return targetEnv; } /// Returns the options controlling the SPIR-V type converter. diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 35790254be137..7e1c7bcf9a867 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL) add_subdirectory(GPUToSPIRV) add_subdirectory(GPUToVulkan) add_subdirectory(IndexToLLVM) +add_subdirectory(IndexToSPIRV) add_subdirectory(LinalgToStandard) add_subdirectory(LLVMCommon) add_subdirectory(MathToFuncs) diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt new file mode 100644 index 0000000000000..a02d2d40a9a2f --- /dev/null +++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRIndexToSPIRV + IndexToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIndexDialect + MLIRSPIRVConversion + MLIRSPIRVDialect + ) diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp new file mode 100644 index 0000000000000..b58efc096e2ea --- /dev/null +++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp @@ -0,0 +1,418 @@ +//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" +#include "../SPIRVCommon/Pattern.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace index; + +namespace { + +//===----------------------------------------------------------------------===// +// Trivial Conversions +//===----------------------------------------------------------------------===// + +using ConvertIndexAdd = spirv::ElementwiseOpPattern; +using ConvertIndexSub = spirv::ElementwiseOpPattern; +using ConvertIndexMul = spirv::ElementwiseOpPattern; +using ConvertIndexDivS = spirv::ElementwiseOpPattern; +using ConvertIndexDivU = spirv::ElementwiseOpPattern; +using ConvertIndexRemS = spirv::ElementwiseOpPattern; +using ConvertIndexRemU = spirv::ElementwiseOpPattern; +using ConvertIndexMaxS = spirv::ElementwiseOpPattern; +using ConvertIndexMaxU = spirv::ElementwiseOpPattern; +using ConvertIndexMinS = spirv::ElementwiseOpPattern; +using ConvertIndexMinU = spirv::ElementwiseOpPattern; + +using ConvertIndexShl = + spirv::ElementwiseOpPattern; +using ConvertIndexShrS = + spirv::ElementwiseOpPattern; +using ConvertIndexShrU = + spirv::ElementwiseOpPattern; + +/// It is the case that when we convert bitwise operations to SPIR-V operations +/// we must take into account the special pattern in SPIR-V that if the +/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise, +/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However, +/// index.add is never a boolean operation so we can directly convert it to the +/// Bitwise[And|Or]Op. +using ConvertIndexAnd = spirv::ElementwiseOpPattern; +using ConvertIndexOr = spirv::ElementwiseOpPattern; +using ConvertIndexXor = spirv::ElementwiseOpPattern; + +//===----------------------------------------------------------------------===// +// ConvertConstantBool +//===----------------------------------------------------------------------===// + +// Converts index.bool.constant operation to spirv.Constant. +struct ConvertIndexConstantBoolOpPattern final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + op.getValueAttr()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertConstant +//===----------------------------------------------------------------------===// + +// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32 +// when required. +struct ConvertIndexConstantOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *typeConverter = this->template getTypeConverter(); + Type indexType = typeConverter->getIndexType(); + + APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth()); + rewriter.replaceOpWithNewOp( + op, indexType, IntegerAttr::get(indexType, value)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexCeilDivS +//===----------------------------------------------------------------------===// + +/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then +/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent +/// conversion in IndexToLLVM. +struct ConvertIndexCeilDivSPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Type n_type = n.getType(); + Value m = adaptor.getRhs(); + + // Define the constants + Value zero = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, 0)); + Value posOne = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, 1)); + Value negOne = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, -1)); + + // Compute `x`. + Value mPos = rewriter.create(loc, m, zero); + Value x = rewriter.create(loc, mPos, negOne, posOne); + + // Compute the positive result. + Value nPlusX = rewriter.create(loc, n, x); + Value nPlusXDivM = rewriter.create(loc, nPlusX, m); + Value posRes = rewriter.create(loc, nPlusXDivM, posOne); + + // Compute the negative result. + Value negN = rewriter.create(loc, zero, n); + Value negNDivM = rewriter.create(loc, negN, m); + Value negRes = rewriter.create(loc, zero, negNDivM); + + // Pick the positive result if `n` and `m` have the same sign and `n` is + // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`. + Value nPos = rewriter.create(loc, n, zero); + Value sameSign = rewriter.create(loc, nPos, mPos); + Value nNonZero = rewriter.create(loc, n, zero); + Value cmp = rewriter.create(loc, sameSign, nNonZero); + rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexCeilDivU +//===----------------------------------------------------------------------===// + +/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken +/// from the equivalent conversion in IndexToLLVM. +struct ConvertIndexCeilDivUPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Type n_type = n.getType(); + Value m = adaptor.getRhs(); + + // Define the constants + Value zero = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, 0)); + Value one = rewriter.create(loc, n_type, + IntegerAttr::get(n_type, 1)); + + // Compute the non-zero result. + Value minusOne = rewriter.create(loc, n, one); + Value quotient = rewriter.create(loc, minusOne, m); + Value plusOne = rewriter.create(loc, quotient, one); + + // Pick the result + Value cmp = rewriter.create(loc, n, zero); + rewriter.replaceOpWithNewOp(op, cmp, zero, plusOne); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexFloorDivS +//===----------------------------------------------------------------------===// + +/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then +/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion +/// in IndexToLLVM. +struct ConvertIndexFloorDivSPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value n = adaptor.getLhs(); + Type n_type = n.getType(); + Value m = adaptor.getRhs(); + + // Define the constants + Value zero = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, 0)); + Value posOne = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, 1)); + Value negOne = rewriter.create( + loc, n_type, IntegerAttr::get(n_type, -1)); + + // Compute `x`. + Value mNeg = rewriter.create(loc, m, zero); + Value x = rewriter.create(loc, mNeg, posOne, negOne); + + // Compute the negative result + Value xMinusN = rewriter.create(loc, x, n); + Value xMinusNDivM = rewriter.create(loc, xMinusN, m); + Value negRes = rewriter.create(loc, negOne, xMinusNDivM); + + // Compute the positive result. + Value posRes = rewriter.create(loc, n, m); + + // Pick the negative result if `n` and `m` have different signs and `n` is + // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`. + Value nNeg = rewriter.create(loc, n, zero); + Value diffSign = rewriter.create(loc, nNeg, mNeg); + Value nNonZero = rewriter.create(loc, n, zero); + + Value cmp = rewriter.create(loc, diffSign, nNonZero); + rewriter.replaceOpWithNewOp(op, cmp, posRes, negRes); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexCast +//===----------------------------------------------------------------------===// + +/// Convert a cast op. If the materialized index type is the same as the other +/// type, fold away the op. Otherwise, use the Convert SPIR-V operation. +/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts +/// zero extend when the result bitwidth is larger. +template +struct ConvertIndexCast final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *typeConverter = this->template getTypeConverter(); + Type indexType = typeConverter->getIndexType(); + + Type srcType = adaptor.getInput().getType(); + Type dstType = op.getType(); + if (isa(srcType)) { + srcType = indexType; + } + if (isa(dstType)) { + dstType = indexType; + } + + if (srcType == dstType) { + rewriter.replaceOp(op, adaptor.getInput()); + } else { + rewriter.template replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); + } +}; + +using ConvertIndexCastS = ConvertIndexCast; +using ConvertIndexCastU = ConvertIndexCast; + +//===----------------------------------------------------------------------===// +// ConvertIndexCmp +//===----------------------------------------------------------------------===// + +// Helper template to replace the operation +template +static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), adaptor.getRhs()); + return success(); +} + +struct ConvertIndexCmpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // We must convert the predicates to the corresponding int comparions. + switch (op.getPred()) { + case IndexCmpPredicate::EQ: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::NE: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::SGE: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::SGT: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::SLE: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::SLT: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::UGE: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::UGT: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::ULE: + return rewriteCmpOp(op, adaptor, rewriter); + case IndexCmpPredicate::ULT: + return rewriteCmpOp(op, adaptor, rewriter); + } + } +}; + +//===----------------------------------------------------------------------===// +// ConvertIndexSizeOf +//===----------------------------------------------------------------------===// + +/// Lower `index.sizeof` to a constant with the value of the index bitwidth. +struct ConvertIndexSizeOf final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *typeConverter = this->template getTypeConverter(); + Type indexType = typeConverter->getIndexType(); + unsigned bitwidth = typeConverter->getIndexTypeBitwidth(); + rewriter.replaceOpWithNewOp( + op, indexType, IntegerAttr::get(indexType, bitwidth)); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// + +void index::populateIndexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertIndexAdd, + ConvertIndexSub, + ConvertIndexMul, + ConvertIndexDivS, + ConvertIndexDivU, + ConvertIndexRemS, + ConvertIndexRemU, + ConvertIndexMaxS, + ConvertIndexMaxU, + ConvertIndexMinS, + ConvertIndexMinU, + ConvertIndexShl, + ConvertIndexShrS, + ConvertIndexShrU, + ConvertIndexAnd, + ConvertIndexOr, + ConvertIndexXor, + ConvertIndexConstantBoolOpPattern, + ConvertIndexConstantOpPattern, + ConvertIndexCeilDivSPattern, + ConvertIndexCeilDivUPattern, + ConvertIndexFloorDivSPattern, + ConvertIndexCastS, + ConvertIndexCastU, + ConvertIndexCmpPattern, + ConvertIndexSizeOf + >(typeConverter, patterns.getContext()); +} + +//===----------------------------------------------------------------------===// +// ODS-Generated Definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +#define GEN_PASS_DEF_CONVERTINDEXTOSPIRVPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertIndexToSPIRVPass + : public impl::ConvertIndexToSPIRVPassBase { + using Base::Base; + + void runOnOperation() override { + Operation *op = getOperation(); + spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); + std::unique_ptr target = + SPIRVConversionTarget::get(targetAttr); + + SPIRVConversionOptions options; + options.use64bitIndex = this->use64bitIndex; + SPIRVTypeConverter typeConverter(targetAttr, options); + + // Use UnrealizedConversionCast as the bridge so that we don't need to pull + // in patterns for other dialects. + target->addLegalOp(); + + // Allow the spirv operations we are converting to + target->addLegalDialect(); + // Fail hard when there are any remaining 'index' ops. + target->addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + index::populateIndexToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(op, *target, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt index 987a1e84b74ab..e509fb025ee19 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRSCFToSPIRV LINK_LIBS PUBLIC MLIRArithToSPIRV MLIRFuncToSPIRV + MLIRIndexToSPIRV MLIRMemRefToSPIRV MLIRSPIRVDialect MLIRSPIRVConversion diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp index 1e8fe4423a422..3ef1d84ee2647 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" +#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h" #include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -52,6 +53,7 @@ void SCFToSPIRVPass::runOnOperation() { populateFuncToSPIRVPatterns(typeConverter, patterns); populateMemRefToSPIRVPatterns(typeConverter, patterns); populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); + index::populateIndexToSPIRVPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, *target, std::move(patterns)))) return signalPassFailure(); diff --git a/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir new file mode 100644 index 0000000000000..53dc896e98c7d --- /dev/null +++ b/mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir @@ -0,0 +1,222 @@ +// RUN: mlir-opt %s -convert-index-to-spirv | FileCheck %s +// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=false | FileCheck %s --check-prefix=INDEX32 +// RUN: mlir-opt %s -convert-index-to-spirv=use-64bit-index=true | FileCheck %s --check-prefix=INDEX64 + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { +// CHECK-LABEL: @trivial_ops +func.func @trivial_ops(%a: index, %b: index) { + // CHECK: spirv.IAdd + %0 = index.add %a, %b + // CHECK: spirv.ISub + %1 = index.sub %a, %b + // CHECK: spirv.IMul + %2 = index.mul %a, %b + // CHECK: spirv.SDiv + %3 = index.divs %a, %b + // CHECK: spirv.UDiv + %4 = index.divu %a, %b + // CHECK: spirv.SRem + %5 = index.rems %a, %b + // CHECK: spirv.UMod + %6 = index.remu %a, %b + // CHECK: spirv.GL.SMax + %7 = index.maxs %a, %b + // CHECK: spirv.GL.UMax + %8 = index.maxu %a, %b + // CHECK: spirv.GL.SMin + %9 = index.mins %a, %b + // CHECK: spirv.GL.UMin + %10 = index.minu %a, %b + // CHECK: spirv.ShiftLeftLogical + %11 = index.shl %a, %b + // CHECK: spirv.ShiftRightArithmetic + %12 = index.shrs %a, %b + // CHECK: spirv.ShiftRightLogical + %13 = index.shru %a, %b + return +} + +// CHECK-LABEL: @bitwise_ops +func.func @bitwise_ops(%a: index, %b: index) { + // CHECK: spirv.BitwiseAnd + %0 = index.and %a, %b + // CHECK: spirv.BitwiseOr + %1 = index.or %a, %b + // CHECK: spirv.BitwiseXor + %2 = index.xor %a, %b + return +} + +// INDEX32-LABEL: @constant_ops +// INDEX64-LABEL: @constant_ops +func.func @constant_ops() { + // INDEX32: spirv.Constant 42 : i32 + // INDEX64: spirv.Constant 42 : i64 + %0 = index.constant 42 + // INDEX32: spirv.Constant true + // INDEX64: spirv.Constant true + %1 = index.bool.constant true + // INDEX32: spirv.Constant false + // INDEX64: spirv.Constant false + %2 = index.bool.constant false + return +} + +// CHECK-LABEL: @ceildivs +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @ceildivs(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 + // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1 + // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1 + + // CHECK: %[[M_POS:.*]] = spirv.SGreaterThan %[[M]], %[[ZERO]] + // CHECK: %[[X:.*]] = spirv.Select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]] + + // CHECK: %[[N_PLUS_X:.*]] = spirv.IAdd %[[N]], %[[X]] + // CHECK: %[[N_PLUS_X_DIV_M:.*]] = spirv.SDiv %[[N_PLUS_X]], %[[M]] + // CHECK: %[[POS_RES:.*]] = spirv.IAdd %[[N_PLUS_X_DIV_M]], %[[POS_ONE]] + + // CHECK: %[[NEG_N:.*]] = spirv.ISub %[[ZERO]], %[[N]] + // CHECK: %[[NEG_N_DIV_M:.*]] = spirv.SDiv %[[NEG_N]], %[[M]] + // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[ZERO]], %[[NEG_N_DIV_M]] + + // CHECK: %[[N_POS:.*]] = spirv.SGreaterThan %[[N]], %[[ZERO]] + // CHECK: %[[SAME_SIGN:.*]] = spirv.LogicalEqual %[[N_POS]], %[[M_POS]] + // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]] + // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[SAME_SIGN]], %[[N_NON_ZERO]] + // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]] + %result = index.ceildivs %n, %m + + // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]] + // return %[[RESULTI]] + return %result : index +} + +// CHECK-LABEL: @ceildivu +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @ceildivu(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 + // CHECK: %[[ONE:.*]] = spirv.Constant 1 + + // CHECK: %[[N_MINUS_ONE:.*]] = spirv.ISub %[[N]], %[[ONE]] + // CHECK: %[[N_MINUS_ONE_DIV_M:.*]] = spirv.UDiv %[[N_MINUS_ONE]], %[[M]] + // CHECK: %[[N_MINUS_ONE_DIV_M_PLUS_ONE:.*]] = spirv.IAdd %[[N_MINUS_ONE_DIV_M]], %[[ONE]] + + // CHECK: %[[CMP:.*]] = spirv.IEqual %[[N]], %[[ZERO]] + // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[ZERO]], %[[N_MINUS_ONE_DIV_M_PLUS_ONE]] + %result = index.ceildivu %n, %m + + // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]] + // return %[[RESULTI]] + return %result : index +} + +// CHECK-LABEL: @floordivs +// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index +func.func @floordivs(%n: index, %m: index) -> index { + // CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]] + // CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]] + // CHECK: %[[ZERO:.*]] = spirv.Constant 0 + // CHECK: %[[POS_ONE:.*]] = spirv.Constant 1 + // CHECK: %[[NEG_ONE:.*]] = spirv.Constant -1 + + // CHECK: %[[M_NEG:.*]] = spirv.SLessThan %[[M]], %[[ZERO]] + // CHECK: %[[X:.*]] = spirv.Select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]] + + // CHECK: %[[X_MINUS_N:.*]] = spirv.ISub %[[X]], %[[N]] + // CHECK: %[[X_MINUS_N_DIV_M:.*]] = spirv.SDiv %[[X_MINUS_N]], %[[M]] + // CHECK: %[[NEG_RES:.*]] = spirv.ISub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]] + + // CHECK: %[[POS_RES:.*]] = spirv.SDiv %[[N]], %[[M]] + + // CHECK: %[[N_NEG:.*]] = spirv.SLessThan %[[N]], %[[ZERO]] + // CHECK: %[[DIFF_SIGN:.*]] = spirv.LogicalNotEqual %[[N_NEG]], %[[M_NEG]] + // CHECK: %[[N_NON_ZERO:.*]] = spirv.INotEqual %[[N]], %[[ZERO]] + + // CHECK: %[[CMP:.*]] = spirv.LogicalAnd %[[DIFF_SIGN]], %[[N_NON_ZERO]] + // CHECK: %[[RESULT:.*]] = spirv.Select %[[CMP]], %[[POS_RES]], %[[NEG_RES]] + %result = index.floordivs %n, %m + + // %[[RESULTI:.*] = builtin.unrealized_conversion_cast %[[RESULT]] + // return %[[RESULTI]] + return %result : index +} + +// CHECK-LABEL: @index_cmp +func.func @index_cmp(%a : index, %b : index) { + // CHECK: spirv.IEqual + %0 = index.cmp eq(%a, %b) + // CHECK: spirv.INotEqual + %1 = index.cmp ne(%a, %b) + + // CHECK: spirv.SLessThan + %2 = index.cmp slt(%a, %b) + // CHECK: spirv.SLessThanEqual + %3 = index.cmp sle(%a, %b) + // CHECK: spirv.SGreaterThan + %4 = index.cmp sgt(%a, %b) + // CHECK: spirv.SGreaterThanEqual + %5 = index.cmp sge(%a, %b) + + // CHECK: spirv.ULessThan + %6 = index.cmp ult(%a, %b) + // CHECK: spirv.ULessThanEqual + %7 = index.cmp ule(%a, %b) + // CHECK: spirv.UGreaterThan + %8 = index.cmp ugt(%a, %b) + // CHECK: spirv.UGreaterThanEqual + %9 = index.cmp uge(%a, %b) + return +} + +// CHECK-LABEL: @index_sizeof +func.func @index_sizeof() { + // CHECK: spirv.Constant 32 : i32 + %0 = index.sizeof + return +} + +// INDEX32-LABEL: @index_cast_from +// INDEX64-LABEL: @index_cast_from +// INDEX32-SAME: %[[AI:.*]]: index +// INDEX64-SAME: %[[AI:.*]]: index +func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) { + // INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32 + // INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64 + + // INDEX32: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64 + %0 = index.casts %a : index to i64 + // INDEX64: %[[V1:.*]] = spirv.SConvert %[[A]] : i64 to i32 + %1 = index.casts %a : index to i32 + // INDEX32: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64 + %2 = index.castu %a : index to i64 + // INDEX64: %[[V3:.*]] = spirv.UConvert %[[A]] : i64 to i32 + %3 = index.castu %a : index to i32 + + // INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]] + // INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]] + return %0, %1, %2, %3 : i64, i32, i64, i32 +} + +// INDEX32-LABEL: @index_cast_to +// INDEX64-LABEL: @index_cast_to +// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64 +// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64 +func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) { + // INDEX64: %[[V0:.*]] = spirv.SConvert %[[A]] : i32 to i64 + %0 = index.casts %a : i32 to index + // INDEX32: %[[V1:.*]] = spirv.SConvert %[[B]] : i64 to i32 + %1 = index.casts %b : i64 to index + // INDEX64: %[[V2:.*]] = spirv.UConvert %[[A]] : i32 to i64 + %2 = index.castu %a : i32 to index + // INDEX32: %[[V3:.*]] = spirv.UConvert %[[B]] : i64 to i32 + %3 = index.castu %b : i64 to index + return %0, %1, %2, %3 : index, index, index, index +} +} diff --git a/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir new file mode 100644 index 0000000000000..68a825fbd93eb --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/use-indices.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s + +// CHECK-LABEL: @forward +func.func @forward() { + // CHECK: %[[LB:.*]] = spirv.Constant 0 : i32 + %c0 = arith.constant 0 : index + // CHECK: %[[UB:.*]] = spirv.Constant 32 : i32 + %c32 = arith.constant 32 : index + // CHECK: %[[STEP:.*]] = spirv.Constant 1 : i32 + %c1 = arith.constant 1 : index + + // CHECK: spirv.mlir.loop { + // CHECK-NEXT: spirv.Branch ^[[HEADER:.*]](%[[LB]] : i32) + // CHECK: ^[[HEADER]](%[[INDVAR:.*]]: i32): + // CHECK: %[[CMP:.*]] = spirv.SLessThan %[[INDVAR]], %[[UB]] : i32 + // CHECK: spirv.BranchConditional %[[CMP]], ^[[BODY:.*]], ^[[MERGE:.*]] + // CHECK: ^[[BODY]]: + // CHECK: %[[X:.*]] = spirv.IAdd %[[INDVAR]], %[[INDVAR]] : i32 + // CHECK: %[[INDNEXT:.*]] = spirv.IAdd %[[INDVAR]], %[[STEP]] : i32 + // CHECK: spirv.Branch ^[[HEADER]](%[[INDNEXT]] : i32) + // CHECK: ^[[MERGE]]: + // CHECK: spirv.mlir.merge + // CHECK: } + scf.for %arg2 = %c0 to %c32 step %c1 { + %1 = index.add %arg2, %arg2 + } + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 51ea4a28cc8fa..8179f19bb3bd1 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7350,6 +7350,7 @@ cc_library( ":FuncDialect", ":FuncToSPIRV", ":IR", + ":IndexToSPIRV", ":MemRefToSPIRV", ":Pass", ":SCFDialect", @@ -9766,6 +9767,31 @@ cc_library( ], ) +cc_library( + name = "IndexToSPIRV", + srcs = glob([ + "lib/Conversion/IndexToSPIRV/*.cpp", + "lib/Conversion/IndexToSPIRV/*.h", + ]), + hdrs = glob([ + "include/mlir/Conversion/IndexToSPIRV/*.h", + ]), + includes = ["include"], + deps = [ + ":ConversionPassIncGen", + ":IR", + ":IndexDialect", + ":Pass", + ":SPIRVCommonConversion", + ":SPIRVConversion", + ":SPIRVDialect", + ":Support", + ":Transforms", + "//llvm:Core", + "//llvm:Support", + ], +) + cc_library( name = "IndexDialect", srcs = glob(["lib/Dialect/Index/IR/*.cpp"]),