Skip to content

Commit 547f554

Browse files
committed
feat(trtorchc): Adding new support for dtypes and formats in
trtorchc Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent d6161e3 commit 547f554

File tree

5 files changed

+276
-79
lines changed

5 files changed

+276
-79
lines changed

core/ir/Input.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ Input::Input(std::vector<int64_t> shape, nvinfer1::DataType dtype, nvinfer1::Ten
128128
max = util::toDims(shape);
129129
input_shape = util::toDims(shape);
130130
input_is_dynamic = false;
131-
format = nvinfer1::TensorFormat::kLINEAR;
132-
dtype = dtype;
133131

134132
TRTORCH_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
135133
this->dtype = dtype;
@@ -156,8 +154,6 @@ Input::Input(std::vector<int64_t> min_shape, std::vector<int64_t> opt_shape, std
156154
min = util::toDims(min_shape);
157155
opt = util::toDims(opt_shape);
158156
max = util::toDims(max_shape);
159-
format = nvinfer1::TensorFormat::kLINEAR;
160-
dtype = nvinfer1::DataType::kFLOAT;
161157

162158
std::vector<int64_t> dyn_shape;
163159
for (size_t i = 0; i < opt_shape.size(); i++) {

cpp/api/include/trtorch/trtorch.h

+11-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <cuda_runtime.h>
12+
#include <iostream>
1213
#include <memory>
1314
#include <string>
1415
#include <vector>
@@ -66,10 +67,12 @@ struct TRTORCH_API CompileSpec {
6667
kHalf,
6768
/// INT8
6869
kChar,
69-
/// INT32
70-
kInt32,
70+
/// INT
71+
kInt,
7172
/// Bool
7273
kBool,
74+
/// Sentinel value
75+
kUnknown
7376
};
7477

7578
/**
@@ -139,6 +142,7 @@ struct TRTORCH_API CompileSpec {
139142
}
140143

141144
private:
145+
friend std::ostream& operator<<(std::ostream& os, const DataType& dtype);
142146
Value value;
143147
};
144148

@@ -278,6 +282,8 @@ struct TRTORCH_API CompileSpec {
278282
kContiguous,
279283
/// Channel Last / NHWC
280284
kChannelsLast,
285+
/// Sentinel value
286+
kUnknown,
281287
};
282288

283289
/**
@@ -346,7 +352,9 @@ struct TRTORCH_API CompileSpec {
346352
return value != other;
347353
}
348354

355+
349356
private:
357+
friend std::ostream& operator<<(std::ostream& os, const TensorFormat& format);
350358
Value value;
351359
};
352360

@@ -472,6 +480,7 @@ struct TRTORCH_API CompileSpec {
472480

473481
bool get_explicit_set_dtype() {return explicit_set_dtype;}
474482
private:
483+
friend std::ostream& operator<<(std::ostream& os, const Input& input);
475484
bool input_is_dynamic;
476485
bool explicit_set_dtype;
477486
};

cpp/api/src/compile_spec.cpp

+64-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,74 @@
99

1010
namespace trtorch {
1111

12+
std::ostream& operator<<(std::ostream& os, const CompileSpec::DataType& dtype) {
13+
switch (dtype) {
14+
case CompileSpec::DataType::kChar:
15+
os << "char";
16+
break;
17+
case CompileSpec::DataType::kHalf:
18+
os << "half";
19+
break;
20+
case CompileSpec::DataType::kInt:
21+
os << "int";
22+
break;
23+
case CompileSpec::DataType::kBool:
24+
os << "bool";
25+
break;
26+
case CompileSpec::DataType::kFloat:
27+
os << "float";
28+
break;
29+
case CompileSpec::DataType::kUnknown:
30+
default:
31+
os << "unknown";
32+
break;
33+
}
34+
return os;
35+
}
36+
37+
std::ostream& operator<<(std::ostream& os, const CompileSpec::TensorFormat& format) {
38+
switch (format) {
39+
case CompileSpec::TensorFormat::kChannelsLast:
40+
os << "channels last";
41+
break;
42+
case CompileSpec::TensorFormat::kContiguous:
43+
os << "contiguous";
44+
break;
45+
case CompileSpec::TensorFormat::kUnknown:
46+
default:
47+
os << "unknown";
48+
break;
49+
}
50+
return os;
51+
}
52+
53+
std::ostream& operator<<(std::ostream& os, const CompileSpec::Input& input) {
54+
auto vec_to_str = [](std::vector<int64_t> shape) -> std::string {
55+
std::stringstream ss;
56+
ss << '[';
57+
for (auto i : shape) {
58+
ss << i << ',';
59+
}
60+
ss << ']';
61+
return ss.str();
62+
};
63+
64+
if (!input.input_is_dynamic) {
65+
os << "Input(shape: " << vec_to_str(input.shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
66+
} else {
67+
os << "Input(shape: " << vec_to_str(input.shape) << ", min: " << vec_to_str(input.min_shape) << ", opt: " << vec_to_str(input.opt_shape) << ", max: " << vec_to_str(input.max_shape) << ", dtype: " << input.dtype << ", format: " << input.format << ')';
68+
}
69+
return os;
70+
}
71+
72+
1273
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
1374
switch (value) {
1475
case CompileSpec::DataType::kChar:
1576
return nvinfer1::DataType::kINT8;
1677
case CompileSpec::DataType::kHalf:
1778
return nvinfer1::DataType::kHALF;
18-
case CompileSpec::DataType::kInt32:
79+
case CompileSpec::DataType::kInt:
1980
return nvinfer1::DataType::kINT32;
2081
case CompileSpec::DataType::kBool:
2182
return nvinfer1::DataType::kBOOL;
@@ -47,7 +108,7 @@ CompileSpec::DataType::DataType(c10::ScalarType t) {
47108
value = DataType::kChar;
48109
break;
49110
case at::kInt:
50-
value = DataType::kInt32;
111+
value = DataType::kInt;
51112
break;
52113
case at::kBool:
53114
value = DataType::kBool;
@@ -250,7 +311,6 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
250311
/* We want default behavior for types to match PyTorch, so in the case the user did not explicitly set the dtype for
251312
inputs they will follow PyTorch convetions */
252313
for (size_t i = 0; i < external.inputs.size(); i++) {
253-
std::cout << "EXPLICIT " << external.inputs[i].get_explicit_set_dtype() << std::endl;
254314
if (!external.inputs[i].get_explicit_set_dtype()) {
255315
auto& precisions = internal.convert_info.engine_settings.enabled_precisions;
256316
auto& internal_ins = internal.convert_info.inputs;
@@ -261,9 +321,9 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
261321
} else {
262322
internal_ins[i].dtype = nvinfer1::DataType::kFLOAT;
263323
}
264-
std::cout << "internal type: " << internal_ins[i].dtype;
265324
}
266325
}
326+
267327
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
268328
internal.convert_info.engine_settings.refit = external.refit;
269329
internal.convert_info.engine_settings.debug = external.debug;

cpp/trtorchc/README.md

+21-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ to standard TorchScript. Load with `torch.jit.load()` and run like you would run
1414

1515
```
1616
trtorchc [input_file_path] [output_file_path]
17-
[input_shapes...] {OPTIONS}
17+
[input_specs...] {OPTIONS}
1818
1919
TRTorch is a compiler for TorchScript, it will compile and optimize
2020
TorchScript programs to run on NVIDIA GPUs using TensorRT
@@ -28,24 +28,29 @@ trtorchc [input_file_path] [output_file_path]
2828
-w, --warnings Disables warnings generated during
2929
compilation onto the console (warnings
3030
are on by default)
31-
--info Dumps info messages generated during
31+
--i, --info Dumps info messages generated during
3232
compilation onto the console
3333
--build-debuggable-engine Creates a debuggable engine
3434
--use-strict-types Restrict operating type to only use set
35-
default operation precision
36-
(op_precision)
35+
operation precision
3736
--allow-gpu-fallback (Only used when targeting DLA
3837
(device-type)) Lets engine run layers on
3938
GPU if they are not supported on DLA
40-
-p[precision],
41-
--default-op-precision=[precision]
42-
Default operating precision for the
43-
engine (Int8 requires a
39+
--disable-tf32 Prevent Float32 layers from using the
40+
TF32 data format
41+
-p[precision...],
42+
--enabled-precison=[precision...] (Repeatable) Enabling an operating
43+
precision for kernels to use when
44+
building the engine (Int8 requires a
4445
calibration-cache argument) [ float |
4546
float32 | f32 | half | float16 | f16 |
4647
int8 | i8 ] (default: float)
4748
-d[type], --device-type=[type] The type of device the engine should be
4849
built for [ gpu | dla ] (default: gpu)
50+
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform
51+
(defaults to 0)
52+
--dla-core=[dla_core] DLACore id if running on available DLA
53+
(defaults to 0)
4954
--engine-capability=[capability] The type of device the engine should be
5055
built for [ default | safe_gpu |
5156
safe_dla ]
@@ -72,16 +77,21 @@ trtorchc [input_file_path] [output_file_path]
7277
input_file_path Path to input TorchScript file
7378
output_file_path Path for compiled TorchScript (or
7479
TensorRT engine) file
75-
input_shapes... Sizes for inputs to engine, can either
80+
input_specs... Specs for inputs to engine, can either
7681
be a single size or a range defined by
7782
Min, Optimal, Max sizes, e.g.
7883
"(N,..,C,H,W)"
79-
"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]"
84+
"[(MIN_N,..,MIN_C,MIN_H,MIN_W);(OPT_N,..,OPT_C,OPT_H,OPT_W);(MAX_N,..,MAX_C,MAX_H,MAX_W)]".
85+
Data Type and format can be specified by
86+
adding an "@" followed by dtype and "%"
87+
followed by format to the end of the
88+
shape spec. e.g. "(3, 3, 32,
89+
32)@f16%NHWC"
8090
"--" can be used to terminate flag options and force all following
8191
arguments to be treated as positional options
8292
```
8393

8494
e.g.
8595
```
86-
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
96+
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
8797
```

0 commit comments

Comments
 (0)