Skip to content

Commit 01d525d

Browse files
committed
feat(//py): Allow example tensors from torch to set shape
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 15e6863 commit 01d525d

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

Diff for: py/trtorch/Input.py

+7
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,10 @@ def _parse_format(format: Any) -> _types.TensorFormat:
196196
else:
197197
raise TypeError(
198198
"Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
199+
200+
@classmethod
201+
def _from_tensor(cls, t: torch.Tensor):
202+
if not any([t.is_contiguous(memory_format=torch.contiguous_format), t.is_contiguous(memory_format=torch.channels_last)]):
203+
raise ValueError("Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last")
204+
frmt = torch.contiguous_format if t.is_contiguous(memory_format=torch.contiguous_format) else torch.channels_last
205+
return cls(shape=t.shape, dtype=t.dtype, format=frmt)

Diff for: py/trtorch/_compile_spec.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
174174
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])
175175

176176
if "inputs" in compile_spec:
177-
info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
177+
if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]):
178+
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format([typeof(i) for i in compile_spec["inputs"]]))
179+
180+
inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
181+
info.inputs = [i._to_internal() for i in inputs]
178182

179183
if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
180184
raise KeyError(

Diff for: tests/py/test_api.py

+24
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,30 @@ def test_compile_script(self):
7373
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
7474
self.assertTrue(same < 2e-2)
7575

76+
def test_from_torch_tensor(self):
77+
compile_spec = {
78+
"inputs": [self.input],
79+
"device": {
80+
"device_type": trtorch.DeviceType.GPU,
81+
"gpu_id": 0,
82+
},
83+
"enabled_precisions": {torch.float}
84+
}
85+
86+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
87+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
88+
self.assertTrue(same < 2e-2)
89+
90+
def test_device(self):
91+
compile_spec = {
92+
"inputs": [self.input],
93+
"device": trtorch.Device("gpu:0"),
94+
"enabled_precisions": {torch.float}
95+
}
96+
97+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
98+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
99+
self.assertTrue(same < 2e-2)
76100

77101
class TestCompileHalf(ModelTestCase):
78102

0 commit comments

Comments
 (0)