Skip to content

Commit 521a0cb

Browse files
committed
fix: Final working version of QAT in TRTorch
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 715120f commit 521a0cb

File tree

16 files changed

+92
-117
lines changed

16 files changed

+92
-117
lines changed

Diff for: core/compiler.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void AddEngineToGraph(
119119

120120
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
121121
// Go through Lowering to simplify graph and extract weight parameters
122-
auto graph_and_parameters = lowering::Lower(mod, method_name);
122+
auto graph_and_parameters = lowering::Lower(mod, method_name, false);
123123

124124
auto g = graph_and_parameters.first;
125125
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
@@ -129,7 +129,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
129129

130130
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
131131
// Go through Lowering to simplify graph and extract weight parameters
132-
auto graph_and_parameters = lowering::Lower(mod, method_name);
132+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.convert_info.engine_settings.unfreeze_module);
133133

134134
auto convert_cfg = std::move(cfg.convert_info);
135135
auto g = graph_and_parameters.first;
@@ -187,7 +187,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
187187
// Compile only forward methods. forward method contains the entire graph.
188188
if (method.name().compare("forward") == 0) {
189189
auto new_g = std::make_shared<torch::jit::Graph>();
190-
auto graph_and_parameters = lowering::Lower(mod, method.name());
190+
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.convert_info.engine_settings.unfreeze_module);
191191

192192
auto g = graph_and_parameters.first;
193193
auto params = graph_and_parameters.second;

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7272
if (!settings.calibrator) {
7373
LOG_WARNING(
7474
"Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
75-
} else{
75+
} else {
7676
cfg->setInt8Calibrator(settings.calibrator);
7777
}
7878
break;

Diff for: core/conversion/conversionctx/ConversionCtx.h

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ struct BuilderSettings {
2727
bool sparse_weights = false;
2828
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
2929
bool disable_tf32 = false;
30+
// Internal flag to ensure torch.jit.Module does not get freezed in lowering.cpp. This is required for QAT models.
31+
bool unfreeze_module = false;
3032
bool refit = false;
3133
bool debug = false;
3234
bool strict_types = false;

Diff for: core/conversion/converters/impl/matrix_multiply.cpp

+1-45
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ auto mm_registrations TRTORCH_UNUSED =
2626

2727
auto mm_layer = ctx->net->addMatrixMultiply(
2828
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
29+
2930
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
3031
mm_layer->setName(util::node_info(n).c_str());
3132
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
@@ -73,51 +74,6 @@ auto mm_registrations TRTORCH_UNUSED =
7374

7475
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
7576
return true;
76-
}})
77-
.pattern(
78-
{"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
79-
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80-
auto self = args[0].ITensorOrFreeze(ctx);
81-
auto mat1 = args[1].ITensorOrFreeze(ctx);
82-
auto mat2 = args[2].ITensorOrFreeze(ctx);
83-
auto beta = args[3].unwrapToScalar().to<float>();
84-
auto betaTensor = tensor_to_const(ctx, torch::tensor({beta}));
85-
auto alpha = args[4].unwrapToScalar().to<float>();
86-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha}));
87-
88-
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
89-
// necessary.
90-
if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) {
91-
mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false);
92-
} else {
93-
mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false);
94-
}
95-
96-
auto mm_layer = ctx->net->addMatrixMultiply(
97-
*mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE);
98-
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n);
99-
auto mm_scale_layer = add_elementwise(
100-
ctx,
101-
nvinfer1::ElementWiseOperation::kPROD,
102-
mm_layer->getOutput(0),
103-
alphaTensor,
104-
util::node_info(n) + "_alphaScale");
105-
TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n);
106-
auto beta_scale_layer = add_elementwise(
107-
ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale");
108-
TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n);
109-
auto add_mm_layer = add_elementwise(
110-
ctx,
111-
nvinfer1::ElementWiseOperation::kSUM,
112-
beta_scale_layer->getOutput(0),
113-
mm_scale_layer->getOutput(0),
114-
util::node_info(n));
115-
TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n);
116-
117-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
118-
119-
LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions());
120-
return true;
12177
}});
12278
} // namespace
12379
} // namespace impl

