Skip to content

[mlir][index][spirv] Add conversion for index to spirv #68085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 20, 2023

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Oct 3, 2023

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713

@@ -693,6 +693,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });

addConversion([this](IntegerType intType) -> std::optional<Type> {
if (!this->options.convertIntAsScalar)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of this hidden option is currently a hack solution, as when we pass in the integer type we fail the linked check, which will then cause us to return a nullptr. Proper implementation would presumably ensure that we are passing the compatibility and extension requirements. So hints for how to solve this are needed.

https://github.com/llvm/llvm-project/blame/77c43e14897a404fbf4a132a1a75d49ba2ec08c1/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp#L232-L235.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we resolve this by adding a module with the right target env around the code in test cases? Have you checked where the const spirv::TargetEnv &targetEnv parameter in comes from in convertScalarType?

@inbelic inbelic marked this pull request as draft October 5, 2023 09:31
@inbelic inbelic marked this pull request as ready for review October 5, 2023 09:32
@llvmbot
Copy link
Member

llvmbot commented Oct 6, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Changes

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713


Patch is 34.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68085.diff

11 Files Affected:

  • (added) mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h (+28)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+22)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+15-3)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt (+15)
  • (added) mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp (+418)
  • (modified) mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp (+2)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+2)
  • (added) mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir (+218)
  • (added) mlir/test/Conversion/SCFToSPIRV/use-indices.mlir (+28)
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
new file mode 100644
index 000000000000000..d1a3c87249508b7
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -0,0 +1,28 @@
+//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class SPIRVTypeConverter;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                  RewritePatternSet &patterns);
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index fc5e9adba114405..9660d89ec23e3be 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 11008baa0160efe..1e45abb66880c12 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -629,6 +629,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToSPIRVPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
+  let summary = "Lower the `index` dialect to the `spirv` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to SPIR-V dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
+    specified by use-64bit-index.
+  }];
+
+  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
+
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToStandard
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 89ded981d38f9f4..4a4e58464a80df7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,13 +55,20 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Use 64-bit integers to convert index types.
-  bool use64bitIndex{false};
-
   /// Whether to enable fast math mode during conversion. If true, various
   /// patterns would assume no NaN/infinity numbers as inputs, and thus there
   /// will be no special guards emitted to check and handle such cases.
   bool enableFastMathMode{false};
+
+  /// Use 64-bit integers when converting index types.
+  bool use64bitIndex{false};
+
+  /// Whether we should treat an integer type as a scalar value within the
+  /// SPIR-V type converter. Used when we need to check if the integer type is a
+  /// supported bitwidth, as described above in emulateLT32BitScalarTypes.
+  /// Turned off when we are converting from index to SPIR-V as it will be an
+  /// i32 or i64.
+  bool convertIntAsScalar{true};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +84,11 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
+  /// Gets the bitwidth of the index type when converted to SPIR-V.
+  unsigned getIndexTypeBitwidth() const {
+    return options.use64bitIndex ? 64 : 32;
+  }
+
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 275e095245e89ce..8dad4c5fa25916a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -23,6 +23,7 @@ add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
+add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000000..1da0e0253501fec
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -0,0 +1,15 @@
+add_mlir_conversion_library(MLIRIndexToSPIRV
+  IndexToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  )
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
new file mode 100644
index 000000000000000..4290a81ae813824
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -0,0 +1,418 @@
+//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
+using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
+using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
+using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
+using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
+using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
+using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
+using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
+using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
+using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
+using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
+
+using ConvertIndexShl =
+    spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
+using ConvertIndexShrS =
+    spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
+using ConvertIndexShrU =
+    spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
+
+/// It is the case that when we convert bitwise operations to SPIR-V operations
+/// we must take into account of the special pattern in SPIR-V that if the
+/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
+/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
+/// index.add is never a boolean operation so we can directly convert it to the
+/// Bitwise[And|Or]Op
+using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
+using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
+using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertConstantBool
+//===----------------------------------------------------------------------===//
+
+// Converts index.bool.constant operation to spirv.Constant.
+struct ConvertIndexConstantBoolOpPattern final
+    : OpConversionPattern<BoolConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
+                                                   op->getAttr("value"));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
+// when required.
+struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
+struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
+    Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
+struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
+                                                   IntegerAttr::get(n_type, 1));
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
+    Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+
+    // Pick the result
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
+struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result
+    Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have different signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
+    Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ConvertOp>
+struct ConvertIndexCast : public OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    Type srcType = adaptor.getInput().getType();
+    Type dstType = op.getType();
+    if (isa<IndexType>(srcType)) {
+      srcType = indexType;
+    }
+    if (isa<IndexType>(dstType)) {
+      dstType = indexType;
+    }
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
+                                                      adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+// Helper template to replace the operation
+template <typename ICmpOp>
+static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
+  return success();
+}
+
+struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We must convert the predicates to the corresponding int comparions.
+    switch (op.getPred()) {
+    case IndexCmpPredicate::EQ:
+      return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::NE:
+      return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGE:
+      return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGT:
+      return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLE:
+      return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLT:
+      return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGE:
+      return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGT:
+      return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULE:
+      return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULT:
+      return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bi...
[truncated]

@kuhar
Copy link
Member

kuhar commented Oct 6, 2023

Awesome, I will take a look after I'm done with the LLVM Dev Mtg and the OpenXLA summit. Apologies for the delay.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks solid, I'd just like to confirm if we need the hack for scalar types.

Also, is type converter allowed to fail in getIndexType()? If we yes, we are missing nullptr checks for failed index queries. My guess is that for any reasonable type converter for SPIR-V, i32 should be always available.

@@ -693,6 +693,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](IndexType /*indexType*/) { return getIndexType(); });

