Skip to content

Commit c76a28a

Browse files
committed
feat: Enable TRT 8.0 QAT functionality in TRTorch
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5708634 commit c76a28a

File tree

7 files changed

+81
-20
lines changed

7 files changed

+81
-20
lines changed

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7272
input_type = nvinfer1::DataType::kFLOAT;
7373
// TRTORCH_CHECK(
7474
// settings.calibrator != nullptr,
75-
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
75+
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec
76+
// struct with your calibrator");
7677
// cfg->setInt8Calibrator(settings.calibrator);
7778
break;
7879
case nvinfer1::DataType::kFLOAT:

Diff for: core/conversion/converters/impl/conv_deconv.cpp

+11-9
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,23 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
4545
if (args[2].IValue()->isTensor()) {
4646
bias = Weights(ctx, args[2].unwrapToTensor());
4747
} else {
48-
bias = Weights(); //nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
48+
bias = Weights(); // nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
4949
}
5050

5151
// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
5252
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
5353
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54-
if (args[1].isITensor()){
54+
if (args[1].isITensor()) {
5555
// Get the kernel tensor
5656
auto kernel = args[1].ITensor();
5757
auto kernel_dims = kernel->getDimensions();
5858

5959
// Make a new Dims with only the spatial dimensions.
6060
nvinfer1::Dims filter_dim;
6161
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
62-
TRTORCH_CHECK(nbSpatialDims = kernel_dims.nbDims - 2, "Number of input spatial dimensions should match the kernel spatial dimensions");
62+
TRTORCH_CHECK(
63+
nbSpatialDims = kernel_dims.nbDims - 2,
64+
"Number of input spatial dimensions should match the kernel spatial dimensions");
6365
filter_dim.nbDims = nbSpatialDims;
6466
filter_dim.d[0] = kernel_dims.d[2];
6567
filter_dim.d[1] = kernel_dims.d[3];
@@ -68,9 +70,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
6870
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
6971

7072
nvinfer1::ILayer* layer = nullptr;
71-
if (transposed){
72-
nvinfer1::IDeconvolutionLayer* deconvLayer
73-
= ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
73+
if (transposed) {
74+
nvinfer1::IDeconvolutionLayer* deconvLayer =
75+
ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
7476
deconvLayer->setStrideNd(stride);
7577
deconvLayer->setDilationNd(dilation);
7678
deconvLayer->setNbGroups(groups);
@@ -79,9 +81,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
7981
deconvLayer->setInput(1, *kernel);
8082
TRTORCH_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
8183
layer = deconvLayer;
82-
} else{
83-
nvinfer1::IConvolutionLayer* convLayer
84-
= ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
84+
} else {
85+
nvinfer1::IConvolutionLayer* convLayer =
86+
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
8587
convLayer->setStrideNd(stride);
8688
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
8789
convLayer->setPaddingNd(padding);

Diff for: core/conversion/converters/impl/linear.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,17 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
4242

4343
// Get the bias
4444
Weights bias;
45-
if(!args[2].IValue()->isNone()){
45+
if (!args[2].IValue()->isNone()) {
4646
bias = Weights(ctx, args[2].IValue()->toTensor());
47-
}else {
47+
} else {
4848
bias = Weights();
4949
}
5050

5151
// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
5252
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53-
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54-
if(args[1].isITensor()){
53+
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in
54+
// impl/quantization.cpp
55+
if (args[1].isITensor()) {
5556
auto kernel_tensor = args[1].ITensor();
5657
auto kernel_dims = args[1].ITensor()->getDimensions();
5758
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.

Diff for: core/conversion/converters/impl/matrix_multiply.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <torch/torch.h>
12
#include "core/conversion/converters/converter_util.h"
23
#include "core/conversion/converters/converters.h"
34
#include "core/util/prelude.h"
@@ -72,6 +73,60 @@ auto mm_registrations TRTORCH_UNUSED =
7273

7374
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
7475
return true;
76+
}})
77+
.pattern(
78+
{"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)",
79+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
80+
auto self = args[0].ITensorOrFreeze(ctx);
81+
auto mat1 = args[1].ITensorOrFreeze(ctx);
82+
auto mat2 = args[2].ITensorOrFreeze(ctx);
83+
auto beta = args[4].unwrapToScalar().to<float>();
84+
auto betaTensor = tensor_to_const(ctx, torch::tensor({beta}));
85+
auto alpha = args[5].unwrapToScalar().to<float>();
86+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha}));
87+
88+
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
89+
// necessary.
90+
if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) {
91+
mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false);
92+
} else {
93+
mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false);
94+
}
95+
96+
auto mat2_dims = mat2->getDimensions();
97+
nvinfer1::Dims transposed_mat2_dims;
98+
for (int i = mat2_dims.nbDims - 1; i >= 0; i--) {
99+
transposed_mat2_dims.d[i] = mat2_dims.d[mat2_dims.nbDims - 1 - i];
100+
}
101+
auto shuffle_layer = ctx->net->addShuffle(*mat2);
102+
shuffle_layer->setReshapeDimensions(transposed_mat2_dims);
103+
mat2 = shuffle_layer->getOutput(0);
104+
105+
auto mm_layer = ctx->net->addMatrixMultiply(
106+
*mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE);
107+
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n);
108+
auto mm_scale_layer = add_elementwise(
109+
ctx,
110+
nvinfer1::ElementWiseOperation::kPROD,
111+
mm_layer->getOutput(0),
112+
alphaTensor,
113+
util::node_info(n) + "_alphaScale");
114+
TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n);
115+
auto beta_scale_layer = add_elementwise(
116+
ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale");
117+
TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n);
118+
auto add_mm_layer = add_elementwise(
119+
ctx,
120+
nvinfer1::ElementWiseOperation::kSUM,
121+
beta_scale_layer->getOutput(0),
122+
mm_scale_layer->getOutput(0),
123+
util::node_info(n));
124+
TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n);
125+
126+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], add_mm_layer->getOutput(0));
127+
128+
LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions());
129+
return true;
75130
}});
76131
} // namespace
77132
} // namespace impl

