Skip to content

Commit 90af26e

Browse files
committed
feat: Adding automatic casting to compare layers
BERT converts but produces NaNs Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent ee2455e commit 90af26e

File tree

6 files changed

+119
-15
lines changed

6 files changed

+119
-15
lines changed

Diff for: core/conversion/converters/Weights.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
114114
// clang-format off
115115
std::ostream& operator<<(std::ostream& os, const Weights& w) {
116116
os << "Weights: " << w.shape
117+
<< "\n Data Type: " << w.data.type
117118
<< "\n Number of input maps: " << w.num_input_maps
118119
<< "\n Number of output maps: " << w.num_output_maps
119120
<< "\n Element shape: [";

Diff for: core/conversion/converters/impl/element_wise.cpp

+49-8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,27 @@ nvinfer1::ITensor* clamp_util(
2525
return clamp_layer_out;
2626
}
2727

28+
nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
29+
nvinfer1::ITensor* out;
30+
if (s.isIntegral(false)) {
31+
auto s_int = s.to<int64_t>();
32+
auto s_t = torch::tensor({s_int}).to(at::kInt);
33+
out = tensor_to_const(ctx, s_t);
34+
} else if (s.isBoolean()) {
35+
auto s_bool = s.to<bool>();
36+
auto s_t = torch::tensor({s_bool}).to(at::kBool);
37+
out = tensor_to_const(ctx, s_t);
38+
} else if (s.isFloatingPoint()) {
39+
auto other_float = s.to<float>();
40+
auto s_t = torch::tensor({other_float});
41+
out = tensor_to_const(ctx, s_t);
42+
} else {
43+
out = nullptr;
44+
TRTORCH_THROW_ERROR("Unsupported data type for scalar. Found: (" << s.type() << ")");
45+
}
46+
return out;
47+
}
48+
2849
auto element_wise_registrations TRTORCH_UNUSED =
2950
RegisterNodeConversionPatterns()
3051
.pattern({"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> "
@@ -557,8 +578,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
557578
.pattern({"aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)",
558579
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
559580
auto self = args[0].ITensorOrFreeze(ctx);
560-
auto otherScalar = args[1].unwrapToScalar().to<float>();
561-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
581+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
582+
if (self->getType() != other->getType()) {
583+
other = castITensor(ctx, other, self->getType());
584+
}
562585
auto gt =
563586
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n));
564587
TRTORCH_CHECK(gt, "Unable to create greater layer from node: " << *n);
@@ -584,8 +607,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
584607
.pattern({"aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)",
585608
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
586609
auto self = args[0].ITensorOrFreeze(ctx);
587-
auto otherScalar = args[1].unwrapToScalar().to<float>();
588-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
610+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
611+
if (self->getType() != other->getType()) {
612+
other = castITensor(ctx, other, self->getType());
613+
}
589614
auto lt =
590615
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n));
591616
TRTORCH_CHECK(lt, "Unable to create less layer from node: " << *n);
@@ -613,6 +638,18 @@ auto element_wise_registrations TRTORCH_UNUSED =
613638
auto self = args[0].ITensorOrFreeze(ctx);
614639
auto otherScalar = args[1].unwrapToScalar().to<float>();
615640
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
641+
if (self->getType() == nvinfer1::DataType::kBOOL) {
642+
if (otherScalar == 0 || otherScalar == 1) {
643+
LOG_DEBUG("Since input tensor is type bool, casting input tensor and scalar to int32");
644+
other = castITensor(ctx, other, nvinfer1::DataType::kINT32);
645+
self = castITensor(ctx, self, nvinfer1::DataType::kINT32);
646+
} else {
647+
LOG_WARNING("Input Tensor has type bool, but scalar is not 0 or 1. Found: " << otherScalar);
648+
}
649+
}
650+
if (self->getType() != other->getType()) {
651+
other = castITensor(ctx, other, self->getType());
652+
}
616653
auto eq =
617654
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kEQUAL, self, other, util::node_info(n));
618655
TRTORCH_CHECK(eq, "Unable to create equal layer from node: " << *n);
@@ -648,8 +685,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
648685
.pattern({"aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)",
649686
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
650687
auto self = args[0].ITensorOrFreeze(ctx);
651-
auto otherScalar = args[1].unwrapToScalar().to<float>();
652-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
688+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
689+
if (self->getType() != other->getType()) {
690+
other = castITensor(ctx, other, self->getType());
691+
}
653692

654693
auto greater = add_elementwise(
655694
ctx, nvinfer1::ElementWiseOperation::kGREATER, self, other, util::node_info(n) + "_greater");
@@ -695,8 +734,10 @@ auto element_wise_registrations TRTORCH_UNUSED =
695734
.pattern({"aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)",
696735
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
697736
auto self = args[0].ITensorOrFreeze(ctx);
698-
auto otherScalar = args[1].unwrapToScalar().to<float>();
699-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
737+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
738+
if (self->getType() != other->getType()) {
739+
other = castITensor(ctx, other, self->getType());
740+
}
700741

701742
auto less = add_elementwise(
702743
ctx, nvinfer1::ElementWiseOperation::kLESS, self, other, util::node_info(n) + "_less");

Diff for: core/conversion/converters/impl/reduce.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,11 @@ auto reduce_registrations TRTORCH_UNUSED =
144144
.pattern({"aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
145145
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
146146
auto in_tensor = args[0].ITensorOrFreeze(ctx);
147+
auto in_dims = in_tensor->getDimensions();
147148
auto dim = args[1].unwrapToInt();
148-
LOG_DEBUG("Dim to reduce: " << dim); // Some abuse of toDim but just for debug info
149+
LOG_DEBUG("Dim to reduce (original): " << dim);
150+
dim = dim < 0 ? (in_dims.nbDims + dim) : dim;
151+
LOG_DEBUG("Dim to reduce (converted): " << dim);
149152

150153
uint32_t axis_mask = 1 << dim;
151154
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));

Diff for: core/conversion/evaluators/aten.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,14 @@ auto aten_registrations TRTORCH_UNUSED =
187187
if (tensor_var.isITensor()) {
188188
auto tensor = tensor_var.ITensor();
189189
return util::toVec(tensor->getDimensions());
190-
} else {
190+
} else if (tensor_var.IValue()->isTensor()) {
191191
auto tensor = tensor_var.unwrapToTensor();
192192
return tensor.sizes();
193+
} else if (tensor_var.IValue()->isCustomClass()) {
194+
auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
195+
return util::toVec(tensor->getDimensions());
196+
} else {
197+
TRTORCH_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
193198
}
194199
} else {
195200
auto dim = args.at(n->input(1)).unwrapToInt();
@@ -201,13 +206,23 @@ auto aten_registrations TRTORCH_UNUSED =
201206
dim += nbDims;
202207
}
203208
return dims[dim];
204-
} else {
209+
} else if (tensor_var.IValue()->isTensor()) {
205210
auto tensor = tensor_var.unwrapToTensor();
206211
auto nbDims = tensor.sizes().size();
207212
if (dim < 0) {
208213
dim += nbDims;
209214
}
210215
return tensor.sizes()[dim];
216+
} else if (tensor_var.IValue()->isCustomClass()) {
217+
auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
218+
auto dims = util::toVec(tensor->getDimensions());
219+
auto nbDims = tensor->getDimensions().nbDims;
220+
if (dim < 0) {
221+
dim += nbDims;
222+
}
223+
return dims[dim];
224+
} else {
225+
TRTORCH_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
211226
}
212227
}
213228
},

Diff for: core/conversion/var/Var.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
132132
out = ptr_.tensor;
133133
}
134134