addConversion([this](IntegerType intType) -> std::optional<Type> {
if (!this->options.convertIntAsScalar)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we resolve this by adding a module with the right target env around the code in test cases? Have you checked where the const spirv::TargetEnv &targetEnv parameter in comes from in convertScalarType?

@inbelic
Copy link
Contributor Author

inbelic commented Oct 12, 2023

Have removed the use of the hack and the options. It was solved by adding Int64 to the list of capabilities of the SPIRV Target Environment in the test case.

Regarding getIndexType, it will determine if we are using a 32 or 64 and then return IntegerType::get(...). Under the assumption that IntegerType::get(...) is not allowed to fail then the index will not either. I believe that is the case? Is there any pattern that denotes if a function is allowed to fail or not?

@kuhar
Copy link
Member

kuhar commented Oct 19, 2023

Regarding getIndexType, it will determine if we are using a 32 or 64 and then return IntegerType::get(...). Under the assumption that IntegerType::get(...) is not allowed to fail then the index will not either. I believe that is the case? Is there any pattern that denotes if a function is allowed to fail or not?

OK, SGTM, thanks for checking.

@kuhar
Copy link
Member

kuhar commented Oct 19, 2023

I believe that is the case? Is there any pattern that denotes if a function is allowed to fail or not?

In general things that return Value/Type may return nullptr. This is definitely the case for convertType(...). For getIndexType I think it's reasonable to assume success.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Could you rebase the code before we merge this, @inbelic?

Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a
conversion pass from the Index dialect to the SPIR-V dialect.
Furthermore, we add the new conversion patterns to the scf-to-spirv
conversion.

Fixes llvm#63713
@inbelic inbelic force-pushed the inbelic/conv-index-to-spirv branch from ff84061 to 25925c1 Compare October 20, 2023 12:10
@kuhar kuhar merged commit 3c07a21 into llvm:main Oct 20, 2023
@joker-eph
Copy link
Collaborator

Seems like this broke the bot: https://lab.llvm.org/buildbot/#/builders/61/builds/50716

kuhar added a commit that referenced this pull request Oct 20, 2023
@kuhar
Copy link
Member

kuhar commented Oct 20, 2023

Thanks for surfacing, @joker-eph. Reverted in 8b02ceb.

@inbelic could you take a look at this buildbot failure?

@inbelic
Copy link
Contributor Author

inbelic commented Oct 20, 2023

@kuhar Yep, starting to investigate now.

@inbelic
Copy link
Contributor Author

inbelic commented Oct 20, 2023

Was unable to reproduce the build error locally, but my guess is that the commit was missing the SPIR-V dialect in our linked libs of the CMake file, new pull request #69790 adds this. Need some help to verify that this is the case. Is there a way to test the build on a given branch/pr?

@kuhar
Copy link
Member

kuhar commented Oct 21, 2023

Was unable to reproduce the build error locally, but my guess is that the commit was missing the SPIR-V dialect in our linked libs of the CMake file, new pull request #69790 adds this. Need some help to verify that this is the case. Is there a way to test the build on a given branch/pr?

This should be reproducible with -DBUILD_SHARED_LIBS=1: https://lab.llvm.org/buildbot/#/builders/61/builds/50716/steps/4/logs/stdio

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Support index to spir-v dialect conversion
4 participants