|
7 | 7 |
|
8 | 8 | # pyre-unsafe
|
9 | 9 |
|
10 |
| -import torch |
11 | 10 | from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
|
12 | 11 | AnnotateChannelsLastDimOrder,
|
13 | 12 | )
|
|
47 | 46 | )
|
48 | 47 | from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
|
49 | 48 | from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
|
50 |
| - ConvertMeanDimToAveragePool, |
| 49 | + ConvertMeanDimToAveragePoolPass, |
51 | 50 | )
|
52 | 51 | from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
|
53 | 52 | from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
|
|
61 | 60 | from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
|
62 | 61 | UnsqueezeScalarPlaceholdersPass,
|
63 | 62 | )
|
| 63 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
64 | 64 | from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
|
65 | 65 | from executorch.exir import ExportedProgram
|
66 |
| -from executorch.exir.dialects._ops import ops as exir_ops |
67 | 66 | from executorch.exir.pass_manager import PassManager
|
| 67 | +from torch.fx import GraphModule |
68 | 68 |
|
69 | 69 |
|
70 | 70 | class ArmPassManager(PassManager):
|
71 | 71 |
|
72 |
| - def _transform(self, graph_module: torch.fx.GraphModule): |
| 72 | + def __init__(self, tosa_spec: TosaSpecification) -> None: |
| 73 | + self.tosa_spec = tosa_spec |
| 74 | + super().__init__() |
| 75 | + |
| 76 | + def _transform(self, graph_module: GraphModule): |
73 | 77 | return self(graph_module).graph_module
|
74 | 78 |
|
75 |
| - def transform_to_backend_pipeline(self, exported_program: ExportedProgram): |
76 |
| - """Apply passes before transforming program to backend""" |
| 79 | + def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
77 | 80 | self.add_pass(FuseQuantizedActivationPass())
|
| 81 | + self.add_pass(RemoveGetItemPass()) |
| 82 | + self.add_pass(ConvertSplitToSlicePass()) |
| 83 | + self.add_pass(ConvertMmToBmmPass()) |
78 | 84 | self.add_pass(DecomposeLinearPass())
|
| 85 | + self.add_pass(ConvertMeanDimToAveragePoolPass()) |
| 86 | + |
| 87 | + self.add_pass(AnnotateDecomposedMatmulPass()) |
| 88 | + self.add_pass(QuantizeFullArgument()) |
| 89 | + self.add_pass(FoldAndAnnotateQParamsPass()) |
| 90 | + self.add_pass(RetraceFoldedDtypesPass()) |
| 91 | + self.add_pass(InsertTableOpsPass(exported_program)) |
| 92 | + |
| 93 | + self.add_pass(RemoveClonePass()) |
| 94 | + self.add_pass(SizeAdjustConv2DPass()) |
| 95 | + self.add_pass(ConvertExpandCopyToRepeatPass()) |
| 96 | + self.add_pass(UnsqueezeBeforeRepeatPass()) |
| 97 | + self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) |
| 98 | + self.add_pass(CastInt64ToInt32Pass(exported_program)) |
| 99 | + self.add_pass(MatchArgRanksPass(exported_program)) |
| 100 | + self.add_pass(KeepDimsFalseToSqueezePass()) |
| 101 | + self.add_pass(Conv1dUnsqueezePass(exported_program)) |
| 102 | + self.add_pass(DecomposeSelectPass()) |
| 103 | + |
| 104 | + self.add_pass(AnnotateChannelsLastDimOrder()) |
| 105 | + |
| 106 | + return self._transform(exported_program.graph_module) |
| 107 | + |
| 108 | + def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: |
| 109 | + |
| 110 | + self.add_pass(FuseQuantizedActivationPass()) |
79 | 111 | self.add_pass(RemoveGetItemPass())
|
| 112 | + self.add_pass(ConvertSplitToSlicePass()) |
| 113 | + self.add_pass(ConvertMmToBmmPass()) |
| 114 | + self.add_pass(DecomposeLinearPass()) |
80 | 115 | self.add_pass(DecomposeLayerNormPass())
|
81 | 116 | self.add_pass(DecomposeVarPass())
|
82 |
| - self.add_pass(ConvertMeanDimToAveragePool()) |
83 | 117 | self.add_pass(DecomposeMeanDimPass())
|
84 |
| - self.add_pass(ConvertSplitToSlicePass()) |
85 |
| - self.add_pass(ConvertMmToBmmPass()) |
86 |
| - # TODO MLETORCH-558 |
| 118 | + self.add_pass(ConvertMeanDimToAveragePoolPass()) |
| 119 | + self.add_pass(DecomposeDivPass()) |
| 120 | + self.add_pass(DecomposeSoftmaxesPass()) |
| 121 | + |
87 | 122 | self.add_pass(AnnotateDecomposedMatmulPass())
|
88 | 123 | self.add_pass(QuantizeFullArgument())
|
89 |
| - self.add_pass( |
90 |
| - FoldAndAnnotateQParamsPass( |
91 |
| - [ |
92 |
| - exir_ops.edge.aten.minimum.default, |
93 |
| - exir_ops.edge.aten.maximum.default, |
94 |
| - exir_ops.edge.aten.add.Tensor, |
95 |
| - exir_ops.edge.aten.avg_pool2d.default, |
96 |
| - exir_ops.edge.aten.bmm.default, |
97 |
| - exir_ops.edge.aten.cat.default, |
98 |
| - exir_ops.edge.aten.convolution.default, |
99 |
| - exir_ops.edge.aten.clone.default, |
100 |
| - exir_ops.edge.aten.exp.default, |
101 |
| - exir_ops.edge.aten.expand_copy.default, |
102 |
| - exir_ops.edge.aten.full.default, |
103 |
| - exir_ops.edge.aten.hardtanh.default, |
104 |
| - exir_ops.edge.aten.log.default, |
105 |
| - exir_ops.edge.aten.max_pool2d.default, |
106 |
| - exir_ops.edge.aten.mul.Tensor, |
107 |
| - exir_ops.edge.aten.permute_copy.default, |
108 |
| - exir_ops.edge.aten.reciprocal.default, |
109 |
| - exir_ops.edge.aten.relu.default, |
110 |
| - exir_ops.edge.aten.repeat.default, |
111 |
| - exir_ops.edge.aten.rsqrt.default, |
112 |
| - exir_ops.edge.aten.select_copy.int, |
113 |
| - exir_ops.edge.aten.sigmoid.default, |
114 |
| - exir_ops.edge.aten.slice_copy.Tensor, |
115 |
| - exir_ops.edge.aten.squeeze_copy.dims, |
116 |
| - exir_ops.edge.aten.sub.Tensor, |
117 |
| - exir_ops.edge.aten.sum.dim_IntList, |
118 |
| - exir_ops.edge.aten.tanh.default, |
119 |
| - exir_ops.edge.aten.unsqueeze_copy.default, |
120 |
| - exir_ops.edge.aten.upsample_nearest2d.vec, |
121 |
| - exir_ops.edge.aten.view_copy.default, |
122 |
| - ] |
123 |
| - ) |
124 |
| - ) |
| 124 | + self.add_pass(FoldAndAnnotateQParamsPass()) |
125 | 125 | self.add_pass(RetraceFoldedDtypesPass())
|
126 | 126 | self.add_pass(InsertTableOpsPass(exported_program))
|
| 127 | + |
| 128 | + self.add_pass(RemoveClonePass()) |
| 129 | + self.add_pass(SizeAdjustConv2DPass()) |
127 | 130 | self.add_pass(ConvertExpandCopyToRepeatPass())
|
128 | 131 | self.add_pass(UnsqueezeBeforeRepeatPass())
|
129 |
| - self.add_pass(CastInt64ToInt32Pass(exported_program)) |
130 | 132 | self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
|
131 |
| - self.add_pass(SizeAdjustConv2DPass()) |
132 |
| - self.add_pass(RemoveClonePass()) |
| 133 | + self.add_pass(CastInt64ToInt32Pass(exported_program)) |
133 | 134 | self.add_pass(MatchArgRanksPass(exported_program))
|
134 |
| - self.add_pass(DecomposeDivPass()) |
135 | 135 | self.add_pass(KeepDimsFalseToSqueezePass())
|
136 | 136 | self.add_pass(Conv1dUnsqueezePass(exported_program))
|
137 |
| - self.add_pass(DecomposeSoftmaxesPass()) |
138 | 137 | self.add_pass(DecomposeSelectPass())
|
| 138 | + |
139 | 139 | self.add_pass(AnnotateChannelsLastDimOrder())
|
140 | 140 |
|
141 | 141 | return self._transform(exported_program.graph_module)
|
142 | 142 |
|
143 |
| - def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule): |
| 143 | + def transform_to_backend_pipeline(self, exported_program: ExportedProgram): |
| 144 | + """Apply passes before transforming program to backend""" |
| 145 | + if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"): |
| 146 | + return self._tosa_080_BI_pipeline(exported_program) |
| 147 | + elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"): |
| 148 | + return self._tosa_080_MI_pipeline(exported_program) |
| 149 | + else: |
| 150 | + raise NotImplementedError( |
| 151 | + f"No pass pipeline implemented for {self.tosa_spec=}" |
| 152 | + ) |
| 153 | + |
| 154 | + def transform_for_annotation_pipeline(self, graph_module: GraphModule): |
144 | 155 | self.add_pass(ScalarsToAttributePass())
|
145 | 156 | self.add_pass(DecomposeLayerNormPass())
|
146 | 157 | self.add_pass(DecomposeVarPass())
|
|
0 commit comments