|
10 | 10 |
|
11 | 11 | import executorch.exir as exir
|
12 | 12 | import torch
|
| 13 | +from executorch.exir import to_edge |
13 | 14 | from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
|
| 15 | +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( |
| 16 | + AllNodePartitioner, |
| 17 | +) |
14 | 18 | from executorch.exir.backend.compile_spec_schema import CompileSpec
|
15 | 19 | from executorch.exir.backend.partitioner import (
|
16 | 20 | DelegationSpec,
|
@@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]):
|
1266 | 1270 |
|
1267 | 1271 | gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge()
|
1268 | 1272 | gm(*inputs)
|
| 1273 | + |
| 1274 | + def test_to_backend_delegation_spec(self): |
| 1275 | + class SinModule(torch.nn.Module): |
| 1276 | + def __init__(self): |
| 1277 | + super().__init__() |
| 1278 | + |
| 1279 | + def forward(self, x): |
| 1280 | + return [torch.sin(x)] |
| 1281 | + |
| 1282 | + sin_module = SinModule() |
| 1283 | + model_inputs = (torch.ones(1),) |
| 1284 | + max_value = model_inputs[0].shape[0] |
| 1285 | + |
| 1286 | + partitioner = AllNodePartitioner( |
| 1287 | + "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))] |
| 1288 | + ) |
| 1289 | + |
| 1290 | + edgeir_m = to_edge(torch.export.export(sin_module, model_inputs)) |
| 1291 | + edgeir_m = edgeir_m.to_backend(partitioner) |
| 1292 | + exec_prog = edgeir_m.to_executorch() |
| 1293 | + graph_module = exec_prog.exported_program().graph_module |
| 1294 | + # Check that there is not an aten.sin node. |
| 1295 | + self.assertTrue( |
| 1296 | + exir_ops.edge.aten.sin |
| 1297 | + not in {node.target for node in graph_module.graph.nodes} |
| 1298 | + ) |
| 1299 | + |
| 1300 | + # Check that there exists a call_delegate, representing the call to the |
| 1301 | + # delegated function |
| 1302 | + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( |
| 1303 | + graph_module.code |
| 1304 | + ) |
| 1305 | + lowered_submodules = get_lowered_submodules(graph_module) |
| 1306 | + self.assertEqual(len(lowered_submodules), 1) |
| 1307 | + |
| 1308 | + for node in graph_module.graph.nodes: |
| 1309 | + if node.op == "call_function" and node.target == executorch_call_delegate: |
| 1310 | + # Check that first arg is lowered_module_{unique_id} |
| 1311 | + self.assertEqual(node.args[0].target, "lowered_module_0") |
| 1312 | + |
| 1313 | + program = exec_prog.executorch_program |
| 1314 | + |
| 1315 | + # Check the program can be printed |
| 1316 | + print_program(program) |
| 1317 | + |
| 1318 | + # Check the backend delegate |
| 1319 | + self.check_backend_delegate( |
| 1320 | + program=program, |
| 1321 | + delegate=program.execution_plan[0].delegates[0], |
| 1322 | + expected_id=BackendWithCompilerDemo.__name__, |
| 1323 | + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#", |
| 1324 | + ) |
| 1325 | + |
| 1326 | + # Check the delegate instruction |
| 1327 | + self.assertTrue( |
| 1328 | + isinstance( |
| 1329 | + program.execution_plan[0].chains[0].instructions[0].instr_args, |
| 1330 | + DelegateCall, |
| 1331 | + ) |
| 1332 | + ) |
| 1333 | + buff = exec_prog.buffer |
| 1334 | + |
| 1335 | + executorch_module = _load_for_executorch_from_buffer(buff) |
| 1336 | + model_inputs = torch.ones(1) |
| 1337 | + model_outputs = executorch_module.forward([model_inputs]) |
| 1338 | + self.assertEqual( |
| 1339 | + model_inputs, |
| 1340 | + torch.ones(1), |
| 1341 | + ) |
| 1342 | + expected_output = 0.8333 * torch.ones(1) |
| 1343 | + |
| 1344 | + self.assertTrue( |
| 1345 | + torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) |
| 1346 | + ) |
| 1347 | + |
| 1348 | + def test_to_backend_multimethod_delegation_spec(self): |
| 1349 | + class SinModule(torch.nn.Module): |
| 1350 | + def __init__(self): |
| 1351 | + super().__init__() |
| 1352 | + |
| 1353 | + def forward(self, x): |
| 1354 | + return torch.sin(x) |
| 1355 | + |
| 1356 | + def inputs(self): |
| 1357 | + return (torch.ones(1),) |
| 1358 | + |
| 1359 | + class AddMulModule(torch.nn.Module): |
| 1360 | + def __init__(self): |
| 1361 | + super().__init__() |
| 1362 | + |
| 1363 | + def forward(self, a, x, b): |
| 1364 | + y = torch.mm(a, x) |
| 1365 | + z = torch.add(y, b) |
| 1366 | + return z |
| 1367 | + |
| 1368 | + def inputs(self): |
| 1369 | + return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) |
| 1370 | + |
| 1371 | + sin_module = SinModule() |
| 1372 | + max_value_sin = sin_module.inputs()[0].shape[0] |
| 1373 | + sin_partitioner = AllNodePartitioner( |
| 1374 | + "BackendWithCompilerDemo", |
| 1375 | + [CompileSpec("max_value", bytes([max_value_sin]))], |
| 1376 | + ) |
| 1377 | + |
| 1378 | + add_mul_module = AddMulModule() |
| 1379 | + max_value_add_mul = add_mul_module.inputs()[0].shape[0] |
| 1380 | + add_mul_partitioner = AllNodePartitioner( |
| 1381 | + "BackendWithCompilerDemo", |
| 1382 | + [CompileSpec("max_value", bytes([max_value_add_mul]))], |
| 1383 | + ) |
| 1384 | + |
| 1385 | + edgeir_m = to_edge( |
| 1386 | + { |
| 1387 | + "sin": torch.export.export(sin_module, sin_module.inputs()), |
| 1388 | + "add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()), |
| 1389 | + } |
| 1390 | + ) |
| 1391 | + edgeir_m = edgeir_m.to_backend( |
| 1392 | + { |
| 1393 | + "sin": sin_partitioner, |
| 1394 | + "add_mul": add_mul_partitioner, |
| 1395 | + } |
| 1396 | + ) |
| 1397 | + exec_prog = edgeir_m.to_executorch() |
| 1398 | + |
| 1399 | + for method_name in ["sin", "add_mul"]: |
| 1400 | + graph_module = exec_prog.exported_program(method_name).graph_module |
| 1401 | + # Check delegated nodes are gone |
| 1402 | + self.assertTrue( |
| 1403 | + exir_ops.edge.aten.sin |
| 1404 | + not in {node.target for node in graph_module.graph.nodes} |
| 1405 | + ) |
| 1406 | + self.assertTrue( |
| 1407 | + exir_ops.edge.aten.add |
| 1408 | + not in {node.target for node in graph_module.graph.nodes} |
| 1409 | + ) |
| 1410 | + self.assertTrue( |
| 1411 | + exir_ops.edge.aten.mm |
| 1412 | + not in {node.target for node in graph_module.graph.nodes} |
| 1413 | + ) |
| 1414 | + # Check that there exists a call_delegate, representing the call to the |
| 1415 | + # delegated function |
| 1416 | + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( |
| 1417 | + graph_module.code |
| 1418 | + ) |
| 1419 | + lowered_submodules = get_lowered_submodules(graph_module) |
| 1420 | + self.assertEqual(len(lowered_submodules), 1) |
| 1421 | + |
| 1422 | + program = exec_prog.executorch_program |
| 1423 | + |
| 1424 | + # Check the program can be printed |
| 1425 | + print_program(program) |
| 1426 | + |
| 1427 | + buff = exec_prog.buffer |
| 1428 | + |
| 1429 | + executorch_module = _load_for_executorch_from_buffer(buff) |
| 1430 | + |
| 1431 | + for method_name, module in { |
| 1432 | + "sin": sin_module, |
| 1433 | + "add_mul": add_mul_module, |
| 1434 | + }.items(): |
| 1435 | + inputs_flattened, _ = tree_flatten(module.inputs()) |
| 1436 | + model_outputs = executorch_module.run_method( |
| 1437 | + method_name, tuple(inputs_flattened) |
| 1438 | + ) |
| 1439 | + |
| 1440 | + if method_name == "sin": |
| 1441 | + # backend with compiler demo does a taylor approximation of sin |
| 1442 | + ref_output = 0.8333 * torch.ones(1) |
| 1443 | + else: |
| 1444 | + ref_output = module(*module.inputs()) |
| 1445 | + self.assertTrue( |
| 1446 | + torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) |
| 1447 | + ) |
0 commit comments