135-
LOG_DEBUG("Frozen tensor shape: " << out->getDimensions());
135+
LOG_DEBUG("ITensor shape: " << out->getDimensions());
136+
LOG_DEBUG("ITensor type: " << out->getType());
136137
return out;
137138
}
138139

Diff for: cpp/trtorchc/main.cpp

+46-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,34 @@
2525
#include "trtorch/ptq.h"
2626
#include "trtorch/trtorch.h"
2727

28+
29+
at::ScalarType to_torch_dtype(trtorch::CompileSpec::DataType dtype) {
30+
switch (dtype) {
31+
case trtorch::CompileSpec::DataType::kHalf:
32+
return at::kHalf;
33+
case trtorch::CompileSpec::DataType::kChar:
34+
return at::kChar;
35+
case trtorch::CompileSpec::DataType::kInt:
36+
return at::kInt;
37+
case trtorch::CompileSpec::DataType::kBool:
38+
return at::kBool;
39+
case trtorch::CompileSpec::DataType::kFloat:
40+
default:
41+
return at::kFloat;
42+
}
43+
}
44+
45+
const std::unordered_map<nvinfer1::DataType, at::ScalarType>& get_trt_at_type_map() {
46+
static const std::unordered_map<nvinfer1::DataType, at::ScalarType> trt_at_type_map = {
47+
{nvinfer1::DataType::kFLOAT, at::kFloat},
48+
{nvinfer1::DataType::kHALF, at::kHalf},
49+
{nvinfer1::DataType::kINT32, at::kInt},
50+
{nvinfer1::DataType::kINT8, at::kChar},
51+
{nvinfer1::DataType::kBOOL, at::kBool},
52+
};
53+
return trt_at_type_map;
54+
}
55+
2856
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold) {
2957
double maxValue = 0.0;
3058
for (auto& tensor : inputs) {
@@ -238,6 +266,9 @@ int main(int argc, char** argv) {
238266
"Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)",
239267
{'t', "threshold"});
240268

269+
args::Flag no_threshold_check(parser, "no-threshold-check", "Skip checking threshold compliance", {"no-threshold-check", "no-threshold-check"});
270+
args::Flag truncate_long_and_double(parser, "truncate-long-double", "Truncate weights that are provided in 64bit to 32bit (Long, Double to Int, Float)", {"truncate", "truncate-long-double", "truncate-64bit"});
271+
241272
args::Flag save_engine(
242273
parser,
243274
"save_engine",
@@ -481,6 +512,10 @@ int main(int argc, char** argv) {
481512
compile_settings.max_batch_size = args::get(max_batch_size);
482513
}
483514

515+
if (truncate_long_and_double) {
516+
compile_settings.truncate_long_and_double = true;
517+
}
518+
484519
auto real_input_path = resolve_path(args::get(input_path));
485520
auto real_output_path = resolve_path(args::get(output_path));
486521

@@ -507,9 +542,9 @@ int main(int argc, char** argv) {
507542
} else {
508543
auto trt_mod = trtorch::CompileGraph(mod, compile_settings);
509544

510-
if (compile_settings.enabled_precisions.size() == 1 &&
545+
if (!no_threshold_check && (compile_settings.enabled_precisions.size() == 1 &&
511546
compile_settings.enabled_precisions.find(trtorch::CompileSpec::DataType::kFloat) !=
512-
compile_settings.enabled_precisions.end()) {
547+
compile_settings.enabled_precisions.end())) {
513548
double threshold_val = 2e-5;
514549
if (threshold) {
515550
threshold_val = args::get(threshold);
@@ -520,10 +555,12 @@ int main(int argc, char** argv) {
520555

521556
for (auto i : ranges) {
522557
auto in = at::randn(i.opt_shape, {at::kCUDA});
558+
in = in.to(to_torch_dtype(i.dtype));
523559
jit_inputs_ivalues.push_back(in.clone());
524560
trt_inputs_ivalues.push_back(in.clone());
525561
}
526562

563+
mod.to({at::kCUDA});
527564
torch::jit::IValue jit_results_ivalues = mod.forward(jit_inputs_ivalues);
528565
std::vector<at::Tensor> jit_results;
529566
if (jit_results_ivalues.isTensor()) {
@@ -557,9 +594,15 @@ int main(int argc, char** argv) {
557594
}
558595
}
559596
} else {
560-
trtorch::logging::log(
597+
if (no_threshold_check) {
598+
trtorch::logging::log(
599+
trtorch::logging::Level::kWARNING,
600+
"Threshold check skipped, numerical precision is not checked");
601+
} else {
602+
trtorch::logging::log(
561603
trtorch::logging::Level::kWARNING,
562604
"Due to change in operating data type, numerical precision is not checked");
605+
}
563606
}
564607

565608
trt_mod.save(real_output_path);

0 commit comments

Comments
 (0)