File tree 5 files changed +60
-6
lines changed
5 files changed +60
-6
lines changed Original file line number Diff line number Diff line change 32
32
CompileSpec ,
33
33
PreprocessResult ,
34
34
)
35
+
36
+ from executorch .exir .passes .memory_format_ops_pass import DimOrderOpsRevertPass
37
+ from executorch .exir .program ._program import _transform
35
38
from torch .export .exported_program import ExportedProgram
36
39
37
40
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +86,9 @@ def preprocess(
83
86
# FlatBuffer graph, process the `output` nodes and add their id to
84
87
# the `output_ids` array in the schema.
85
88
89
+ # TODO: Remove this once we have a better support for the dim-order ops.
90
+ edge_program = _transform (edge_program , DimOrderOpsRevertPass ())
91
+
86
92
mps_graph = MPSGraph (
87
93
version = "0" ,
88
94
mps_nodes = [],
Original file line number Diff line number Diff line change @@ -79,6 +79,25 @@ def define_node(
79
79
)
80
80
81
81
82
+ @register_node_visitor
83
+ class ToDimOrderEmptyVisitor (NodeVisitor ):
84
+ target = ["dim_order_ops._empty_dim_order.default" ]
85
+
86
+ def __init__ (self , * args ) -> None :
87
+ super ().__init__ (* args )
88
+
89
+ def define_node (
90
+ self ,
91
+ node : torch .fx .Node ,
92
+ mps_graph : MPSGraph ,
93
+ ) -> None :
94
+ # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
95
+ # But if we do, we can't handle it ATM, so raise an exception
96
+ raise NotImplementedError (
97
+ "dim_order_ops._empty_dim_order.default is not supported yet"
98
+ )
99
+
100
+
82
101
@register_node_visitor
83
102
class FullLikeVisitor (NodeVisitor ):
84
103
target = "aten.full_like.default"
Original file line number Diff line number Diff line change @@ -33,3 +33,22 @@ def define_node(
33
33
)
34
34
input_id = self .define_tensor (get_input_node (node , 0 ), mps_graph )
35
35
self .tensor_to_id [node ] = input_id
36
+
37
+
38
+ @register_node_visitor
39
+ class ToDimOrderCopyVisitor (NodeVisitor ):
40
+ target = ["dim_order_ops._to_dim_order_copy.default" ]
41
+
42
+ def __init__ (self , * args ) -> None :
43
+ super ().__init__ (* args )
44
+
45
+ def define_node (
46
+ self ,
47
+ node : torch .fx .Node ,
48
+ mps_graph : MPSGraph ,
49
+ ) -> None :
50
+ # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
51
+ # But if we do, we can't handle it ATM, so raise an exception
52
+ raise NotImplementedError (
53
+ "dim_order_ops._to_dim_order_copy.default is not supported yet"
54
+ )
Original file line number Diff line number Diff line change @@ -1829,6 +1829,21 @@ def forward(self, x):
1829
1829
Clone (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
1830
1830
)
1831
1831
1832
+ def test_mps_backend_to_copy (self ):
1833
+ class Copy (torch .nn .Module ):
1834
+ def forward (self , x ):
1835
+ return (
1836
+ torch .ops .aten ._to_copy .default (
1837
+ x + 2 , memory_format = torch .contiguous_format
1838
+ )
1839
+ + x
1840
+ )
1841
+
1842
+ model_inputs = (torch .randn (1 , 3 , 3 ),)
1843
+ self .lower_and_test_with_partitioner (
1844
+ Copy (), model_inputs , func_name = inspect .stack ()[0 ].function [5 :]
1845
+ )
1846
+
1832
1847
def test_mps_backend_floor (self ):
1833
1848
class Floor (torch .nn .Module ):
1834
1849
def forward (self , x ):
Original file line number Diff line number Diff line change 26
26
27
27
# Config for Capturing the weights, will be moved in the future
28
28
29
- # TODO(T182928844): Delegate dim order op to backend.
30
- _EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (
31
- _check_ir_validity = False , _skip_dim_order = True
32
- )
29
+ _EDGE_COMPILE_CONFIG = exir .EdgeCompileConfig (_check_ir_validity = False )
33
30
34
31
35
32
class ansi_colors :
@@ -219,7 +216,6 @@ def lower_module_and_test_output(
219
216
dynamic_shapes = dynamic_shapes ,
220
217
edge_compile_config = EdgeCompileConfig (
221
218
_check_ir_validity = False ,
222
- _skip_dim_order = True , # TODO(T182928844): Delegate dim order op to backend.
223
219
),
224
220
)
225
221
@@ -253,7 +249,6 @@ def lower_module_and_test_output(
253
249
),
254
250
compile_config = exir .EdgeCompileConfig (
255
251
_check_ir_validity = False ,
256
- _skip_dim_order = True , # TODO(T182928844): Delegate dim order op to backend.
257
252
),
258
253
).to_executorch (
259
254
config = ExecutorchBackendConfig (extract_delegate_segments = False )
You can’t perform that action at this time.
0 commit comments