4
4
from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set
5
5
6
6
import numpy as np
7
+ import tensorrt as trt
7
8
import torch
8
9
import torch .fx
9
10
from torch .fx .node import _get_qualified_name
23
24
from torch_tensorrt .fx .observer import Observer
24
25
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
25
26
26
- # @manual=//deeplearning/trt/python:py_tensorrt
27
- import tensorrt as trt
28
27
from packaging import version
29
28
30
29
_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -96,6 +95,7 @@ def __init__(
96
95
self ._itensor_to_tensor_meta : Dict [
97
96
trt .tensorrt .ITensor , TensorMetadata
98
97
] = dict ()
98
+ self .compilation_settings = compilation_settings
99
99
100
100
# Data types for TRT Module output Tensors
101
101
self .output_dtypes = output_dtypes
@@ -118,40 +118,25 @@ def validate_conversion(self) -> Set[str]:
118
118
119
119
def run (
120
120
self ,
121
- workspace_size : int = 0 ,
122
- precision : torch .dtype = torch .float32 , # TODO: @peri044 Needs to be expanded to set
123
- sparse_weights : bool = False ,
124
- disable_tf32 : bool = False ,
125
121
force_fp32_output : bool = False ,
126
122
strict_type_constraints : bool = False ,
127
123
algorithm_selector : Optional [trt .IAlgorithmSelector ] = None ,
128
124
timing_cache : Optional [trt .ITimingCache ] = None ,
129
- profiling_verbosity : Optional [trt .ProfilingVerbosity ] = None ,
130
125
tactic_sources : Optional [int ] = None ,
131
- max_aux_streams : Optional [int ] = None ,
132
- version_compatible : bool = False ,
133
- optimization_level : Optional [int ] = None ,
134
126
) -> TRTInterpreterResult :
135
127
"""
136
128
Build TensorRT engine with some configs.
137
129
Args:
138
- workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
139
- precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
140
- sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
141
130
force_fp32_output: force output to be fp32
142
131
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
143
132
algorithm_selector: set up algorithm selection for certain layer
144
133
timing_cache: enable timing cache for TensorRT
145
- profiling_verbosity: TensorRT logging level
146
- max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
147
- version_compatible: Provide version forward-compatibility for engine plan files
148
- optimization_level: Builder optimization 0-5, higher levels imply longer build time,
149
- searching for more optimization options. TRT defaults to 3
150
134
Return:
151
135
TRTInterpreterResult
152
136
"""
153
137
TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
154
138
139
+ precision = self .compilation_settings .precision
155
140
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
156
141
# force_fp32_output=False. Overriden by specifying output_dtypes
157
142
self .output_fp16 = not force_fp32_output and precision == torch .float16
@@ -172,9 +157,9 @@ def run(
172
157
173
158
builder_config = self .builder .create_builder_config ()
174
159
175
- if workspace_size != 0 :
160
+ if self . compilation_settings . workspace_size != 0 :
176
161
builder_config .set_memory_pool_limit (
177
- trt .MemoryPoolType .WORKSPACE , workspace_size
162
+ trt .MemoryPoolType .WORKSPACE , self . compilation_settings . workspace_size
178
163
)
179
164
180
165
cache = None
@@ -187,34 +172,66 @@ def run(
187
172
188
173
if version .parse (trt .__version__ ) >= version .parse ("8.2" ):
189
174
builder_config .profiling_verbosity = (
190
- profiling_verbosity
191
- if profiling_verbosity
175
+ trt . ProfilingVerbosity . VERBOSE
176
+ if self . compilation_settings . debug
192
177
else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
193
178
)
194
179
195
180
if version .parse (trt .__version__ ) >= version .parse ("8.6" ):
196
- if max_aux_streams is not None :
197
- _LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
198
- builder_config .max_aux_streams = max_aux_streams
199
- if version_compatible :
181
+ if self .compilation_settings .max_aux_streams is not None :
182
+ _LOGGER .info (
183
+ f"Setting max aux streams to { self .compilation_settings .max_aux_streams } "
184
+ )
185
+ builder_config .max_aux_streams = (
186
+ self .compilation_settings .max_aux_streams
187
+ )
188
+ if self .compilation_settings .version_compatible :
200
189
_LOGGER .info ("Using version compatible" )
201
190
builder_config .set_flag (trt .BuilderFlag .VERSION_COMPATIBLE )
202
- if optimization_level is not None :
203
- _LOGGER .info (f"Using optimization level { optimization_level } " )
204
- builder_config .builder_optimization_level = optimization_level
191
+ if self .compilation_settings .optimization_level is not None :
192
+ _LOGGER .info (
193
+ f"Using optimization level { self .compilation_settings .optimization_level } "
194
+ )
195
+ builder_config .builder_optimization_level = (
196
+ self .compilation_settings .optimization_level
197
+ )
198
+
199
+ builder_config .engine_capability = self .compilation_settings .engine_capability
200
+ builder_config .avg_timing_iterations = (
201
+ self .compilation_settings .num_avg_timing_iters
202
+ )
203
+
204
+ if self .compilation_settings .device .device_type == trt .DeviceType .DLA :
205
+ builder_config .DLA_core = self .compilation_settings .device .dla_core
206
+ _LOGGER .info (f"Using DLA core { self .compilation_settings .device .dla_core } " )
207
+ builder_config .set_memory_pool_limit (
208
+ trt .MemoryPoolType .DLA_MANAGED_SRAM ,
209
+ self .compilation_settings .dla_sram_size ,
210
+ )
211
+ builder_config .set_memory_pool_limit (
212
+ trt .MemoryPoolType .DLA_LOCAL_DRAM ,
213
+ self .compilation_settings .dla_local_dram_size ,
214
+ )
215
+ builder_config .set_memory_pool_limit (
216
+ trt .MemoryPoolType .DLA_GLOBAL_DRAM ,
217
+ self .compilation_settings .dla_global_dram_size ,
218
+ )
205
219
206
220
if precision == torch .float16 :
207
221
builder_config .set_flag (trt .BuilderFlag .FP16 )
208
222
209
223
if precision == torch .int8 :
210
224
builder_config .set_flag (trt .BuilderFlag .INT8 )
211
225
212
- if sparse_weights :
226
+ if self . compilation_settings . sparse_weights :
213
227
builder_config .set_flag (trt .BuilderFlag .SPARSE_WEIGHTS )
214
228
215
- if disable_tf32 :
229
+ if self . compilation_settings . disable_tf32 :
216
230
builder_config .clear_flag (trt .BuilderFlag .TF32 )
217
231
232
+ if self .compilation_settings .refit :
233
+ builder_config .set_flag (trt .BuilderFlag .REFIT )
234
+
218
235
if strict_type_constraints :
219
236
builder_config .set_flag (trt .BuilderFlag .STRICT_TYPES )
220
237
0 commit comments