Skip to content

Commit ffbcc7a

Browse files
authored
small fix: Index validator enable int64 (#2642)
1 parent e38a7f3 commit ffbcc7a

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

examples/dynamo/torch_compile_advanced_usage.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
4343
# For the default settings, we can simply call torch.compile
4444
# with the backend "torch_tensorrt", and run the model on an
4545
# input to cause compilation, as so:
46-
optimized_model = torch.compile(model, backend="torch_tensorrt")
46+
optimized_model = torch.compile(model, backend="torch_tensorrt", dynamic=False)
4747
optimized_model(*sample_inputs)
4848

4949
# %%
@@ -81,7 +81,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
8181

8282
# Run the model on an input to cause compilation, as so:
8383
optimized_model_custom = torch.compile(
84-
model_half, backend="torch_tensorrt", options=backend_kwargs
84+
model_half,
85+
backend="torch_tensorrt",
86+
options=backend_kwargs,
87+
dynamic=False,
8588
)
8689
optimized_model_custom(*sample_inputs_half)
8790

examples/dynamo/torch_compile_transformers_example.py

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
optimized_model = torch.compile(
6262
model,
6363
backend="torch_tensorrt",
64+
dynamic=False,
6465
options=compilation_kwargs,
6566
)
6667
optimized_model(*inputs)

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def index_dtype_validator(node: Node) -> bool:
397397
for ind in index:
398398
if ind is not None:
399399
val = ind.meta.get("val")
400-
if val is not None and val.dtype != torch.int32:
400+
if val is not None and val.dtype not in (torch.int32, torch.int64):
401401
return False
402402
return True
403403

tests/py/dynamo/conversion/test_index_aten.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import operator
2-
31
import torch
42
import torch.nn as nn
5-
from .harness import DispatchTestCase
63
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt import Input
4+
5+
from .harness import DispatchTestCase
86

97

108
class TestIndexConverter(DispatchTestCase):
@@ -15,7 +13,6 @@ def __init__(self):
1513
super().__init__()
1614

1715
def forward(self, x):
18-
index0 = torch.randint(0, 1, (1, 1))
1916
indices = [None, self.index0]
2017
out = torch.ops.aten.index.Tensor(x, indices)
2118
return out
@@ -158,8 +155,6 @@ def __init__(self):
158155
super().__init__()
159156

160157
def forward(self, x):
161-
index0 = torch.tensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7])
162-
index1 = index0.unsqueeze(0).T.long()
163158
indices = [None, None, self.index0, self.index1]
164159
out = torch.ops.aten.index.Tensor(x, indices)
165160
return out

0 commit comments

Comments
 (0)