Skip to content

Commit 91c9417

Browse files
authored
[ExecuTorch][to_backend] add AllNodePartitioner (#9822)
### Summary We add an allnode partitioner as a canonical partitioenr. This essentially delegates all the nodes in the graph. It is initialized with the backend name and the compile specs to be used ### Testing Unittests
1 parent d6e14fc commit 91c9417

File tree

6 files changed

+300
-0
lines changed

6 files changed

+300
-0
lines changed

exir/backend/canonical_partitioners/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ runtime.python_library(
77
srcs = [
88
"duplicate_dequant_node_pass.py",
99
"pattern_op_partitioner.py",
10+
"all_node_partitioner.py",
1011
],
1112
visibility = [
1213
"//executorch/...",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.exir.backend.backend_details import ExportedProgram
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
18+
19+
20+
def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool:
21+
"""
22+
Returns true if the node is a placeholder node and it is not a tensor
23+
"""
24+
return node.op == "placeholder" and not (
25+
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
26+
)
27+
28+
29+
class AllNodePartitioner(Partitioner):
30+
def __init__(
31+
self,
32+
backend_id: str,
33+
compile_specs: List[CompileSpec],
34+
):
35+
"""
36+
Partitioner that lowers every single node in the graph module unconditionally
37+
to the specified backend_id
38+
"""
39+
super().__init__()
40+
self.delegation_spec = DelegationSpec(backend_id, compile_specs)
41+
42+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
43+
# tag all nodes
44+
partition_tags: Dict[str, DelegationSpec] = {}
45+
for node in exported_program.graph_module.graph.nodes:
46+
if is_non_tensor_placeholder(node, exported_program) or node.op == "output":
47+
continue
48+
49+
delegation_tag = self.delegation_spec.backend_id
50+
node.meta["delegation_tag"] = delegation_tag
51+
partition_tags[delegation_tag] = self.delegation_spec
52+
53+
return PartitionResult(
54+
tagged_exported_program=exported_program, partition_tags=partition_tags
55+
)

exir/backend/test/test_backends.py

+179
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import executorch.exir as exir
1212
import torch
13+
from executorch.exir import to_edge
1314
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
15+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
16+
AllNodePartitioner,
17+
)
1418
from executorch.exir.backend.compile_spec_schema import CompileSpec
1519
from executorch.exir.backend.partitioner import (
1620
DelegationSpec,
@@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]):
12661270

12671271
gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
12681272
gm(*inputs)
1273+
1274+
def test_to_backend_delegation_spec(self):
1275+
class SinModule(torch.nn.Module):
1276+
def __init__(self):
1277+
super().__init__()
1278+
1279+
def forward(self, x):
1280+
return [torch.sin(x)]
1281+
1282+
sin_module = SinModule()
1283+
model_inputs = (torch.ones(1),)
1284+
max_value = model_inputs[0].shape[0]
1285+
1286+
partitioner = AllNodePartitioner(
1287+
"BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))]
1288+
)
1289+
1290+
edgeir_m = to_edge(torch.export.export(sin_module, model_inputs))
1291+
edgeir_m = edgeir_m.to_backend(partitioner)
1292+
exec_prog = edgeir_m.to_executorch()
1293+
graph_module = exec_prog.exported_program().graph_module
1294+
# Check that there is not an aten.sin node.
1295+
self.assertTrue(
1296+
exir_ops.edge.aten.sin
1297+
not in {node.target for node in graph_module.graph.nodes}
1298+
)
1299+
1300+
# Check that there exists a call_delegate, representing the call to the
1301+
# delegated function
1302+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1303+
graph_module.code
1304+
)
1305+
lowered_submodules = get_lowered_submodules(graph_module)
1306+
self.assertEqual(len(lowered_submodules), 1)
1307+
1308+
for node in graph_module.graph.nodes:
1309+
if node.op == "call_function" and node.target == executorch_call_delegate:
1310+
# Check that first arg is lowered_module_{unique_id}
1311+
self.assertEqual(node.args[0].target, "lowered_module_0")
1312+
1313+
program = exec_prog.executorch_program
1314+
1315+
# Check the program can be printed
1316+
print_program(program)
1317+
1318+
# Check the backend delegate
1319+
self.check_backend_delegate(
1320+
program=program,
1321+
delegate=program.execution_plan[0].delegates[0],
1322+
expected_id=BackendWithCompilerDemo.__name__,
1323+
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
1324+
)
1325+
1326+
# Check the delegate instruction
1327+
self.assertTrue(
1328+
isinstance(
1329+
program.execution_plan[0].chains[0].instructions[0].instr_args,
1330+
DelegateCall,
1331+
)
1332+
)
1333+
buff = exec_prog.buffer
1334+
1335+
executorch_module = _load_for_executorch_from_buffer(buff)
1336+
model_inputs = torch.ones(1)
1337+
model_outputs = executorch_module.forward([model_inputs])
1338+
self.assertEqual(
1339+
model_inputs,
1340+
torch.ones(1),
1341+
)
1342+
expected_output = 0.8333 * torch.ones(1)
1343+
1344+
self.assertTrue(
1345+
torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03)
1346+
)
1347+
1348+
def test_to_backend_multimethod_delegation_spec(self):
1349+
class SinModule(torch.nn.Module):
1350+
def __init__(self):
1351+
super().__init__()
1352+
1353+
def forward(self, x):
1354+
return torch.sin(x)
1355+
1356+
def inputs(self):
1357+
return (torch.ones(1),)
1358+
1359+
class AddMulModule(torch.nn.Module):
1360+
def __init__(self):
1361+
super().__init__()
1362+
1363+
def forward(self, a, x, b):
1364+
y = torch.mm(a, x)
1365+
z = torch.add(y, b)
1366+
return z
1367+
1368+
def inputs(self):
1369+
return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
1370+
1371+
sin_module = SinModule()
1372+
max_value_sin = sin_module.inputs()[0].shape[0]
1373+
sin_partitioner = AllNodePartitioner(
1374+
"BackendWithCompilerDemo",
1375+
[CompileSpec("max_value", bytes([max_value_sin]))],
1376+
)
1377+
1378+
add_mul_module = AddMulModule()
1379+
max_value_add_mul = add_mul_module.inputs()[0].shape[0]
1380+
add_mul_partitioner = AllNodePartitioner(
1381+
"BackendWithCompilerDemo",
1382+
[CompileSpec("max_value", bytes([max_value_add_mul]))],
1383+
)
1384+
1385+
edgeir_m = to_edge(
1386+
{
1387+
"sin": torch.export.export(sin_module, sin_module.inputs()),
1388+
"add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()),
1389+
}
1390+
)
1391+
edgeir_m = edgeir_m.to_backend(
1392+
{
1393+
"sin": sin_partitioner,
1394+
"add_mul": add_mul_partitioner,
1395+
}
1396+
)
1397+
exec_prog = edgeir_m.to_executorch()
1398+
1399+
for method_name in ["sin", "add_mul"]:
1400+
graph_module = exec_prog.exported_program(method_name).graph_module
1401+
# Check delegated nodes are gone
1402+
self.assertTrue(
1403+
exir_ops.edge.aten.sin
1404+
not in {node.target for node in graph_module.graph.nodes}
1405+
)
1406+
self.assertTrue(
1407+
exir_ops.edge.aten.add
1408+
not in {node.target for node in graph_module.graph.nodes}
1409+
)
1410+
self.assertTrue(
1411+
exir_ops.edge.aten.mm
1412+
not in {node.target for node in graph_module.graph.nodes}
1413+
)
1414+
# Check that there exists a call_delegate, representing the call to the
1415+
# delegated function
1416+
FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run(
1417+
graph_module.code
1418+
)
1419+
lowered_submodules = get_lowered_submodules(graph_module)
1420+
self.assertEqual(len(lowered_submodules), 1)
1421+
1422+
program = exec_prog.executorch_program
1423+
1424+
# Check the program can be printed
1425+
print_program(program)
1426+
1427+
buff = exec_prog.buffer
1428+
1429+
executorch_module = _load_for_executorch_from_buffer(buff)
1430+
1431+
for method_name, module in {
1432+
"sin": sin_module,
1433+
"add_mul": add_mul_module,
1434+
}.items():
1435+
inputs_flattened, _ = tree_flatten(module.inputs())
1436+
model_outputs = executorch_module.run_method(
1437+
method_name, tuple(inputs_flattened)
1438+
)
1439+
1440+
if method_name == "sin":
1441+
# backend with compiler demo does a taylor approximation of sin
1442+
ref_output = 0.8333 * torch.ones(1)
1443+
else:
1444+
ref_output = module(*module.inputs())
1445+
self.assertTrue(
1446+
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
1447+
)