Diff for: core/conversion/converters/impl/shuffle.cpp

+13-6
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,22 @@ static auto shuffle_registrations TRTORCH_UNUSED =
131131
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
132132
auto in = args[0].ITensorOrFreeze(ctx);
133133
auto input_dims = in->getDimensions();
134-
nvinfer1::Dims transposed_input_dims;
135-
transposed_input_dims.nbDims = input_dims.nbDims;
136-
for (int i = input_dims.nbDims - 1; i >= 0; i--) {
137-
transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i];
134+
// For input tensors < 2D, return them as is
135+
// For a 2D input tensor, return transpose(input, 0, 1) which is a general 2d matrix transpose.
136+
if (input_dims.nbDims < 2) {
137+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], in);
138+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
139+
return true;
138140
}
141+
139142
auto shuffle_layer = ctx->net->addShuffle(*in);
140143
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
141-
shuffle_layer->setReshapeDimensions(transposed_input_dims);
142-
shuffle_layer->setZeroIsPlaceholder(true);
144+
nvinfer1::Permutation firstPerm;
145+
firstPerm.order[0] = 1;
146+
firstPerm.order[1] = 0;
147+
148+
shuffle_layer->setFirstTranspose(firstPerm);
149+
shuffle_layer->setZeroIsPlaceholder(false);
143150
shuffle_layer->setName(util::node_info(n).c_str());
144151

145152
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0));

Diff for: core/conversion/evaluators/aten.cpp

-30
Original file line numberDiff line numberDiff line change
@@ -427,36 +427,6 @@ auto aten_registrations TRTORCH_UNUSED =
427427
EvalOptions().validSchemas({
428428
"aten::numel(Tensor self) -> int",
429429
})})
430-
// .evaluator({c10::Symbol::fromQualString("aten::t"),
431-
// [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
432-
// auto tensor_var = args.at(n->input(0));
433-
// if (tensor_var.isIValue() && tensor_var.IValue()->isTensor()) {
434-
// auto tensor = tensor_var.unwrapToTensor();
435-
// return tensor.t();
436-
// } else if (tensor_var.isITensor()) {
437-
// auto input_tensor = tensor_var.ITensor();
438-
// auto input_dims = input_tensor->getDimensions();
439-
// LOG_DEBUG("[aten::t] INPUT TENSOR DIMS: " << input_dims);
440-
// // nvinfer1::Dims transposed_input_dims;
441-
// // for (int i = input_dims.nbDims - 1; i >= 0; i--) {
442-
// // transposed_input_dims.d[i] = input_dims.d[input_dims.nbDims - 1 - i];
443-
// // }
444-
// // auto shuffle_layer = ctx->net->addShuffle(*input_tensor);
445-
// // shuffle_layer->setReshapeDimensions(transposed_input_dims);
446-
// // shuffle_layer->setZeroIsPlaceholder(true);
447-
// // auto output_tensor = shuffle_layer->getOutput(0);
448-
// auto tensor_holder = TensorContainer();
449-
// tensor_holder.hold_tensor(input_tensor);
450-
// auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
451-
// return ival;
452-
// } else {
453-
// TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
454-
// return {};
455-
// }
456-
// },
457-
// EvalOptions().validSchemas({
458-
// "aten::t(Tensor self) -> Tensor",
459-
// })})
460430
.evaluator({c10::Symbol::fromQualString("aten::dim"),
461431
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
462432
auto tensor_var = args.at(n->input(0));

Diff for: core/lowering/lowering.cpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
2424
DropUnusedNodes(b);
2525
}
2626

27-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
27+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse) {
2828
passes::UnpackHardSwish(g);
2929
torch::jit::EliminateRedundantGuards(g);
3030
torch::jit::RemoveListMutation(g);
@@ -42,9 +42,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4242
passes::Conv3DToConvolution(g);
4343
passes::FuseAddMMBranches(g);
4444
passes::RemoveBNDimCheck(g);
45-
LOG_INFO("====PRE CSE =====" << *g);
46-
// torch::jit::EliminateCommonSubexpression(g);
47-
LOG_INFO("====POST CSE =====" << *g);
45+
if (!disable_cse) {
46+
torch::jit::EliminateCommonSubexpression(g);
47+
}
4848
// torch::jit::UnrollLoops(g);
4949
passes::UnpackAddMM(g);
5050
// passes::UnpackBatchNorm(g);
@@ -57,25 +57,36 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
5757
}
5858

