Skip to content

Commit cd24f26

Browse files
committed
fix: Address issues in PR
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 46bb485 commit cd24f26

File tree

4 files changed

+105
-32
lines changed

4 files changed

+105
-32
lines changed

Diff for: core/conversion/conversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
235235
if (!OpSupported(n)) {
236236
auto schema = n->maybeSchema();
237237
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
238-
<< " (conversion.VerifyCoverterSupportForBloxk");
238+
<< " (conversion.VerifyCoverterSupportForBlock");
239239
std::stringstream ss;
240240
ss << *schema;
241241
unsupported_ops.insert(ss.str());

Diff for: cpp/api/README.md

+96-27
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# C++ API
22

3-
Targets in module create the user facing C++ library for the TRTorch core.
3+
Targets in module create the user facing C++ library for the TRTorch core.
44

55
## Building libtrtorch.so
66

7-
### Debug build
7+
### Debug build
88
``` shell
99
bazel build //cpp/api:libtrtorch.so --compilation_mode=dbg
1010
```
@@ -26,12 +26,19 @@ bazel build //cpp/api:libtrtorch.so --cxxopt="-DNDEBUG"
2626
> Temporary, will get real documentation soon
2727
2828
```c++
29+
namespace trtorch {
2930
/**
3031
* Settings data structure for TRTorch compilation
3132
*
3233
*/
3334
struct TRTORCH_API ExtraInfo {
34-
//struct TRTORCH_API InputRangesArray {
35+
/**
36+
* @brief A struct to hold an input range (used by TensorRT Optimization profile)
37+
*
38+
* This struct can either hold a single vector representing an input shape, signifying a
39+
* static input shape or a set of three input shapes representing the min, optiminal and max
40+
* input shapes allowed for the engine.
41+
*/
3542
struct TRTORCH_API InputRange {
3643
std::vector<int64_t> min;
3744
std::vector<int64_t> opt;
@@ -46,7 +53,7 @@ struct TRTORCH_API ExtraInfo {
4653
* Supported Data Types that can be used with TensorRT engines
4754
*
4855
* This class is compatable with c10::DataTypes (but will check for TRT support)
49-
* so there should not be a reason that you need to use this type explictly.
56+
* so there should not be a reason that you need to use this type explictly.
5057
*/
5158
class DataType {
5259
public:
@@ -59,14 +66,14 @@ struct TRTORCH_API ExtraInfo {
5966
* ex. trtorch::DataType type = DataType::kFloat;
6067
*/
6168
enum Value : int8_t {
62-
/// FP32
69+
/// FP32
6370
kFloat,
6471
/// FP16
6572
kHalf,
6673
/// INT8
67-
/*kChar, char or int8? */
74+
kChar,
6875
};
69-
76+
7077
DataType() = default;
7178
constexpr DataType(Value t) : value(t) {}
7279
DataType(c10::ScalarType t);
@@ -83,7 +90,7 @@ struct TRTORCH_API ExtraInfo {
8390
*
8491
* This class is compatable with c10::DeviceTypes (but will check for TRT support)
8592
* but the only applicable value is at::kCUDA, which maps to DeviceType::kGPU
86-
*
93+
*
8794
* To use the DataType class itself, interface using the enum vs. normal instatination
8895
*
8996
* ex. trtorch::DeviceType type = DeviceType::kGPU;
@@ -117,7 +124,7 @@ struct TRTORCH_API ExtraInfo {
117124
};
118125

119126
/**
120-
* Emum for selecting engine capability
127+
* Emum for selecting engine capability
121128
*/
122129
enum class EngineCapability : int8_t {
123130
kDEFAULT,
@@ -129,24 +136,24 @@ struct TRTORCH_API ExtraInfo {
129136
: input_ranges(std::move(input_ranges)) {}
130137
ExtraInfo(std::vector<std::vector<int64_t>> fixed_sizes);
131138
ExtraInfo(std::vector<c10::ArrayRef<int64_t>> fixed_sizes);
132-
139+
133140
// Defaults should reflect TensorRT defaults for BuilderConfig
134141

135-
/**
142+
/**
136143
* Sizes for inputs to engine, can either be a single size or a range
137-
* defined by Min, Optimal, Max sizes
138-
*
139-
* Order is should match call order
144+
* defined by Min, Optimal, Max sizes
145+
*
146+
* Order is should match call order
140147
*/
141148
std::vector<InputRange> input_ranges;
142149

143150
/**
144-
* Default operating precision for the engine
151+
* Default operating precision for the engine
145152
*/
146153
DataType op_precision = DataType::kFloat;
147-
154+
148155
/**
149-
* Build a refitable engine
156+
* Build a refitable engine
150157
*/
151158
bool refit = false;
152159

@@ -158,10 +165,10 @@ struct TRTORCH_API ExtraInfo {
158165
/**
159166
* Restrict operating type to only set default operation precision (op_precision)
160167
*/
161-
bool strict_type = false;
168+
bool strict_types = false;
162169

163170
/**
164-
* (Only used when targeting DLA (device))
171+
* (Only used when targeting DLA (device))
165172
* Lets engine run layers on GPU if they are not supported on DLA
166173
*/
167174
bool allow_gpu_fallback = true;
@@ -189,6 +196,16 @@ struct TRTORCH_API ExtraInfo {
189196
* Maximum size of workspace given to TensorRT
190197
*/
191198
uint64_t workspace_size = 0;
199+
200+
/**
201+
* Maximum batch size (must be =< 1 to be set, 0 means not set)
202+
*/
203+
uint64_t max_batch_size = 0;
204+
205+
/**
206+
* Calibration dataloaders for each input for post training quantizatiom
207+
*/
208+
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
192209
};
193210

194211
/**
@@ -198,37 +215,89 @@ TRTORCH_API std::string get_build_info();
198215

199216
/**
200217
* Dump the version information for TRTorch including base libtorch and TensorRT versions
201-
* to stdout
218+
* to stdout
202219
*/
203220
TRTORCH_API void dump_build_info();
204221

222+
/**
223+
* @brief Check to see if a module is fully supported by the compiler
224+
*
225+
* @param module: torch::jit::script::Module - Existing TorchScript module
226+
* @param method_name: std::string - Name of method to compile
227+
*
228+
* Takes a module and a method name and checks if the method graph contains purely
229+
* convertable operators
230+
*
231+
* Will print out a list of unsupported operators if the graph is unsupported
232+
*/
233+
TRTORCH_API bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name);
234+
205235
/**
206236
* @brief Compile a TorchScript module for NVIDIA GPUs using TensorRT
207237
*
208-
* @param module: torch::jit::script::Module - Existing TorchScript module
209-
* @param info: trtorch::ExtraInfo - Compilation settings
238+
* @param module: torch::jit::script::Module - Existing TorchScript module
239+
* @param info: trtorch::ExtraInfo - Compilation settings
210240
*
211241
* Takes a existing TorchScript module and a set of settings to configure the compiler
212242
* and will convert methods to JIT Graphs which call equivalent TensorRT engines
213243
*
214-
* Converts specifically the forward method of a TorchScript Module
215-
*/
244+
* Converts specifically the forward method of a TorchScript Module
245+
*/
216246
TRTORCH_API torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo info);
217247

218248
/**
219249
* @brief Compile a TorchScript method for NVIDIA GPUs using TensorRT
220250
*
221-
* @param module: torch::jit::script::Module - Existing TorchScript module
251+
* @param module: torch::jit::script::Module - Existing TorchScript module
222252
* @param method_name: std::string - Name of method to compile
223-
* @param info: trtorch::ExtraInfo - Compilation settings
253+
* @param info: trtorch::ExtraInfo - Compilation settings
224254
*
225255
* Takes a existing TorchScript module and a set of settings to configure the compiler
226256
* and will convert selected method to a serialized TensorRT engine which can be run with
227257
* TensorRT
228258
*/
229-
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, ExtraInfo info);
259+
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& module, std::string method_name, ExtraInfo info);
260+
261+
namespace ptq {
262+
/**
263+
* @brief A factory to build a post training quantization calibrator from a torch dataloader
264+
*
265+
* Creates a calibrator to use for post training quantization
266+
* If there are multiple inputs, the dataset should produce a example which is a vector (or similar container) of tensors vs a single tensor
267+
*
268+
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
269+
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
270+
* the calibrator class as a template parameter.
271+
*
272+
* e.g. trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, use_cache);
273+
*/
274+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2, typename DataLoader>
275+
TRTORCH_API inline Int8Calibrator<Algorithm, DataLoader> make_int8_calibrator(DataLoader dataloader, const std::string& cache_file_path, bool use_cache) {
276+
return Int8Calibrator<Algorithm, DataLoader>(std::move(dataloader), cache_file_path, use_cache);
277+
}
278+
279+
/**
280+
* @brief A factory to build a post training quantization calibrator from a torch dataloader that only uses the calibration cache
281+
*
282+
* Creates a calibrator to use for post training quantization which reads from a previously created calibration cache, therefore
283+
* you can have a calibration cache generating program that requires a dataloader and a dataset, then save the cache to use later
284+
* in a different program that needs to calibrate from scratch and not have the dataset dependency. However, the network should also
285+
* be recalibrated if its structure changes, or the input data set changes, and it is the responsibility of the application to ensure this.
286+
*
287+
* By default the returned calibrator uses TensorRT Entropy v2 algorithm to perform calibration. This is recommended for feed forward networks
288+
* You can override the algorithm selection (such as to use the MinMax Calibrator recomended for NLP tasks) by calling make_int8_calibrator with
289+
* the calibrator class as a template parameter.
290+
*
291+
* e.g. trtorch::ptq::make_int8_cache_calibrator<nvinfer1::IInt8MinMaxCalibrator>(calibration_cache_file);
292+
*/
293+
template<typename Algorithm = nvinfer1::IInt8EntropyCalibrator2>
294+
TRTORCH_API inline Int8CacheCalibrator<Algorithm> make_int8_cache_calibrator(const std::string& cache_file_path) {
295+
return Int8CacheCalibrator<Algorithm>(cache_file_path);
296+
}
297+
} // namespace ptq
230298
} // namespace trtorch
231299

300+
232301
```
233302
234303

