Skip to content

Commit 42798cc

Browse files
committed
Adding test case for index
1 parent 4058533 commit 42798cc

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
class TestIndexConverter(DispatchTestCase):
9+
def test_index(self):
10+
class TestModule(nn.Module):
11+
def forward(self, x):
12+
input = torch.randn(2, 1280, 8, 8)
13+
index0 = torch.randint(0, 16, (1, 16))
14+
index1 = torch.randint(0, 16, (1, 16))
15+
out = torch.ops.aten.index(None, None, index0, index1)
16+
17+
inputs = [torch.randn(1, 10)]
18+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.index.Tensor})

0 commit comments

Comments
 (0)