|
26 | 26 | from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
|
27 | 27 | from executorch.exir.error import ExportError
|
28 | 28 | from executorch.exir.graph_module import get_control_flow_submodules
|
| 29 | +from executorch.exir.operator.convert import _pybind_schema_to_native_schema |
29 | 30 | from executorch.exir.pass_base import PassBase
|
30 | 31 | from executorch.exir.pass_manager import PassType
|
31 | 32 | from executorch.exir.passes import (
|
@@ -836,6 +837,9 @@ def _replace_aten_ops_with_transformed_ops(
|
836 | 837 | ops_set_to_not_decompose, check_op_support = partitioner.ops_to_not_decompose(
|
837 | 838 | program
|
838 | 839 | )
|
| 840 | + ops_set_to_not_decompose = _remove_invalid_ops_for_not_decompose( |
| 841 | + ops_set_to_not_decompose |
| 842 | + ) |
839 | 843 |
|
840 | 844 | for op_aten in ops_set_to_not_decompose:
|
841 | 845 | _register_no_decomp_op(op_aten)
|
@@ -965,6 +969,47 @@ def _sanity_check_graph_for_non_decomp_ops(
|
965 | 969 | logging.warning(warning_str)
|
966 | 970 |
|
967 | 971 |
|
| 972 | +def _remove_invalid_ops_for_not_decompose( |
| 973 | + ops_to_not_decompose: List[torch._ops.OpOverload], |
| 974 | +) -> List[torch._ops.OpOverload]: |
| 975 | + # To address https://github.com/pytorch/executorch/issues/8781 |
| 976 | + def keep(op): |
| 977 | + schema = op._schema |
| 978 | + native_schema = _pybind_schema_to_native_schema(schema) |
| 979 | + if native_schema.is_mutable: |
| 980 | + logging.warn( |
| 981 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is mutable." |
| 982 | + ) |
| 983 | + return False |
| 984 | + |
| 985 | + if native_schema.aliased_return_names() != [None]: |
| 986 | + logging.warn( |
| 987 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it aliases output." |
| 988 | + ) |
| 989 | + return False |
| 990 | + |
| 991 | + # Explicit block list of ops that don't work if asked for |
| 992 | + # preservation |
| 993 | + if op in [ |
| 994 | + # Hits infinte recursion error when op is in |
| 995 | + # EDGE_DO_NOT_DECOMP namespace |
| 996 | + torch.ops.aten._to_copy.default, |
| 997 | + # scalar to tensor type promotion does not work on ops |
| 998 | + # in EDGE_DO_NOT_DECOMP namespace |
| 999 | + torch.ops.aten.mul.Tensor, |
| 1000 | + torch.ops.aten.add.Tensor, |
| 1001 | + torch.ops.aten.sub.Tensor, |
| 1002 | + torch.ops.aten.div.Tensor, |
| 1003 | + ]: |
| 1004 | + logging.warn( |
| 1005 | + f"Op {op} was requested for preservation by partitioner. This request is ignored because it is in a blocklist." |
| 1006 | + ) |
| 1007 | + return False |
| 1008 | + return True |
| 1009 | + |
| 1010 | + return list(filter(keep, ops_to_not_decompose)) |
| 1011 | + |
| 1012 | + |
968 | 1013 | def _gen_edge_manager_for_partitioners(
|
969 | 1014 | partitioner: Dict[str, List[Partitioner]],
|
970 | 1015 | aten_programs: Dict[str, ExportedProgram],
|
@@ -992,6 +1037,9 @@ def _gen_edge_manager_for_partitioners(
|
992 | 1037 | all_ops_no_decomp = set()
|
993 | 1038 | for curr_partitioner in partitioner.get(name, []):
|
994 | 1039 | curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
|
| 1040 | + curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( |
| 1041 | + curr_ops_no_decomp |
| 1042 | + ) |
995 | 1043 | all_ops_no_decomp |= set(curr_ops_no_decomp)
|
996 | 1044 |
|
997 | 1045 | table = _default_decomposition_table()
|
@@ -1113,6 +1161,7 @@ def to_edge_transform_and_lower(
|
1113 | 1161 | curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
|
1114 | 1162 | program
|
1115 | 1163 | )
|
| 1164 | + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) |
1116 | 1165 | ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
|
1117 | 1166 | _sanity_check_graph_for_non_decomp_ops(
|
1118 | 1167 | name,
|
|
0 commit comments