Skip to content

Commit 450c869

Browse files
author
Wei
authored
enable direct call to fx.compile() (#1344)
* enable direct call to fx.compile() * Update lower_example.py * Update _compile.py
1 parent 837a85c commit 450c869

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

examples/fx/lower_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
import torchvision
7-
from torch_tensorrt.fx.lower import compile
7+
from torch_tensorrt.fx import compile
88
from torch_tensorrt.fx.utils import LowerPrecision
99

1010

py/torch_tensorrt/_compile.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from enum import Enum
88

99
import torch_tensorrt.fx
10-
import torch_tensorrt.fx.lower
1110
from torch_tensorrt.fx.utils import LowerPrecision
1211

1312

@@ -140,7 +139,7 @@ def compile(
140139
else:
141140
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
142141

143-
return torch_tensorrt.fx.lower.compile(
142+
return torch_tensorrt.fx.compile(
144143
module,
145144
inputs,
146145
lower_precision=lower_precision,

py/torch_tensorrt/fx/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
1212
from .lower_setting import LowerSetting # noqa
1313
from .trt_module import TRTModule # noqa
14+
from .lower import compile # usort: skip #noqa
1415

1516
logging.basicConfig(level=logging.INFO)

0 commit comments

Comments
 (0)