exir/backend/test/test_backends_lifted.py

+15
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
import torch
1212
from executorch.exir import to_edge
1313
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
14+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
15+
AllNodePartitioner,
16+
)
1417
from executorch.exir.backend.compile_spec_schema import CompileSpec
1518
from executorch.exir.backend.partitioner import (
1619
DelegationSpec,
@@ -138,6 +141,18 @@ def forward(self, x):
138141

139142
self.assertTrue(torch.allclose(new_res, expected_res))
140143

144+
# Test same flow but through edge_program_manager
145+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
146+
loweredir_m = edgeir_m.to_backend(
147+
AllNodePartitioner(BackendWithCompilerDemo.__name__, [])
148+
)
149+
lowered_sin_module = get_lowered_submodules(
150+
loweredir_m.exported_program().graph_module
151+
)[0][1]
152+
153+
new_res = lowered_sin_module(*model_inputs)[0]
154+
155+
self.assertTrue(torch.allclose(new_res, expected_res))
141156
# TODO(tkaruturi): emitting single LoweredBackendModule
142157
# program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program
143158

exir/backend/test/test_compatibility.py

+49
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from executorch.exir import to_edge
1111
from executorch.exir._serialize import _serialize_pte_binary
1212
from executorch.exir.backend.backend_api import to_backend
13+
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
14+
AllNodePartitioner,
15+
)
1316
from executorch.exir.backend.compile_spec_schema import CompileSpec
1417
from executorch.exir.backend.test.backend_with_compiler_demo import (
1518
BackendWithCompilerDemo,
@@ -65,3 +68,49 @@ def forward(self, x):
6568
"loading method forward failed with error 0x30",
6669
):
6770
executorch_module = _load_for_executorch_from_buffer(buff)
71+
72+
def test_compatibility_in_runtime_edge_program_manager(self):
73+
class SinModule(torch.nn.Module):
74+
def __init__(self):
75+
super().__init__()
76+
77+
def forward(self, x):
78+
return torch.sin(x)
79+
80+
sin_module = SinModule()
81+
model_inputs = (torch.ones(1),)
82+
edgeir_m = to_edge(export(sin_module, model_inputs, strict=True))
83+
max_value = model_inputs[0].shape[0]
84+
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
85+
lowered_edge_irm = edgeir_m.to_backend(
86+
AllNodePartitioner("BackendWithCompilerDemo", compile_specs)
87+
)
88+
exec_prog = lowered_edge_irm.to_executorch()
89+
90+
buff = exec_prog.buffer
91+
92+
# The demo backend works well
93+
executorch_module = _load_for_executorch_from_buffer(buff)
94+
model_inputs = torch.ones(1)
95+
_ = executorch_module.forward([model_inputs])
96+
97+
prog = exec_prog.executorch_program
98+
# Rewrite the delegate version number from 0 to 1.
99+
prog.backend_delegate_data[0].data = bytes(
100+
"1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
101+
encoding="utf8",
102+
)
103+
104+
# Generate the .pte file with the wrong version.
105+
buff = bytes(
106+
_serialize_pte_binary(
107+
program=prog,
108+
)
109+
)
110+
111+
# Throw runtime error with error code 0x30, meaning delegate is incompatible.
112+
with self.assertRaisesRegex(
113+
RuntimeError,
114+
"loading method forward failed with error 0x30",
115+
):
116+
executorch_module = _load_for_executorch_from_buffer(buff)

exir/program/TARGETS

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ python_library(
3131
"//executorch/exir/_serialize:lib",
3232
"//executorch/exir/backend:backend_api",
3333
"//executorch/exir/backend:partitioner",
34+
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
3435
"//executorch/exir/capture:config",
3536
"//executorch/exir/emit:emit",
3637
"//executorch/exir/emit:lib",

0 commit comments

Comments
 (0)