Diff for: cpp/ptq/training/vgg16/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pip3 install -r requirements.txt --user
1212

1313
The following recipe should get somewhere between 89-92% accuracy on the CIFAR10 testset
1414
```
15-
python3 main.py --lr 0.01 --batch-size 256 --drop-ratio 0.15 --ckpt-dir $(pwd)/vgg16_ckpts --epochs 100
15+
python3 main.py --lr 0.01 --batch-size 128 --drop-ratio 0.15 --ckpt-dir $(pwd)/vgg16_ckpts --epochs 100
1616
```
1717

1818
> 545 was the seed used in testing

Diff for: cpp/ptq/training/vgg16/export_ckpt.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test(model, dataloader, crit):
1919
class_probs = []
2020
class_preds = []
2121
model.eval()
22+
2223
with torch.no_grad():
2324
for data, labels in dataloader:
2425
data, labels = data.cuda(), labels.cuda(async=True)
@@ -53,21 +54,24 @@ def test(model, dataloader, crit):
5354
weights = new_state_dict
5455

5556
model.load_state_dict(weights)
57+
model.eval()
5658

5759
jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda"))
5860

5961
testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True,
6062
transform=transforms.Compose([
6163
transforms.ToTensor(),
62-
transforms.Normalize((0.4914, 0.4822, 0.4465),
63-
(0.2023, 0.1994, 0.2010))]))
64+
transforms.Normalize((0.4914, 0.4822, 0.4465),
65+
(0.2023, 0.1994, 0.2010))]))
6466

6567
testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=32,
6668
shuffle=False, num_workers=2)
6769

6870
crit = torch.nn.CrossEntropyLoss()
6971
test_loss, test_acc = test(model, testing_dataloader, crit)
7072
print("[PTH] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
71-
print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
7273

7374
torch.jit.save(jit_model, "trained_vgg16.jit.pt")
75+
jit_model = torch.jit.load("trained_vgg16.jit.pt")
76+
test_loss, test_acc = test(jit_model, testing_dataloader, crit)
77+
print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

0 commit comments

Comments
 (0)