Skip to content

Commit 5e39222

Browse files
mfeliz-cruisenarendasan
authored andcommitted
Fix int/int=float division
1 parent 6c9832a commit 5e39222

File tree

2 files changed

+45
-34
lines changed

2 files changed

+45
-34
lines changed

core/conversion/converters/impl/element_wise.cpp

+30-34
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ nvinfer1::ITensor* clamp_util(
2626
return clamp_layer_out;
2727
}
2828

29+
void cast_int_int_div_tensors(
30+
ConversionCtx* ctx,
31+
const torch::jit::Node* n,
32+
nvinfer1::ITensor*& a,
33+
nvinfer1::ITensor*& b) {
34+
// Torch automatically produces a float for int/int division
35+
if (a->getType() == nvinfer1::DataType::kINT32 && b->getType() == nvinfer1::DataType::kINT32) {
36+
a = castITensor(ctx, a, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_a_cast");
37+
b = castITensor(ctx, b, nvinfer1::DataType::kFLOAT, util::node_info(n) + "_b_cast");
38+
}
39+
}
40+
41+
bool element_wise_divide_implementation(
42+
ConversionCtx* ctx,
43+
const torch::jit::Node* n,
44+
nvinfer1::ITensor* a,
45+
nvinfer1::ITensor* b) {
46+
cast_int_int_div_tensors(ctx, n, a, b);
47+
auto element_wise = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, a, b, util::node_info(n));
48+
TORCHTRT_CHECK(element_wise, "Unable to create element_wise layer from node: " << *n);
49+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], element_wise->getOutput(0));
50+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
51+
return true;
52+
}
53+
2954
auto element_wise_registrations TORCHTRT_UNUSED =
3055
RegisterNodeConversionPatterns()
3156
.pattern(
@@ -296,18 +321,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
296321
.pattern(
297322
{"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
298323
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
299-
// Should implement self / other
300324
auto self = args[0].ITensorOrFreeze(ctx);
301325
auto other = args[1].ITensorOrFreeze(ctx);
302-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
303-
304-
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
305-
306-
div->setName(util::node_info(n).c_str());
307-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
308-
309-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
310-
return true;
326+
return element_wise_divide_implementation(ctx, n, self, other);
311327
}})
312328
.pattern(
313329
{"aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)",
@@ -349,6 +365,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
349365
div = add_elementwise(
350366
ctx, nvinfer1::ElementWiseOperation::kPROD, floor, sign->getOutput(0), util::node_info(n));
351367
} else {
368+
cast_int_int_div_tensors(ctx, n, self, other);
352369
div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
353370
}
354371

@@ -365,42 +382,21 @@ auto element_wise_registrations TORCHTRT_UNUSED =
365382
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
366383
auto self = args[0].ITensorOrFreeze(ctx);
367384
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
368-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
369-
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
370-
371-
div->setName(util::node_info(n).c_str());
372-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
373-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
374-
return true;
385+
return element_wise_divide_implementation(ctx, n, self, other);
375386
}})
376387
.pattern(
377388
{"aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)",
378389
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
379-
// TODO: Remove with functionalization
380390
auto self = args[0].ITensorOrFreeze(ctx);
381391
auto other = args[1].ITensorOrFreeze(ctx);
382-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
383-
384-
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
385-
386-
div->setName(util::node_info(n).c_str());
387-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
388-
389-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
390-
return true;
392+
return element_wise_divide_implementation(ctx, n, self, other);
391393
}})
392394
.pattern(
393395
{"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
394396
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
395397
auto self = args[0].ITensorOrFreeze(ctx);
396398
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
397-
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
398-
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
399-
400-
div->setName(util::node_info(n).c_str());
401-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
402-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
403-
return true;
399+
return element_wise_divide_implementation(ctx, n, self, other);
404400
}})
405401
.pattern(
406402
{"aten::square(Tensor self) -> Tensor",

tests/core/conversion/converters/test_div.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ TEST(Converters, ATenDivConvertsCorrectly) {
1818
pointwise_test_helper(graph, false, false, {4}, {3, 4});
1919
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
2020
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
21+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
2122
}
2223

2324
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
@@ -29,6 +30,16 @@ TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
2930
pointwise_test_helper(graph, true);
3031
}
3132

33+
TEST(Converters, ATenDivWithScalarIntConvertsCorrectly) {
34+
const auto graph = R"IR(
35+
graph(%0 : Tensor):
36+
%scalar : int = prim::Constant[value=2]()
37+
%1 : Tensor = aten::div(%0, %scalar)
38+
return (%1))IR";
39+
pointwise_test_helper(graph, true);
40+
pointwise_test_helper(graph, true, false, {5}, {1}, false, at::kInt);
41+
}
42+
3243
TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
3344
const auto graph = R"IR(
3445
graph(%0 : Tensor, %1 : Tensor):
@@ -42,6 +53,7 @@ TEST(Converters, ATenDivRoundingFloorConvertsCorrectly) {
4253
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
4354
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
4455
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
56+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
4557
}
4658

4759
TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
@@ -57,6 +69,7 @@ TEST(Converters, ATenDivRoundingTruncConvertsCorrectly) {
5769
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
5870
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
5971
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
72+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
6073
}
6174

6275
TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
@@ -70,6 +83,7 @@ TEST(Converters, ATenDivRoundingNoneConvertsCorrectly) {
7083
pointwise_test_helper(graph, false, false, {4}, {3, 4}, true);
7184
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3}, true);
7285
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3}, true);
86+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
7387
}
7488

7589
TEST(Converters, ATenDivRoundingTruncWithIntsConvertsCorrectly) {
@@ -107,6 +121,7 @@ TEST(Converters, ATenFloorDivideConvertsCorrectly) {
107121
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
108122
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kFloat, at::kInt);
109123
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kFloat);
124+
pointwise_test_helper(graph, false, true, {5}, {5}, false, at::kInt, at::kInt);
110125
}
111126

112127
TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {

0 commit comments

Comments
 (0)