Skip to content

Commit 3ee60b7

Browse files
authored
fix: Error on aten::div with truncation (#1442)
- `aten::div` with truncation on integer tensor inputs currently throws an error if both inputs are integer type, as the TRT unary operations for absolute value and floor do not apply to Int32 or Bool types - For absolute value, this is a legitimate bug as `aten::abs` is functional for integer types - For the floor operation, `aten::floor` does not explicitly support integer inputs, and `torch.floor()` does not work with Int32 inputs by default. However, `torch.div(..., rounding_mode="trunc")` with integer tensors does return an integer value, and so the corollary Torch-TRT converter should behave similarly - Modified `aten:abs` converter logic to be a utility, as it is used in multiple locations - Added regression test to ensure truncation divide with two integer tensors is functional - Address comments on PR - Update utility name to add_abs for conciseness - Refactor absolute value utility to return ITensor* - Update logging level for certain debug messages
1 parent db5d290 commit 3ee60b7

File tree

5 files changed

+87
-41
lines changed

5 files changed

+87
-41
lines changed

core/conversion/converters/converter_util.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,38 @@ nvinfer1::ILayer* add_elementwise(
156156
return ele;
157157
}
158158

159+
nvinfer1::ITensor* add_abs(
160+
ConversionCtx* ctx,
161+
const torch::jit::Node* n,
162+
nvinfer1::ITensor* self,
163+
const std::string& name) {
164+
nvinfer1::ILayer* absolute_value_layer;
165+
166+
// Check if TRT Unary ops support the input type
167+
bool unary_supported_input = (self->getType() == nvinfer1::DataType::kFLOAT) ||
168+
(self->getType() == nvinfer1::DataType::kHALF) || (self->getType() == nvinfer1::DataType::kINT8);
169+
if (unary_supported_input) {
170+
absolute_value_layer = ctx->net->addUnary(*self, nvinfer1::UnaryOperation::kABS);
171+
TORCHTRT_CHECK(absolute_value_layer, "Unable to create abs layer from node: " << *n);
172+
absolute_value_layer->setName(name.c_str());
173+
} else {
174+
LOG_GRAPH(
175+
"Tensor is of unsupported type "
176+
<< self->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
177+
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
178+
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(self->getType()));
179+
auto neg_one_const = tensor_to_const(ctx, neg_one);
180+
auto neg_layer = add_elementwise(
181+
ctx, nvinfer1::ElementWiseOperation::kPROD, self, neg_one_const, util::node_info(n) + std::string("_Negation"));
182+
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
183+
absolute_value_layer =
184+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, neg_layer->getOutput(0), name);
185+
TORCHTRT_CHECK(absolute_value_layer, "Unable to create max layer from node: " << *n);
186+
}
187+
188+
return absolute_value_layer->getOutput(0);
189+
}
190+
159191
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& tensor_name) {
160192
auto id_layer = ctx->net->addIdentity(*tensor);
161193
auto id_out_tensor = id_layer->getOutput(0);

core/conversion/converters/converter_util.h

+8
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,21 @@ nvinfer1::ITensor* addUnpadding(
3535
bool trailing = true,
3636
bool use_zeros = true);
3737

38+
// TODO: Change add_elementwise schema to output nvinfer1::ITensor* instead of nvinfer1::ILayer*,
39+
// for consistency with other utils. Need to change schema and usage in all calling contexts
3840
nvinfer1::ILayer* add_elementwise(
3941
ConversionCtx* ctx,
4042
nvinfer1::ElementWiseOperation op,
4143
nvinfer1::ITensor* self,
4244
nvinfer1::ITensor* other,
4345
const std::string& name);
4446

47+
nvinfer1::ITensor* add_abs(
48+
ConversionCtx* ctx,
49+
const torch::jit::Node* n,
50+
nvinfer1::ITensor* self,
51+
const std::string& name);
52+
4553
// Apply an identity operation on a tensor. Used in the case where an input is an output to a network.
4654
nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor, const std::string& name);
4755

