|
22 | 22 | #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
23 | 23 | #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
24 | 24 | #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
| 25 | +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" |
25 | 26 | #include <iostream>
|
26 | 27 | #include <numeric>
|
27 | 28 |
|
@@ -628,6 +629,35 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
628 | 629 |
|
629 | 630 | } // namespace
|
630 | 631 |
|
| 632 | +// Convert a Aten::GELU to HLO |
| 633 | +// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] |
| 634 | +namespace { |
| 635 | +template <> |
| 636 | +LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite( |
| 637 | + AtenGeluOp op, |
| 638 | + OpAdaptor adaptor, |
| 639 | + ConversionPatternRewriter& rewriter) const { |
| 640 | + Location loc = op.getLoc(); |
| 641 | + Value input = adaptor.self(); |
| 642 | + auto inputTy = input.getType().template dyn_cast<RankedTensorType>(); |
| 643 | + if (!inputTy) { |
| 644 | + return op.emitError("only ranked tensor type is supported."); |
| 645 | + } |
| 646 | + |
| 647 | + Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); |
| 648 | + Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); |
| 649 | + Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); |
| 650 | + auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two); |
| 651 | + auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo); |
| 652 | + auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement); |
| 653 | + auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one); |
| 654 | + auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half); |
| 655 | + rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul); |
| 656 | + return success(); |
| 657 | +} |
| 658 | +} // namespace |
| 659 | + |
| 660 | + |
631 | 661 | // AtenErfOp
|
632 | 662 | namespace {
|
633 | 663 | template <>
|
@@ -984,6 +1014,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
984 | 1014 | INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
985 | 1015 |
|
986 | 1016 | INSERT_ATENOP_PATTERN(AtenReluOp);
|
| 1017 | + INSERT_ATENOP_PATTERN(AtenGeluOp); |
987 | 1018 | INSERT_ATENOP_PATTERN(AtenErfOp);
|
988 | 1019 |
|
989 | 1020 | INSERT_ATENOP_PATTERN(AtenBatchNormOp);
|
|
0 commit comments