Skip to content

Commit a6ba374

Browse files
committed
fix bugs and update tests
1 parent 04d0ccf commit a6ba374

File tree

3 files changed

+127
-1
lines changed

3 files changed

+127
-1
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+6
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,12 @@ def convert_method_to_trt_engine(
397397
f"Conversion of module {module} not currently fully supported or convertible!",
398398
exc_info=True,
399399
)
400+
except Exception as e:
401+
logger.error(
402+
f"While interpreting the module got an error: {e}",
403+
exc_info=True,
404+
)
405+
400406
import io
401407

402408
with io.BytesIO() as engine_bytes:

py/torch_tensorrt/dynamo/conversion/_conversion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
TRTInterpreterResult,
1313
)
1414
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
15-
from torch_tensorrt.dynamo.utils import get_torch_inputs
15+
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1616

1717

1818
def interpret_module(
@@ -22,6 +22,7 @@ def interpret_module(
2222
name: str = "",
2323
) -> TRTInterpreterResult:
2424
torch_inputs = get_torch_inputs(inputs, settings.device)
25+
module.to(to_torch_device(settings.device))
2526
module_outputs = module(*torch_inputs)
2627

2728
if not isinstance(module_outputs, (list, tuple)):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import unittest
2+
3+
import numpy as np
4+
import pycuda.driver as cuda
5+
import tensorrt as trt
6+
import torch
7+
import torch_tensorrt
8+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
9+
10+
try:
11+
import pycuda.autoprimaryctx
12+
except ModuleNotFoundError:
13+
import pycuda.autoinit
14+
15+
16+
class HostDeviceMem(object):
17+
def __init__(self, host_mem, device_mem):
18+
self.host = host_mem
19+
self.device = device_mem
20+
21+
def __str__(self):
22+
return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
23+
24+
def __repr__(self):
25+
return self.__str__()
26+
27+
28+
def allocate_buffers(engine):
29+
inputs = []
30+
outputs = []
31+
bindings = []
32+
stream = cuda.Stream()
33+
for binding in engine:
34+
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
35+
dtype = trt.nptype(engine.get_binding_dtype(binding))
36+
# Allocate host and device buffers
37+
host_mem = cuda.pagelocked_empty(size, dtype)
38+
device_mem = cuda.mem_alloc(host_mem.nbytes)
39+
# Append the device buffer to device bindings.
40+
bindings.append(int(device_mem))
41+
# Append to the appropriate list.
42+
if engine.binding_is_input(binding):
43+
inputs.append(HostDeviceMem(host_mem, device_mem))
44+
else:
45+
outputs.append(HostDeviceMem(host_mem, device_mem))
46+
return inputs, outputs, bindings, stream
47+
48+
49+
def do_inference_v2(context, bindings, inputs, outputs, stream):
50+
# Transfer input data to the GPU.
51+
[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
52+
# Run inference.
53+
context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
54+
# Transfer predictions back from the GPU.
55+
[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
56+
# Synchronize the stream
57+
stream.synchronize()
58+
# Return only the host outputs.
59+
return [out.host for out in outputs]
60+
61+
gt_tensor = gt_tensor.flatten().to(torch.float32)
62+
pred_tensor = pred_tensor.flatten().to(torch.float32)
63+
if torch.sum(gt_tensor) == 0.0 or torch.sum(pred_tensor) == 0.0:
64+
if torch.allclose(gt_tensor, pred_tensor, atol=1e-4, rtol=1e-4, equal_nan=True):
65+
return 1.0
66+
res = torch.nn.functional.cosine_similarity(gt_tensor, pred_tensor, dim=0, eps=1e-6)
67+
res = res.cpu().detach().item()
68+
69+
return res
70+
71+
72+
class TestConvertMethodToTrtEngine(unittest.TestCase):
73+
def test_convert_module(self):
74+
class Test(torch.nn.Module):
75+
def forward(self, a, b):
76+
return torch.add(a, b)
77+
78+
# Prepare the input data
79+
input_data_0, input_data_1 = torch.randn((2, 4)), torch.randn((2, 4))
80+
81+
# Create a model
82+
model = Test()
83+
symbolic_traced_gm = torch.fx.symbolic_trace(model)
84+
85+
# Convert to TensorRT engine
86+
trt_engine_str = torch_tensorrt.dynamo.convert_method_to_trt_engine(
87+
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
88+
)
89+
90+
# Deserialize the TensorRT engine
91+
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
92+
engine = runtime.deserialize_cuda_engine(trt_engine_str)
93+
94+
# Allocate memory for inputs and outputs
95+
inputs, outputs, bindings, stream = allocate_buffers(engine)
96+
context = engine.create_execution_context()
97+
98+
# Copy input data to buffer (need .ravel() here, as the inputs[0] buffer is (4,) not (2, 2))
99+
np.copyto(inputs[0].host, input_data_0.ravel())
100+
np.copyto(inputs[1].host, input_data_1.ravel())
101+
102+
# Inference on TRT Engine
103+
trt_outputs = do_inference_v2(
104+
context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream
105+
)
106+
trt_output = torch.from_numpy(trt_outputs[0])
107+
108+
# Inference on PyTorch model
109+
model_output = model(input_data_0, input_data_1)
110+
111+
cos_sim = cosine_similarity(model_output, trt_output)
112+
self.assertTrue(
113+
cos_sim > COSINE_THRESHOLD,
114+
msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
115+
)
116+
117+
118+
if __name__ == "__main__":
119+
unittest.main()

0 commit comments

Comments
 (0)