|
| 1 | +import copy |
1 | 2 | import unittest
|
2 |
| -import torch_tensorrt as torchtrt |
| 3 | +from typing import Dict |
| 4 | + |
| 5 | +import custom_models as cm |
| 6 | +import timm |
3 | 7 | import torch
|
| 8 | +import torch_tensorrt as torchtrt |
4 | 9 | import torchvision.models as models
|
5 |
| -import copy |
6 |
| -import timm |
7 |
| -import custom_models as cm |
8 |
| -from typing import Dict |
9 |
| -from utils import cosine_similarity, COSINE_THRESHOLD |
| 10 | +from utils import COSINE_THRESHOLD, cosine_similarity |
10 | 11 |
|
11 | 12 |
|
12 | 13 | class TestModels(unittest.TestCase):
|
@@ -152,6 +153,47 @@ def test_resnet18_half(self):
|
152 | 153 | msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
|
153 | 154 | )
|
154 | 155 |
|
| 156 | + def test_aten_unbind_dynamic(self): |
| 157 | + class ATenUnbindDynamic(torch.nn.Module): |
| 158 | + def __init__(self) -> None: |
| 159 | + super().__init__() |
| 160 | + |
| 161 | + def forward(self, x): |
| 162 | + x1, x2, x3 = x.unbind(1) |
| 163 | + y = torch.cat([x1, x2, x3], dim=0) |
| 164 | + return y |
| 165 | + |
| 166 | + self.model = ATenUnbindDynamic().eval().to("cuda") |
| 167 | + self.input = torch.randn((5, 3, 7, 64)).to("cuda") |
| 168 | + self.scripted_model = torch.jit.script(self.model) |
| 169 | + |
| 170 | + compile_spec = { |
| 171 | + "inputs": [ |
| 172 | + torchtrt.Input( |
| 173 | + min_shape=[1, 3, 1, 64], |
| 174 | + opt_shape=[5, 3, 32, 64], |
| 175 | + max_shape=[10, 3, 64, 64], |
| 176 | + dtype=torch.float, |
| 177 | + format=torch.contiguous_format, |
| 178 | + ) |
| 179 | + ], |
| 180 | + "device": { |
| 181 | + "device_type": torchtrt.DeviceType.GPU, |
| 182 | + "gpu_id": 0, |
| 183 | + }, |
| 184 | + "enabled_precisions": {torch.float}, |
| 185 | + "ir": "ts", |
| 186 | + } |
| 187 | + |
| 188 | + trt_mod = torchtrt.compile(self.scripted_model, **compile_spec) |
| 189 | + cos_sim = cosine_similarity( |
| 190 | + self.model.half()(self.input.half()), trt_mod(self.input.half()) |
| 191 | + ) |
| 192 | + self.assertTrue( |
| 193 | + cos_sim > COSINE_THRESHOLD, |
| 194 | + msg=f"ATen Unbind Dynamic TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 195 | + ) |
| 196 | + |
155 | 197 |
|
156 | 198 | if __name__ == "__main__":
|
157 | 199 | unittest.main()
|
0 commit comments