|
1 | 1 | from copy import deepcopy
|
2 | 2 |
|
3 | 3 | import numpy as np
|
| 4 | +import pytest |
4 | 5 | import torch
|
| 6 | +import torch.nn.functional as F |
| 7 | +import torch_tensorrt |
| 8 | +from parameterized import parameterized |
5 | 9 | from torch.testing._internal.common_utils import TestCase, run_tests
|
6 | 10 | from torch_tensorrt.dynamo import partitioning
|
7 | 11 |
|
8 | 12 | from ..testing_utilities import lower_graph_testing
|
9 | 13 |
|
10 | 14 |
|
11 | 15 | class TestGlobalPartitioning(TestCase):
|
| 16 | + @parameterized.expand( |
| 17 | + [ |
| 18 | + ({}, 1), |
| 19 | + ({"torch.ops.aten.relu.default"}, 3), |
| 20 | + ] |
| 21 | + ) |
| 22 | + def test_end2end_global_partition(self, torch_executed_ops, trt_mod_cnt): |
| 23 | + class SimpleCNN(torch.nn.Module): |
| 24 | + def __init__(self): |
| 25 | + super(SimpleCNN, self).__init__() |
| 26 | + self.conv1 = torch.nn.Conv2d(3, 12, 3, padding=1) |
| 27 | + self.bn = torch.nn.BatchNorm2d(12) |
| 28 | + self.conv2 = torch.nn.Conv2d(12, 12, 3, padding=1) |
| 29 | + self.fc1 = torch.nn.Linear(12 * 56 * 56, 10) |
| 30 | + |
| 31 | + def forward(self, x, b=5): |
| 32 | + x = self.conv1(x) |
| 33 | + x = F.relu(x) |
| 34 | + x = self.bn(x) |
| 35 | + x = F.max_pool2d(x, (2, 2)) |
| 36 | + x = self.conv2(x) |
| 37 | + x = F.relu(x) |
| 38 | + x = F.max_pool2d(x, (2, 2)) |
| 39 | + x = torch.flatten(x, 1) |
| 40 | + x = x + b |
| 41 | + return self.fc1(x) |
| 42 | + |
| 43 | + mod = SimpleCNN().to("cuda") |
| 44 | + mod.eval() |
| 45 | + with torch.no_grad(): |
| 46 | + inputs = torch.rand((1, 3, 224, 224)).to("cuda") |
| 47 | + try: |
| 48 | + trt_mod = torch_tensorrt.compile( |
| 49 | + mod, |
| 50 | + ir="dynamo", |
| 51 | + inputs=[inputs], |
| 52 | + min_block_size=1, |
| 53 | + torch_executed_ops=torch_executed_ops, |
| 54 | + use_fast_partitioner=False, |
| 55 | + ) |
| 56 | + cnt = 0 |
| 57 | + for name, _ in trt_mod.named_children(): |
| 58 | + if "_run_on_acc" in name: |
| 59 | + cnt += 1 |
| 60 | + self.assertEqual(cnt, trt_mod_cnt) |
| 61 | + except Exception as e: |
| 62 | + pytest.fail(f"unexpected exception raised: {e}") |
| 63 | + |
12 | 64 | def test_partition_fully_supported_one_op(self):
|
13 | 65 | class FullySupportedOneOp(torch.nn.Module):
|
14 | 66 | def __init__(self, *args, **kwargs) -> None:
|
|
0 commit comments