Skip to content

Commit 2468a42

Browse files
cherry pick fix global partitioner bug #3195 from main to release/2.5 branch (#3209)
1 parent 911279b commit 2468a42

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
366366

367367
# Partition module into components that can be TRT-accelerated
368368
fast_partitioner_failed = False
369-
370369
# If specified, try using the fast partitioner and fall back to the global one on failure
371370
if settings.use_fast_partitioner:
372371
try:
@@ -408,6 +407,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
408407
# Generate the corresponding TRT Module for those
409408
for name, _ in partitioned_module.named_children():
410409
submodule = getattr(partitioned_module, name)
410+
# filter on the GraphModule
411+
if not isinstance(submodule, torch.fx.graph_module.GraphModule):
412+
continue
411413
# Criteria for a module to be convertible to TRT
412414
if settings.use_fast_partitioner and "_run_on_acc" not in name:
413415
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ def partition(
228228
# Determine partitions based on user specifications and operator support
229229
# Then, fuse partitions and display overview of supported/unsupported operators
230230
partitions = partitioner.propose_partitions()
231-
fused_graph = partitioner.fuse_partitions(partitions)
232-
231+
fused_graph = partitioner.fuse_partitions(partitions, prefix="_run_on_acc_")
233232
if verbose:
234233
supported_ops.print_support_overview(len(partitions))
235234

tests/py/dynamo/partitioning/test_global_partitioning.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,66 @@
11
from copy import deepcopy
22

33
import numpy as np
4+
import pytest
45
import torch
6+
import torch.nn.functional as F
7+
import torch_tensorrt
8+
from parameterized import parameterized
59
from torch.testing._internal.common_utils import TestCase, run_tests
610
from torch_tensorrt.dynamo import partitioning
711

812
from ..testing_utilities import lower_graph_testing
913

1014

1115
class TestGlobalPartitioning(TestCase):
16+
@parameterized.expand(
17+
[
18+
({}, 1),
19+
({"torch.ops.aten.relu.default"}, 3),
20+
]
21+
)
22+
def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt):
23+
class SimpleCNN(torch.nn.Module):
24+
def __init__(self):
25+
super(SimpleCNN, self).__init__()
26+
self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1)
27+
self.bn = torch.nn.BatchNorm2d(12)
28+
self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1)
29+
self.fc1 = torch.nn.Linear(12 * 56 * 56, 10)
30+
31+
def forward(self, x, b=5):
32+
x = self.conv1(x)
33+
x = F.relu(x)
34+
x = self.bn(x)
35+
x = F.max_pool2d(x, (2, 2))
36+
x = self.conv2(x)
37+
x = F.relu(x)
38+
x = F.max_pool2d(x, (2, 2))
39+
x = torch.flatten(x, 1)
40+
x = x + b
41+
return self.fc1(x)
42+
43+
mod = SimpleCNN().to("cuda")
44+
mod.eval()
45+
with torch.no_grad():
46+
inputs = torch.rand((1, 3, 224, 224)).to("cuda")
47+
try:
48+
trt_mod = torch_tensorrt.compile(
49+
mod,
50+
ir="dynamo",
51+
inputs=[inputs],
52+
min_block_size=1,
53+
torch_executed_ops=torch_executed_ops,
54+
use_fast_partitioner=False,
55+
)
56+
cnt = 0
57+
for name, _ in trt_mod.named_children():
58+
if "_run_on_acc" in name:
59+
cnt += 1
60+
self.assertEqual(cnt, trt_mod_cnt)
61+
except Exception as e:
62+
pytest.fail(f"unexpected exception raised: {e}")
63+
1264
def test_partition_fully_supported_one_op(self):
1365
class FullySupportedOneOp(torch.nn.Module):
1466
def __init__(self, *args, **kwargs) -> None:

0 commit comments

Comments
 (0)