Skip to content

Commit 781b082

Browse files
authored
Fixes to_edge_transform_and_lower when unsupported ops are asked for preservation (#8776)
* init * up * up * up
1 parent 7e0a446 commit 781b082

File tree

3 files changed

+78
-6
lines changed

3 files changed

+78
-6
lines changed

backends/apple/coreml/partition/coreml_partitioner.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,16 @@ def ops_to_not_decompose(
111111
do_not_decompose = []
112112
op_support = OperatorsSupportedForCoreMLBackend()
113113
for node in ep.graph.nodes:
114-
if (
115-
node.op == "call_function"
116-
and isinstance(node.target, torch._ops.OpOverload)
117-
and op_support.is_node_supported(None, node)
114+
if node.op == "call_function" and isinstance(
115+
node.target, torch._ops.OpOverload
118116
):
119-
do_not_decompose.append(node.target)
117+
try:
118+
if op_support.is_node_supported(None, node):
119+
do_not_decompose.append(node.target)
120+
except Exception as e:
121+
# CoreML's op_support.is_node_supported will sometimes throw
122+
# for unsupported ops, rather than returning False
123+
logger.warning(
124+
f"Encountered exception when checking node support: {e}"
125+
)
120126
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,28 @@ def test_vit_skip_conv(self):
8282

8383
def test_ops_to_not_decompose(self):
8484
class Model(torch.nn.Module):
85+
def __init__(self) -> None:
86+
super().__init__()
87+
8588
def forward(self, q, k, v, mask):
86-
return torch.ops.aten.scaled_dot_product_attention.default(
89+
out = torch.ops.aten.scaled_dot_product_attention.default(
8790
q, k, v, attn_mask=mask
8891
)
8992

93+
# Add non-functional and alias ops
94+
# These will be removed by ExecuTorch in non-decomposition
95+
# table because they cannot be functionalized
96+
out = out.transpose(1, 2)
97+
out = out.view(1, -1)
98+
out = out.permute(0, 1)
99+
out = out.add_(1.0)
100+
out = out.mul_(2.0)
101+
out = out.div_(3.0)
102+
out = out.sub_(4.0)
103+
out = torch.ops.aten.view_copy.default(out, (-1,))
104+
out = out.select(0, 0)
105+
return out
106+
90107
model = Model()
91108
model.eval()
92109

exir/program/_program.py

+49
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
2727
from executorch.exir.error import ExportError
2828
from executorch.exir.graph_module import get_control_flow_submodules
29+
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
2930
from executorch.exir.pass_base import PassBase
3031
from executorch.exir.pass_manager import PassType
3132
from executorch.exir.passes import (
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
836837
ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
837838
program
838839
)
840+
ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose(
841+
ops_set_to_not_decompose
842+
)
839843

840844
for op_aten in ops_set_to_not_decompose:
841845
_register_no_decomp_op(op_aten)
@@ -965,6 +969,47 @@ def _sanity_check_graph_for_non_decomp_ops(
965969
logging.warning(warning_str)
966970

967971

972+
def _remove_invalid_ops_for_not_decompose(
973+
ops_to_not_decompose: List[torch._ops.OpOverload],
974+
) -> List[torch._ops.OpOverload]:
975+
# To address https://github.com/pytorch/executorch/issues/8781
976+
def keep(op):
977+
schema = op._schema
978+
native_schema = _pybind_schema_to_native_schema(schema)
979+
if native_schema.is_mutable:
980+
logging.warn(
981+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable."
982+
)
983+
return False
984+
985+
if native_schema.aliased_return_names() != [None]:
986+
logging.warn(
987+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output."
988+
)
989+
return False
990+
991+
# Explicit block list of ops that don't work if asked for
992+
# preservation
993+
if op in [
994+
# Hits infinte recursion error when op is in
995+
# EDGE_DO_NOT_DECOMP namespace
996+
torch.ops.aten._to_copy.default,
997+
# scalar to tensor type promotion does not work on ops
998+
# in EDGE_DO_NOT_DECOMP namespace
999+
torch.ops.aten.mul.Tensor,
1000+
torch.ops.aten.add.Tensor,
1001+
torch.ops.aten.sub.Tensor,
1002+
torch.ops.aten.div.Tensor,
1003+
]:
1004+
logging.warn(
1005+
f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist."
1006+
)
1007+
return False
1008+
return True
1009+
1010+
return list(filter(keep, ops_to_not_decompose))
1011+
1012+
9681013
def _gen_edge_manager_for_partitioners(
9691014
partitioner: Dict[str, List[Partitioner]],
9701015
aten_programs: Dict[str, ExportedProgram],
@@ -992,6 +1037,9 @@ def _gen_edge_manager_for_partitioners(
9921037
all_ops_no_decomp = set()
9931038
for curr_partitioner in partitioner.get(name, []):
9941039
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1040+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1041+
curr_ops_no_decomp
1042+
)
9951043
all_ops_no_decomp |= set(curr_ops_no_decomp)
9961044

9971045
table = _default_decomposition_table()
@@ -1113,6 +1161,7 @@ def to_edge_transform_and_lower(
11131161
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
11141162
program
11151163
)
1164+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
11161165
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
11171166
_sanity_check_graph_for_non_decomp_ops(
11181167
name,

0 commit comments

Comments
 (0)