Skip to content

Commit 664f117

Browse files
andi4191narendasan
authored andcommitted
feat!(//cpp/bin/torchtrtc): torchtrtc using atol and rtol for tolerance test
BREAKING CHANGE: The flag `--threshold` has been removed in favor of two flags `--atol` and `--rtol` which control the maximum absolute and relative tolerances for numberical deviation Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Anurag Dixit <[email protected]> fix: Fix for rtol and atol tolerance limit in torchtrtc Signed-off-by: Anurag Dixit <[email protected]> feat(//cpp)!: Using logger instead of std::cout Signed-off-by: Anurag Dixit <[email protected]> chore!: Applying C++ lint Signed-off-by: Anurag Dixit <[email protected]> chore!: Updated the tensor names as per review comments Signed-off-by: Anurag Dixit <[email protected]>
1 parent 4ee9dbc commit 664f117

File tree

5 files changed

+57
-25
lines changed

5 files changed

+57
-25
lines changed

cpp/bin/torchtrtc/README.md

+6-4
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,12 @@ torchtrtc [input_file_path] [output_file_path]
8989
used to select kernels
9090
--workspace-size=[workspace_size] Maximum size of workspace given to
9191
TensorRT
92-
-t[threshold],
93-
--threshold=[threshold] Maximum acceptable numerical deviation
94-
from standard torchscript output
95-
(default 2e-5)
92+
--atol=[atol] Absolute tolerance threshold for acceptable
93+
numerical deviation from standard torchscript
94+
output (default 1e-8)
95+
--rtol=[rtol] Relative tolerance threshold for acceptable
96+
numerical deviation from standard torchscript
97+
output (default 1e-5)
9698
--no-threshold-check Skip checking threshold compliance
9799
--truncate-long-double,
98100
--truncate, --truncate-64bit Truncate weights that are provided in

cpp/bin/torchtrtc/accuracy.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,24 @@ bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, fl
1919
return diff.abs().max().item<float>() <= threshold * maxValue;
2020
}
2121

22-
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold) {
23-
return check_rtol(a - b, {a, b}, threshold);
22+
bool almost_equal(
23+
const at::Tensor& computed_tensor,
24+
const at::Tensor& gt_tensor, // gt_tensor : Ground Truth Tensor
25+
float atol,
26+
float rtol) {
27+
auto computed_tensor_float = computed_tensor.toType(at::kFloat);
28+
auto gt_tensor_float = gt_tensor.toType(at::kFloat);
29+
30+
auto diff = computed_tensor_float - gt_tensor_float;
31+
auto result = diff.abs().max().item<float>();
32+
auto threshold = atol + (rtol * gt_tensor.abs().max().item<float>());
33+
34+
torchtrt::logging::log(torchtrt::logging::Level::kDEBUG, std::string("Max Difference: ") + std::to_string(result));
35+
torchtrt::logging::log(
36+
torchtrt::logging::Level::kDEBUG, std::string("Acceptable Threshold: ") + std::to_string(threshold));
37+
38+
return result <= threshold;
2439
}
2540

2641
} // namespace accuracy
27-
} // namespace torchtrtc
42+
} // namespace torchtrtc

cpp/bin/torchtrtc/accuracy.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace torchtrtc {
1212
namespace accuracy {
1313

1414
bool check_rtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs, float threshold);
15-
bool almost_equal(const at::Tensor& a, const at::Tensor& b, float threshold);
15+
bool almost_equal(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5);
1616

1717
} // namespace accuracy
18-
} // namespace torchtrtc
18+
} // namespace torchtrtc

cpp/bin/torchtrtc/main.cpp

+25-12
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,16 @@ int main(int argc, char** argv) {
119119
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
120120
args::ValueFlag<uint64_t> workspace_size(
121121
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
122-
args::ValueFlag<double> threshold(
122+
args::ValueFlag<double> atol(
123123
parser,
124-
"threshold",
125-
"Maximum acceptable numerical deviation from standard torchscript output (default 2e-5)",
126-
{'t', "threshold"});
124+
"atol",
125+
"Absolute tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-8)",
126+
{"atol"});
127+
args::ValueFlag<double> rtol(
128+
parser,
129+
"rtol",
130+
"Relative tolerance threshold for acceptable numerical deviation from standard torchscript output (default 1e-5)",
131+
{"rtol"});
127132

128133
args::Flag no_threshold_check(
129134
parser, "no-threshold-check", "Skip checking threshold compliance", {"no-threshold-check", "no-threshold-check"});
@@ -392,9 +397,13 @@ int main(int argc, char** argv) {
392397
(compile_settings.enabled_precisions.size() == 1 &&
393398
compile_settings.enabled_precisions.find(torchtrt::DataType::kFloat) !=
394399
compile_settings.enabled_precisions.end())) {
395-
double threshold_val = 2e-5;
396-
if (threshold) {
397-
threshold_val = args::get(threshold);
400+
double atol_val = 1e-8;
401+
double rtol_val = 1e-5;
402+
if (atol) {
403+
atol_val = args::get(atol);
404+
}
405+
if (rtol) {
406+
rtol_val = args::get(rtol);
398407
}
399408

400409
std::vector<torch::jit::IValue> jit_inputs_ivalues;
@@ -431,14 +440,18 @@ int main(int argc, char** argv) {
431440
}
432441

433442
for (size_t i = 0; i < trt_results.size(); i++) {
443+
std::ostringstream threshold_ss;
444+
threshold_ss << "atol: " << atol_val << " rtol: " << rtol_val;
434445
if (!torchtrtc::accuracy::almost_equal(
435-
jit_results[i], trt_results[i].reshape_as(jit_results[i]), threshold_val)) {
436-
std::ostringstream threshold_ss;
437-
threshold_ss << threshold_val;
446+
jit_results[i], trt_results[i].reshape_as(jit_results[i]), atol_val, rtol_val)) {
438447
torchtrt::logging::log(
439448
torchtrt::logging::Level::kWARNING,
440-
std::string("Maximum numerical deviation for output exceeds set threshold (") + threshold_ss.str() +
441-
std::string(")"));
449+
std::string("Maximum numerical deviation for output exceeds tolerance thresholds (") +
450+
threshold_ss.str() + std::string(")"));
451+
} else {
452+
torchtrt::logging::log(
453+
torchtrt::logging::Level::kDEBUG,
454+
std::string("Maximum numerical deviation within threshold limits ") + threshold_ss.str());
442455
}
443456
}
444457
} else {

docsrc/tutorials/torchtrtc.rst

+6-4
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
9292
used to select kernels
9393
--workspace-size=[workspace_size] Maximum size of workspace given to
9494
TensorRT
95-
-t[threshold],
96-
--threshold=[threshold] Maximum acceptable numerical deviation
97-
from standard torchscript output
98-
(default 2e-5)
95+
--atol=[atol] Absolute tolerance threshold for acceptable
96+
numerical deviation from standard torchscript
97+
output (default 1e-8)
98+
--rtol=[rtol] Relative tolerance threshold for acceptable
99+
numerical deviation from standard torchscript
100+
output (default 1e-5)
99101
--no-threshold-check Skip checking threshold compliance
100102
--truncate-long-double,
101103
--truncate, --truncate-64bit Truncate weights that are provided in

0 commit comments

Comments
 (0)