Skip to content

Commit 704efdc

Browse files
authored
[MHLO] add aten::gelu op pattern (llvm#1127)
add aten::gelu op pattern, and moved some unit tests from basic.mlir to elementwise.mlir
1 parent 76c9766 commit 704efdc

File tree

5 files changed

+464
-403
lines changed

5 files changed

+464
-403
lines changed

lib/Conversion/TorchToMhlo/BasicOp.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
2323
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2424
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
25+
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
2526
#include <iostream>
2627
#include <numeric>
2728

@@ -628,6 +629,35 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
628629

629630
} // namespace
630631

632+
// Convert a Aten::GELU to HLO
633+
// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))]
634+
namespace {
635+
template <>
636+
LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
637+
AtenGeluOp op,
638+
OpAdaptor adaptor,
639+
ConversionPatternRewriter& rewriter) const {
640+
Location loc = op.getLoc();
641+
Value input = adaptor.self();
642+
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
643+
if (!inputTy) {
644+
return op.emitError("only ranked tensor type is supported.");
645+
}
646+
647+
Value one = chlo::getConstantLike(rewriter, loc, 1.0, input);
648+
Value two = chlo::getConstantLike(rewriter, loc, 2.0, input);
649+
Value half = chlo::getConstantLike(rewriter, loc, 0.5, input);
650+
auto rsqrtTwo = rewriter.create<mlir::mhlo::RsqrtOp>(loc, two);
651+
auto erfElement = rewriter.create<mhlo::MulOp>(loc, input, rsqrtTwo);
652+
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
653+
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
654+
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
655+
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
656+
return success();
657+
}
658+
} // namespace
659+
660+
631661
// AtenErfOp
632662
namespace {
633663
template <>
@@ -984,6 +1014,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
9841014
INSERT_ATENOP_PATTERN(AtenContiguousOp);
9851015

9861016
INSERT_ATENOP_PATTERN(AtenReluOp);
1017+
INSERT_ATENOP_PATTERN(AtenGeluOp);
9871018
INSERT_ATENOP_PATTERN(AtenErfOp);
9881019

9891020
INSERT_ATENOP_PATTERN(AtenBatchNormOp);

lib/Conversion/TorchToMhlo/PopulatePatterns.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter,
2525
void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter,
2626
RewritePatternSet &patterns,
2727
ConversionTarget &target);
28+
2829
} // namespace torch_to_mhlo
2930
} // namespace torch
3031
} // namespace mlir

lib/Conversion/TorchToMhlo/TorchToMhlo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
2424
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
2525
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
26+
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
2627

2728
using namespace mlir;
2829
using namespace mlir::torch;

0 commit comments

Comments
 (0)