|
20 | 20 | #include "mlir/Dialect/Traits.h"
|
21 | 21 | #include "mlir/IR/Matchers.h"
|
22 | 22 | #include "mlir/Transforms/DialectConversion.h"
|
| 23 | +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
23 | 24 | #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
24 | 25 | #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
25 | 26 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
@@ -163,6 +164,37 @@ static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
163 | 164 | b.getStringAttr("mismatching contracting dimension"));
|
164 | 165 | }
|
165 | 166 |
|
| 167 | +template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred, |
| 168 | + arith::CmpIPredicate ispred> |
| 169 | +static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type, |
| 170 | + Value lhs, Value rhs) { |
| 171 | + if (type.isa<mlir::FloatType>()) |
| 172 | + return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs); |
| 173 | + if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) { |
| 174 | + if (intType.isUnsigned()) |
| 175 | + return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs); |
| 176 | + if (intType.isSigned()) |
| 177 | + return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs); |
| 178 | + } |
| 179 | + assert(false && "Unhandled element type for comparison"); |
| 180 | +} |
| 181 | + |
| 182 | +static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType, |
| 183 | + Value lhs, Value rhs) { |
| 184 | + return createComparisonTemplate<arith::CmpFPredicate::UGT, |
| 185 | + arith::CmpIPredicate::ugt, |
| 186 | + arith::CmpIPredicate::sgt>( |
| 187 | + b, loc, elementalType, lhs, rhs); |
| 188 | +} |
| 189 | + |
| 190 | +static Value createLessThan(OpBuilder &b, Location loc, Type elementalType, |
| 191 | + Value lhs, Value rhs) { |
| 192 | + return createComparisonTemplate<arith::CmpFPredicate::ULT, |
| 193 | + arith::CmpIPredicate::ult, |
| 194 | + arith::CmpIPredicate::slt>( |
| 195 | + b, loc, elementalType, lhs, rhs); |
| 196 | +} |
| 197 | + |
166 | 198 | static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
167 | 199 | Value tensor, int dim) {
|
168 | 200 | RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
@@ -2072,20 +2104,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
2072 | 2104 |
|
2073 | 2105 | Type elementalType =
|
2074 | 2106 | gtTensor.self().getType().cast<BaseTensorType>().getDtype();
|
2075 |
| - |
2076 |
| - if (elementalType.isa<mlir::FloatType>()) |
2077 |
| - return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, |
2078 |
| - payloadArgs[0], payloadArgs[1]); |
2079 |
| - if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) { |
2080 |
| - if (intType.isUnsigned()) |
2081 |
| - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt, |
2082 |
| - payloadArgs[0], payloadArgs[1]); |
2083 |
| - if (intType.isSigned()) |
2084 |
| - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, |
2085 |
| - payloadArgs[0], payloadArgs[1]); |
2086 |
| - } |
2087 |
| - gtTensor.emitError("unimplemented: dtype isn't supported."); |
2088 |
| - return nullptr; |
| 2107 | + return createGreaterThan(b, loc, elementalType, payloadArgs[0], |
| 2108 | + payloadArgs[1]); |
2089 | 2109 | }
|
2090 | 2110 | if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
2091 | 2111 | AtenEqTensorOp::Adaptor adaptor(operands);
|
@@ -2126,20 +2146,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
2126 | 2146 |
|
2127 | 2147 | Type elementalType =
|
2128 | 2148 | ltTensor.self().getType().cast<BaseTensorType>().getDtype();
|
2129 |
| - |
2130 |
| - if (elementalType.isa<mlir::FloatType>()) |
2131 |
| - return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, |
2132 |
| - payloadArgs[0], payloadArgs[1]); |
2133 |
| - if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) { |
2134 |
| - if (intType.isUnsigned()) |
2135 |
| - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, |
2136 |
| - payloadArgs[0], payloadArgs[1]); |
2137 |
| - if (intType.isSigned()) |
2138 |
| - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, |
2139 |
| - payloadArgs[0], payloadArgs[1]); |
2140 |
| - } |
2141 |
| - ltTensor.emitError("unimplemented: dtype isn't supported."); |
2142 |
| - return nullptr; |
| 2149 | + return createLessThan(b, loc, elementalType, payloadArgs[0], |
| 2150 | + payloadArgs[1]); |
2143 | 2151 | }
|
2144 | 2152 | if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
2145 | 2153 | AtenDivTensorOp::Adaptor adaptor(operands);
|
@@ -2329,28 +2337,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
2329 | 2337 | return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
2330 | 2338 | }
|
2331 | 2339 | if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
2332 |
| - if (!minimum.getType() |
2333 |
| - .cast<ValueTensorType>() |
2334 |
| - .getDtype() |
2335 |
| - .isa<mlir::FloatType>()) { |
2336 |
| - minimum.emitError("unimplemented: non-floating point dtype"); |
2337 |
| - return nullptr; |
2338 |
| - } |
2339 |
| - Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, |
2340 |
| - payloadArgs[0], payloadArgs[1]); |
2341 |
| - return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]); |
| 2340 | + Type dtype = minimum.getType().cast<BaseTensorType>().getDtype(); |
| 2341 | + Type elemTy = converter->convertType(minimum.getType()) |
| 2342 | + .cast<RankedTensorType>() |
| 2343 | + .getElementType(); |
| 2344 | + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); |
| 2345 | + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); |
| 2346 | + Value pred = createLessThan(b, loc, dtype, lhs, rhs); |
| 2347 | + return b.create<arith::SelectOp>(loc, pred, lhs, rhs); |
2342 | 2348 | }
|
2343 | 2349 | if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
2344 |
| - if (!maximum.getType() |
2345 |
| - .cast<ValueTensorType>() |
2346 |
| - .getDtype() |
2347 |
| - .isa<mlir::FloatType>()) { |
2348 |
| - maximum.emitError("unimplemented: non-floating point dtype"); |
2349 |
| - return nullptr; |
2350 |
| - } |
2351 |
| - Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT, |
2352 |
| - payloadArgs[0], payloadArgs[1]); |
2353 |
| - return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]); |
| 2350 | + Type dtype = maximum.getType().cast<BaseTensorType>().getDtype(); |
| 2351 | + Type elemTy = converter->convertType(maximum.getType()) |
| 2352 | + .cast<RankedTensorType>() |
| 2353 | + .getElementType(); |
| 2354 | + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); |
| 2355 | + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); |
| 2356 | + Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); |
| 2357 | + return b.create<arith::SelectOp>(loc, pred, lhs, rhs); |
2354 | 2358 | }
|
2355 | 2359 | if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
2356 | 2360 | Type dtype = converter->convertType(clamp.getType())
|
|
0 commit comments