Diff for: core/conversion/evaluators/aten.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,14 @@ auto aten_registrations TRTORCH_UNUSED =
430430
.evaluator({c10::Symbol::fromQualString("aten::t"),
431431
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
432432
auto tensor_var = args.at(n->input(0));
433-
if (tensor_var.IValue()->isTensor()) {
433+
if (tensor_var.isIValue() && tensor_var.IValue()->isTensor()) {
434434
auto tensor = tensor_var.unwrapToTensor();
435435
return tensor.t();
436+
} else if (tensor_var.isITensor()) {
437+
auto tensor_holder = TensorContainer();
438+
tensor_holder.hold_tensor(tensor_var.ITensor());
439+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
440+
return ival;
436441
} else {
437442
TRTORCH_THROW_ERROR("Unimplemented data type for aten::t evaluator: ITensor");
438443
return {};

Diff for: core/lowering/lowering.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
6363
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
6464
const torch::jit::script::Module& mod,
6565
std::string method_name) {
66-
auto lowered_mod = LowerModule(mod);
66+
auto lowered_mod = mod; // LowerModule(mod);
6767
auto g = lowered_mod.get_method(method_name).graph();
68-
LOG_GRAPH(*g);
68+
LOG_INFO(*g);
6969

7070
// Go through TRTorch Lowering to reformat graph to be conversion friendly
7171
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)

Diff for: core/plugins/impl/interpolate_plugin.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
105105
return size_;
106106
}
107107

108-
109108
int InterpolatePlugin::getNbOutputs() const noexcept {
110109
if (mode_ == "adaptive_max_pool2d") {
111110
return 2;
@@ -170,7 +169,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
170169
return nvinfer1::DataType::kFLOAT;
171170
}
172171

173-
174172
int InterpolatePlugin::initialize() noexcept {
175173
return 0;
176174
}
@@ -208,7 +206,6 @@ bool InterpolatePlugin::supportsFormatCombination(
208206
const nvinfer1::PluginTensorDesc* inOut,
209207
int nbInputs,
210208
int nbOutputs) noexcept {
211-
212209
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
213210

214211
if (mode_ == "adaptive_max_pool2d") {

0 commit comments

Comments
 (0)