Skip to content

Commit 8dd0e34

Browse files
committed
feat(//py): Use TensorRT to fill in .so libraries automatically if
possible Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e9e824c commit 8dd0e34

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

py/setup.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
JETPACK_VERSION = None
2424

2525
__version__ = '1.2.0a0'
26-
26+
__cuda_version__ = '11.3'
27+
__cudnn_version__ = '8.2'
28+
__tensorrt_version__ = '8.2'
2729

2830
def get_git_revision_short_hash() -> str:
2931
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip()
@@ -114,7 +116,10 @@ def gen_version_file():
114116

115117
with open(dir_path + '/torch_tensorrt/_version.py', 'w') as f:
116118
print("creating version file")
117-
f.write("__version__ = \"" + __version__ + '\"')
119+
f.write("__version__ = \"" + __version__ + '\"\n')
120+
f.write("__cuda_version__ = \"" + __cuda_version__ + '\"\n')
121+
f.write("__cudnn_version__ = \"" + __cudnn_version__ + '\"\n')
122+
f.write("__tensorrt_version__ = \"" + __tensorrt_version__ + '\"\n')
118123

119124

120125
def copy_libtorchtrt(multilinux=False):

py/torch_tensorrt/__init__.py

+67-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,78 @@
1+
import ctypes
2+
import glob
13
import os
24
import sys
5+
import warnings
6+
from torch_tensorrt._version import __version__, __cuda_version__, __cudnn_version__, __tensorrt_version__
7+
38

49
if sys.version_info < (3,):
510
raise Exception("Python 2 has reached end-of-life and is not supported by Torch-TensorRT")
611

7-
import ctypes
12+
def _parse_semver(version):
13+
split = version.split(".")
14+
if len(split) < 3:
15+
split.append("")
16+
17+
return {
18+
"major": split[0],
19+
"minor": split[1],
20+
"patch": split[2]
21+
}
22+
23+
def _find_lib(name, paths):
24+
for path in paths:
25+
libpath = os.path.join(path, name)
26+
if os.path.isfile(libpath):
27+
return libpath
28+
29+
raise FileNotFoundError(
30+
f"Could not find {name}\n Search paths: {paths}"
31+
)
32+
33+
try:
34+
import tensorrt
35+
except:
36+
cuda_version = _parse_semver(__cuda_version__)
37+
cudnn_version = _parse_semver(__cudnn_version__)
38+
tensorrt_version = _parse_semver(__tensorrt_version__)
39+
40+
CUDA_MAJOR = cuda_version["major"]
41+
CUDNN_MAJOR = cudnn_version["major"]
42+
TENSORRT_MAJOR = tensorrt_version["major"]
43+
44+
if sys.platform.startswith("win"):
45+
WIN_LIBS = [
46+
f"cublas64_{CUDA_MAJOR}.dll",
47+
f"cublasLt64_{CUDA_MAJOR}.dll",
48+
f"cudnn64_{CUDNN_MAJOR}.dll",
49+
"nvinfer.dll",
50+
"nvinfer_plugin.dll",
51+
]
52+
53+
WIN_PATHS = os.environ["PATH"].split(os.path.pathsep)
54+
55+
56+
for lib in WIN_LIBS:
57+
ctypes.CDLL(_find_lib(lib, WIN_PATHS))
58+
59+
elif sys.platform.startswith("linux"):
60+
LINUX_PATHS = [
61+
"/usr/lib/x86_64-linux-gnu",
62+
"/usr/local/cuda/lib64",
63+
] + os.environ["LD_LIBRARY_PATH"].split(os.path.pathsep)
64+
65+
LINUX_LIBS = [
66+
f"libcudnn.so.{CUDNN_MAJOR}",
67+
f"libnvinfer.so.{TENSORRT_MAJOR}",
68+
f"libnvinfer_plugin.so.{TENSORRT_MAJOR}",
69+
]
70+
71+
for lib in LINUX_LIBS:
72+
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))
73+
874
import torch
975

10-
from torch_tensorrt._version import __version__
1176
from torch_tensorrt._compile import *
1277
from torch_tensorrt._util import *
1378
from torch_tensorrt import ts

0 commit comments

Comments
 (0)