5959
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
60+
LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module");
6061
auto mod_ = torch::jit::freeze_module(mod);
6162
return mod_;
6263
}
6364

6465
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
6566
const torch::jit::script::Module& mod,
66-
std::string method_name) {
67-
auto lowered_mod = mod; // LowerModule(mod);
67+
std::string method_name,
68+
bool unfreeze_module = false) {
69+
auto lowered_mod = unfreeze_module ? mod : LowerModule(mod);
6870
auto g = lowered_mod.get_method(method_name).graph();
6971
LOG_GRAPH(*g);
7072

7173
// Go through TRTorch Lowering to reformat graph to be conversion friendly
7274
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
73-
LOG_GRAPH("TRTorch Graph Lowering");
74-
// lowering::LowerGraph(g);
75+
// unfreeze_module is used to not perform constant folding on weights in the network.
76+
// In quantization aware trained (QAT) models, weights are passed through quantize and
77+
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
78+
if (!unfreeze_module) {
79+
LOG_GRAPH("TRTorch Graph Lowering");
80+
lowering::LowerGraph(g, false);
81+
}
7582

7683
LOG_GRAPH("LibTorch Lowering");
7784
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
78-
lowering::LowerGraph(graph_and_ivalues.first);
85+
86+
if (unfreeze_module) {
87+
LOG_GRAPH("TRTorch Graph Lowering");
88+
lowering::LowerGraph(graph_and_ivalues.first, true);
89+
}
7990
// Is this necessary?
8091
lowering::LowerBlock(g->block());
8192

Diff for: core/lowering/lowering.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ namespace core {
77
namespace lowering {
88

99
void LowerBlock(torch::jit::Block* b);
10-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g);
10+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, bool disable_cse /*=false*/);
1111
torch::jit::Module LowerModule(const torch::jit::script::Module& mod);
1212
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
1313
const torch::jit::script::Module& mod,
14-
std::string method_name);
14+
std::string method_name,
15+
bool unfreeze_module /*=false*/);
1516

1617
} // namespace lowering
1718
} // namespace core

Diff for: cpp/api/include/trtorch/trtorch.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,9 @@ struct TRTORCH_API CompileSpec {
262262
* Emum for selecting engine capability
263263
*/
264264
enum class EngineCapability : int8_t {
265-
kDEFAULT,
266-
kSAFE_GPU,
267-
kSAFE_DLA,
265+
kSTANDARD,
266+
kSAFETY,
267+
kDLA_STANDALONE,
268268
};
269269

270270
class TRTORCH_API TensorFormat {
@@ -686,12 +686,12 @@ struct TRTORCH_API CompileSpec {
686686
* This is the behavior of FP32 layers by default.
687687
*/
688688
bool disable_tf32 = false;
689-
690-
/**
689+
690+
/**
691691
* Enable sparsity for weights of conv and FC layers
692692
*/
693693
bool sparse_weights = false;
694-
694+
695695
/**
696696
* Build a refitable engine
697697
*/

Diff for: cpp/api/src/compile_spec.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,13 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
405405

406406
if (internal.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
407407
internal.convert_info.engine_settings.enabled_precisions.end()) {
408-
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
408+
if (external.ptq_calibrator) {
409+
internal.convert_info.engine_settings.calibrator = external.ptq_calibrator;
410+
} else {
411+
;
412+
internal.convert_info.engine_settings.unfreeze_module = true;
413+
internal.convert_info.engine_settings.calibrator = nullptr;
414+
}
409415
} else {
410416
internal.convert_info.engine_settings.calibrator = nullptr;
411417
}

Diff for: py/trtorch/csrc/register_tensorrt_classes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void RegisterTRTCompileSpec() {
4747
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
4848
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
4949
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
50-
50+
5151
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, sparse_weights);
5252
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, disable_tf32);
5353
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, refit);

Diff for: py/trtorch/csrc/tensorrt_backend.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
3232
const auto& method_name = it->key();
3333
auto method = mod.get_method(method_name);
3434
auto graph = method.graph();
35-
core::lowering::LowerGraph(graph);
35+
core::lowering::LowerGraph(graph, false);
3636
}
3737

