Skip to content

Commit 1b25542

Browse files
committed
feat(//cpp/api): Adding max batch size setting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent fc70267 commit 1b25542

File tree

4 files changed

+34
-16
lines changed

4 files changed

+34
-16
lines changed

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+23-12
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,24 @@ namespace core {
99
namespace conversion {
1010

1111
std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
12-
os << "Settings requested for TensorRT engine:" \
13-
<< "\n Operating Precision: " << s.op_precision \
14-
<< "\n Make Refittable Engine: " << s.refit \
15-
<< "\n Debuggable Engine: " << s.debug \
16-
<< "\n Strict Type: " << s.strict_type \
17-
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
18-
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
19-
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
20-
<< "\n Max Workspace Size: " << s.workspace_size \
21-
<< "\n Device Type: " << s.device \
22-
<< "\n Engine Capability: " << s.capability \
12+
os << "Settings requested for TensorRT engine:" \
13+
<< "\n Operating Precision: " << s.op_precision \
14+
<< "\n Make Refittable Engine: " << s.refit \
15+
<< "\n Debuggable Engine: " << s.debug \
16+
<< "\n Strict Type: " << s.strict_types \
17+
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
18+
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
19+
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
20+
<< "\n Max Workspace Size: " << s.workspace_size;
21+
22+
if (s.max_batch_size != 0) {
23+
os << "\n Max Batch Size: " << s.max_batch_size;
24+
} else {
25+
os << "\n Max Batch Size: Not set";
26+
}
27+
28+
os << "\n Device Type: " << s.device \
29+
<< "\n Engine Capability: " << s.capability \
2330
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
2431
return os;
2532
}
@@ -62,14 +69,18 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
6269
cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
6370
}
6471

65-
if (settings.strict_type) {
72+
if (settings.strict_types) {
6673
cfg->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
6774
}
6875

6976
if (settings.allow_gpu_fallback) {
7077
cfg->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
7178
}
7279

80+
if (settings.max_batch_size != 0) {
81+
builder->setMaxBatchSize(settings.max_batch_size);
82+
}
83+
7384
cfg->setMinTimingIterations(settings.num_min_timing_iters);
7485
cfg->setAvgTimingIterations(settings.num_avg_timing_iters);
7586
cfg->setMaxWorkspaceSize(settings.workspace_size);

Diff for: core/conversion/conversionctx/ConversionCtx.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@ struct BuilderSettings {
2020
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
2121
bool refit = false;
2222
bool debug = false;
23-
bool strict_type = false;
23+
bool strict_types = false;
2424
bool allow_gpu_fallback = true;
2525
nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU;
2626
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
2727
nvinfer1::IInt8Calibrator* calibrator = nullptr;
2828
uint64_t num_min_timing_iters = 2;
2929
uint64_t num_avg_timing_iters = 1;
3030
uint64_t workspace_size = 0;
31+
uint64_t max_batch_size = 0;
3132

3233
BuilderSettings() = default;
3334
BuilderSettings(const BuilderSettings& other) = default;

Diff for: cpp/api/include/trtorch/trtorch.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ struct TRTORCH_API ExtraInfo {
175175
/**
176176
* Restrict operating type to only set default operation precision (op_precision)
177177
*/
178-
bool strict_type = false;
178+
bool strict_types = false;
179179

180180
/**
181181
* (Only used when targeting DLA (device))
@@ -205,7 +205,12 @@ struct TRTORCH_API ExtraInfo {
205205
/**
206206
* Maximum size of workspace given to TensorRT
207207
*/
208-
uint64_t workspace_size = 1 << 20;
208+
uint64_t workspace_size = 0;
209+
210+
/**
211+
* Maximum batch size (must be =< 1 to be set, 0 means not set)
212+
*/
213+
uint64_t max_batch_size = 0;
209214

210215
/**
211216
* Calibration dataloaders for each input for post training quantizatiom

Diff for: cpp/api/src/extra_info.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ core::ExtraInfo to_internal_extra_info(ExtraInfo external) {
9191

9292
internal.convert_info.engine_settings.refit = external.refit;
9393
internal.convert_info.engine_settings.debug = external.debug;
94-
internal.convert_info.engine_settings.strict_type = external.strict_type;
94+
internal.convert_info.engine_settings.strict_types = external.strict_types;
9595
internal.convert_info.engine_settings.allow_gpu_fallback = external.allow_gpu_fallback;
96+
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;
9697

9798
switch(external.device) {
9899
case ExtraInfo::DeviceType::kDLA:

0 commit comments

Comments
 (0)