Skip to content

Commit 4aa4bd7

Browse files
committed
fixed redundant code in TRT Interpreter
1 parent 2368e63 commit 4aa4bd7

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3434
DYNAMO_CONVERTERS as CONVERTERS,
3535
)
36-
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
36+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
37+
CallingConvention,
38+
)
3739
from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor
3840
from torch_tensorrt.dynamo.conversion.converter_utils import (
3941
get_node_io,
@@ -740,10 +742,6 @@ def run(
740742
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
741743
self._cur_node_name = get_node_name(n)
742744
self._cur_node = n
743-
# add "_itensor_to_tensor_meta"
744-
kwargs = dict(n.kwargs)
745-
kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
746-
n.kwargs = kwargs
747745

748746
if _LOGGER.isEnabledFor(logging.DEBUG):
749747
_LOGGER.debug(
@@ -759,11 +757,6 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
759757
f"Converted node {self._cur_node_name} [{n.target}] ({get_node_io(n, self.const_mapping)})"
760758
)
761759

762-
# remove "_itensor_to_tensor_meta"
763-
kwargs = dict(n.kwargs)
764-
del kwargs["_itensor_to_tensor_meta"]
765-
n.kwargs = kwargs
766-
767760
if isinstance(trt_node, trt.ITensor):
768761
self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta")
769762

py/torch_tensorrt/fx/converters/acc_ops_converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3221,7 +3221,7 @@ def acc_ops_dequantize(
32213221
name: str,
32223222
) -> Union[TRTTensor, Sequence[TRTTensor]]:
32233223
input_val = kwargs["input"]
3224-
input_val_tensor_meta = kwargs["_itensor_to_tensor_meta"][input_val] # type: ignore[index]
3224+
input_val_tensor_meta = network._itensor_to_tensor_meta[input_val] # type: ignore[index]
32253225

32263226
if not isinstance(input_val, TRTTensor):
32273227
raise RuntimeError(

0 commit comments

Comments
 (0)