@@ -8,10 +8,17 @@ namespace passes {
8
8
9
9
void ReduceGelu (std::shared_ptr<torch::jit::Graph>& graph) {
10
10
std::string gelu_pattern = R"IR(
11
- graph(%x):
11
+ graph(%x : Tensor ):
12
12
%out : Tensor = aten::gelu(%x)
13
13
return (%out))IR" ;
14
14
15
+ // This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16
+ // an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
17
+ std::string gelu_approximate_pattern = R"IR(
18
+ graph(%x : Tensor, %approx):
19
+ %out : Tensor = aten::gelu(%x, %approx)
20
+ return (%out))IR" ;
21
+
15
22
std::string gelu_reduce_pattern = R"IR(
16
23
graph(%x.1 : Tensor):
17
24
%6 : float = prim::Constant[value=0.044714999999999998]()
@@ -30,11 +37,36 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
30
37
%15 : Tensor = aten::mul(%7, %14)
31
38
return (%15))IR" ;
32
39
40
+ // This is same as gelu_reduce_pattern except for an additional input %approx.
41
+ // SubgraphRewriter only works as expected if the number of inputs to gelu_approximate_pattern
42
+ // and gelu_reduce_multi_input_pattern are same.
43
+ std::string gelu_reduce_multi_input_pattern = R"IR(
44
+ graph(%x.1 : Tensor, %approx):
45
+ %6 : float = prim::Constant[value=0.044714999999999998]()
46
+ %5 : float = prim::Constant[value=0.79788456080000003]()
47
+ %4 : float = prim::Constant[value=1.]()
48
+ %3 : float = prim::Constant[value=0.5]()
49
+ %2 : int = prim::Constant[value=1]()
50
+ %7 : Tensor = aten::mul(%x.1, %3)
51
+ %8 : Tensor = aten::mul(%x.1, %5)
52
+ %9 : Tensor = aten::mul(%x.1, %6)
53
+ %10 : Tensor = aten::mul(%9, %x.1)
54
+ %11 : Tensor = aten::add(%10, %4, %2)
55
+ %12 : Tensor = aten::mul(%8, %11)
56
+ %13 : Tensor = aten::tanh(%12)
57
+ %14 : Tensor = aten::add(%13, %4, %2)
58
+ %15 : Tensor = aten::mul(%7, %14)
59
+ return (%15))IR" ;
60
+
33
61
// replace aten::gelu with pointwise operations
34
62
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
35
63
map_gelu_to_pointwise_ops.RegisterRewritePattern (gelu_pattern, gelu_reduce_pattern);
36
64
map_gelu_to_pointwise_ops.runOnGraph (graph);
37
65
66
+ torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67
+ map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern (gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
68
+ map_gelu_approximate_to_pointwise_ops.runOnGraph (graph);
69
+
38
70
LOG_GRAPH (" Post lowering of [aten::gelu] -> " << *graph);
39
71
}
40
72
0 commit comments