3838
auto handles = c10::impl::GenericDict(

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
181181
for (auto p : enabled_precisions) {
182182
info.convert_info.engine_settings.enabled_precisions.insert(toTRTDataType(p));
183183
}
184-
185-
info.convert_info.engine_settings.calibrator = ptq_calibrator;
184+
if (ptq_calibrator) {
185+
info.convert_info.engine_settings.calibrator = ptq_calibrator;
186+
} else {
187+
if (info.convert_info.engine_settings.enabled_precisions.find(nvinfer1::DataType::kINT8) !=
188+
info.convert_info.engine_settings.enabled_precisions.end()) {
189+
std::cout << "===INTERNAL UNFREEZE MODULE TRUE===" << std::endl;
190+
info.convert_info.engine_settings.unfreeze_module = true;
191+
}
192+
}
186193
info.convert_info.engine_settings.sparse_weights = sparse_weights;
187194
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
188195
info.convert_info.engine_settings.refit = refit;

Diff for: tests/core/conversion/converters/test_shuffle.cpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,30 @@ TEST(Converters, ATenTransposeConvertsCorrectly) {
241241
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
242242
}
243243

244+
TEST(Converters, ATenTConvertsCorrectly) {
245+
const auto graph = R"IR(
246+
graph(%x.1 : Tensor):
247+
%out : Tensor = aten::t(%x.1)
248+
return (%out))IR";
249+
250+
auto g = std::make_shared<torch::jit::Graph>();
251+
torch::jit::parseIR(graph, &*g);
252+
253+
auto in = at::randint(0, 5, {3, 4}, {at::kCUDA});
254+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
255+
256+
std::cout << "Running JIT" << std::endl;
257+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
258+
259+
std::cout << "Running TRT" << std::endl;
260+
in = at::clone(in);
261+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
262+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
263+
auto trt = trt_results[0].reshape_as(jit_results[0]);
264+
265+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
266+
}
267+
244268
TEST(Converters, ATenTransposeNegativeConvertsCorrectly) {
245269
const auto graph = R"IR(
246270
graph(%x.1 : Tensor):
@@ -312,7 +336,6 @@ TEST(Converters, ATenPixelShuffle3DConvertsCorrectly) {
312336
in = at::clone(in);
313337
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
314338
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
315-
// auto trt = trt_results[0].reshape_as(jit_results[0]);
316339

317340
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
318341
}

Diff for: tests/modules/hub.py

-8
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,10 @@
5454
"model": torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True),
5555
"path": "both"
5656
},
57-
"fcn_resnet101": {
58-
"model": torch.hub.load('pytorch/vision:v0.9.0', 'fcn_resnet101', pretrained=True),
59-
"path": "script"
60-
},
6157
"ssd": {
6258
"model": torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math="fp32"),
6359
"path": "trace"
6460
},
65-
"faster_rcnn": {
66-
"model": models.detection.fasterrcnn_resnet50_fpn(pretrained=True),
67-
"path": "script"
68-
},
6961
"efficientnet_b0": {
7062
"model": timm.create_model('efficientnet_b0', pretrained=True),
7163
"path": "script"

Diff for: tests/util/run_graph_engine.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ std::vector<at::Tensor> RunGraphEngine(
6969
auto in = toInputs(inputs);
7070
auto info = core::conversion::ConversionInfo(in);
7171
info.engine_settings.workspace_size = 1 << 20;
72-
info.engine_settings.op_precision = op_precision;
72+
info.engine_settings.enabled_precisions.insert(op_precision);
7373
std::string eng = core::conversion::ConvertBlockToEngine(g->block(), info, named_params);
7474
return RunEngine(eng, inputs);
7575
}

0 commit comments

Comments
 (0)