Skip to content

Commit 22fa250

Browse files
committed
Add dim_order compat support
Differential Revision: D67542995
1 parent fb1cc93 commit 22fa250

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

backends/apple/mps/mps_preprocess.py

+5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
CompileSpec,
3333
PreprocessResult,
3434
)
35+
36+
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
3537
from torch.export.exported_program import ExportedProgram
3638

3739
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -83,6 +85,9 @@ def preprocess(
8385
# FlatBuffer graph, process the `output` nodes and add their id to
8486
# the `output_ids` array in the schema.
8587

88+
# TODO: Remove this once we have a better support for the dim-order ops.
89+
edge_program = DimOrderOpsRevertPass()(edge_program)
90+
8691
mps_graph = MPSGraph(
8792
version="0",
8893
mps_nodes=[],

backends/apple/mps/operators/constant_ops.py

+9
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ def define_node(
7878
)
7979
)
8080

81+
@register_node_visitor
82+
class ToDimOrderEmptyVisitor(NodeVisitor):
83+
target = ["exir_ops.edge.dim_order_ops._to_dim_order_copy.default"]
84+
85+
def __init__(self, *args) -> None:
86+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op
87+
# But if we do, we can't handle it ATM, so raise an exception
88+
raise NotImplementedError("exir_ops.edge.dim_order_ops._to_dim_order_copy.default is not supported yet")
89+
8190

8291
@register_node_visitor
8392
class FullLikeVisitor(NodeVisitor):

backends/apple/mps/operators/op_clone.py

+9
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,12 @@ def define_node(
3333
)
3434
input_id = self.define_tensor(get_input_node(node, 0), mps_graph)
3535
self.tensor_to_id[node] = input_id
36+
37+
@register_node_visitor
38+
class ToDimOrderCopyVisitor(NodeVisitor):
39+
target = ["exir_ops.edge.dim_order_ops._to_dim_order_copy.default"]
40+
41+
def __init__(self, *args) -> None:
42+
# We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op
43+
# But if we do, we can't handle it ATM, so raise an exception
44+
raise NotImplementedError("exir_ops.edge.dim_order_ops._to_dim_order_copy.default is not supported yet")

0 commit comments

Comments
 (0)