Skip to content

Commit 2deab59

Browse files
committed
fix: Issue in TS dimension-squeeze utility
1 parent 5de208f commit 2deab59

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

core/util/trt_util.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros, boo
216216
// Replace all instances of -1, indicating dynamic dimension
217217
// with 0, indicating copy the dimension from another tensor
218218
// (Generally used for reshape operations)
219-
if (use_zeros && d.d[i] == -1) {
219+
if (use_zeros && d.d[i] == -1 && i < pos) {
220220
dims.d[j] = 0;
221221
// If zeros already exist in the dimensions (empty tensor),
222222
// Replace all instances of 0, indicating empty dimension

tests/py/ts/models/test_models.py

+46-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import copy
12
import unittest
2-
import torch_tensorrt as torchtrt
3+
from typing import Dict
4+
5+
import custom_models as cm
6+
import timm
37
import torch
8+
import torch_tensorrt as torchtrt
49
import torchvision.models as models
5-
import copy
6-
import timm
7-
import custom_models as cm
8-
from typing import Dict
9-
from utils import cosine_similarity, COSINE_THRESHOLD
10+
from utils import COSINE_THRESHOLD, cosine_similarity
1011

1112

1213
class TestModels(unittest.TestCase):
@@ -152,6 +153,45 @@ def test_resnet18_half(self):
152153
msg=f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
153154
)
154155

156+
def test_aten_unbind_dynamic(self):
157+
class ATenUnbindDynamic(torch.nn.Module):
158+
def __init__(self) -> None:
159+
super().__init__()
160+
161+
def forward(self, x):
162+
x1, x2, x3 = x.unbind(1)
163+
y = torch.cat([x1, x2, x3], dim=0)
164+
return y
165+
166+
self.model = ATenUnbindDynamic().eval().to("cuda")
167+
self.input = torch.randn((5, 3, 7, 64)).to("cuda")
168+
self.scripted_model = torch.jit.script(self.model)
169+
170+
compile_spec = {
171+
"inputs": [
172+
torchtrt.Input(
173+
min_shape=[1, 3, 1, 64],
174+
opt_shape=[5, 3, 32, 64],
175+
max_shape=[10, 3, 64, 64],
176+
dtype=torch.float,
177+
format=torch.contiguous_format,
178+
)
179+
],
180+
"device": {
181+
"device_type": torchtrt.DeviceType.GPU,
182+
"gpu_id": 0,
183+
},
184+
"enabled_precisions": {torch.float},
185+
"ir": "ts",
186+
}
187+
188+
trt_mod = torchtrt.compile(self.scripted_model, **compile_spec)
189+
cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input))
190+
self.assertTrue(
191+
cos_sim > COSINE_THRESHOLD,
192+
msg=f"ATen Unbind Dynamic TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
193+
)
194+
155195

156196
if __name__ == "__main__":
157197
unittest.main()

0 commit comments

Comments
 (0)