core/conversion/converters/impl/element_wise.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -326,15 +326,27 @@ auto element_wise_registrations TORCHTRT_UNUSED =
326326
} else if (rounding_mode == "trunc") {
327327
// trunc = floor(abs(div)) * sign(div)
328328
auto tmp_div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, "tmp_div");
329-
auto abs = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kABS);
330-
auto floor = ctx->net->addUnary(*abs->getOutput(0), nvinfer1::UnaryOperation::kFLOOR);
329+
auto abs = add_abs(ctx, n, tmp_div->getOutput(0), util::node_info(n) + "_absolute_val");
330+
331+
// In this case, we allow the floor unary on non-TRT Unary types, as it is needed for this
332+
// specific function. Floor applied to non-float types equates to identity
333+
nvinfer1::ITensor* floor;
334+
335+
if ((abs->getType() == nvinfer1::DataType::kINT32) || (abs->getType() == nvinfer1::DataType::kBOOL)) {
336+
LOG_DEBUG(
337+
"Tensor is of unsupported type " << abs->getType()
338+
<< " for IUnaryLayer::kFLOOR. Using identity instead.");
339+
floor = abs;
340+
} else {
341+
auto floor_layer = ctx->net->addUnary(*abs, nvinfer1::UnaryOperation::kFLOOR);
342+
TORCHTRT_CHECK(floor_layer, "Unable to create floor layer from node: " << *n);
343+
floor_layer->setName((util::node_info(n) + "_floor").c_str());
344+
floor = floor_layer->getOutput(0);
345+
}
346+
331347
auto sign = ctx->net->addUnary(*tmp_div->getOutput(0), nvinfer1::UnaryOperation::kSIGN);
332348
div = add_elementwise(
333-
ctx,
334-
nvinfer1::ElementWiseOperation::kPROD,
335-
floor->getOutput(0),
336-
sign->getOutput(0),
337-
util::node_info(n));
349+
ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n));
338350
} else {
339351
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
340352
}

core/conversion/converters/impl/unary.cpp

+4-34
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,10 @@ namespace {
1313
auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
1414
{"aten::abs(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1515
auto in = args[0].ITensorOrFreeze(ctx);
16-
bool unary_supported_input = in->getType() == nvinfer1::DataType::kFLOAT ||
17-
in->getType() == nvinfer1::DataType::kHALF || in->getType() == nvinfer1::DataType::kINT8;
18-
if (unary_supported_input) {
19-
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kABS);
20-
TORCHTRT_CHECK(unary_layer, "Unable to create abs layer from node: " << *n);
21-
unary_layer->setName(util::node_info(n).c_str());
22-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
23-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
24-
return true;
25-
} else {
26-
LOG_GRAPH(
27-
"Tensor is of unsupported type "
28-
<< in->getType() << " for IUnaryLayer::kABS. Using backup implementation via IElementWise (max(x, -x)");
29-
// For types not supported by kABS, use an elementwise implementation abs(x) = max(x, -1 * x)
30-
at::Tensor neg_one = torch::full({1}, -1).to(util::TRTDataTypeToScalarType(in->getType()));
31-
auto neg_one_const = tensor_to_const(ctx, neg_one);
32-
auto neg_layer = add_elementwise(
33-
ctx,
34-
nvinfer1::ElementWiseOperation::kPROD,
35-
in,
36-
neg_one_const,
37-
util::node_info(n) + std::string("_Negation"));
38-
TORCHTRT_CHECK(neg_layer, "Unable to create prod layer from node: " << *n);
39-
auto max_layer = add_elementwise(
40-
ctx,
41-
nvinfer1::ElementWiseOperation::kMAX,
42-
in,
43-
neg_layer->getOutput(0),
44-
util::node_info(n) + std::string("_Max"));
45-
TORCHTRT_CHECK(max_layer, "Unable to create max layer from node: " << *n);
46-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], max_layer->getOutput(0));
47-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
48-
return true;
49-
}
16+
auto abs_tensor = add_abs(ctx, n, in, util::node_info(n));
17+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], abs_tensor);
18+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
19+
return true;
5020
}});
5121

5222
auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(

tests/core/conversion/converters/test_element_wise.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "gtest/gtest.h"
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/torch.h"
78

89
void pointwise_test_helper(
910
std::string graph_ir,
@@ -235,6 +236,29 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
235236
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
236237
}
237238

239+
TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) {
240+
const auto graph = R"IR(
241+
graph(%0 : Tensor, %1 : Tensor):
242+
%trunc : str = prim::Constant[value="trunc"]()
243+
%out : Tensor = aten::div(%0, %1, %trunc)
244+
return (%out))IR";
245+
246+
auto g = std::make_shared<torch::jit::Graph>();
247+
torch::jit::parseIR(graph, g.get());
248+
249+
// Avoid divide-by-zero issues by making denominator >= 1
250+
auto in_0 = at::randint(-5, 5, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32);
251+
auto in_1 = at::randint(1, 10, {4, 1, 7, 8}, {at::kCUDA}).to(torch::kInt32);
252+
253+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
254+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in_0, in_1});
255+
256+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
257+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in_0, in_1});
258+
259+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0])));
260+
}
261+
238262
TEST(Converters, ATenPowTensorConvertsCorrectly) {
239263
const auto graph = R"IR(
240264
graph(%x.1 : Tensor, %x2.1 : Tensor):

0 commit comments

Comments
 (0)