Skip to content

Commit c1ac4f0

Browse files
[fix] Disambiguate element-wise cast layer names (#1630)
1 parent 1668bfa commit c1ac4f0

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

core/conversion/converters/converter_util.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ nvinfer1::ILayer* add_elementwise(
8585
const std::string& name) {
8686
if (self->getType() == nvinfer1::DataType::kFLOAT && other->getType() == nvinfer1::DataType::kINT32) {
8787
LOG_DEBUG("Type mismatch, casting other to " << self->getType());
88-
other = castITensor(ctx, other, self->getType());
88+
other = castITensor(ctx, other, self->getType(), name);
8989
} else if (self->getType() == nvinfer1::DataType::kINT32 && other->getType() == nvinfer1::DataType::kFLOAT) {
9090
LOG_DEBUG("Type mismatch, casting self to " << other->getType());
91-
self = castITensor(ctx, self, other->getType());
91+
self = castITensor(ctx, self, other->getType(), name);
9292
}
9393
// ensure self to have larger number of dimension
9494
bool swapSelfOther = false;
@@ -106,13 +106,13 @@ nvinfer1::ILayer* add_elementwise(
106106
LOG_DEBUG(
107107
"Element-wise op type promotion adding cast from " << self->getType() << " to " << promo_type << " for layer "
108108
<< name);
109-
self = castITensor(ctx, self, promo_type);
109+
self = castITensor(ctx, self, promo_type, name);
110110
}
111111
if (other->getType() != promo_type) {
112112
LOG_DEBUG(
113113
"Element-wise op type promotion adding cast from " << other->getType() << " to " << promo_type
114114
<< " for layer " << name);
115-
other = castITensor(ctx, other, promo_type);
115+
other = castITensor(ctx, other, promo_type, name);
116116
}
117117
}
118118

tests/core/conversion/converters/test_add_sub_mul.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,14 @@ TEST(Converters, ATenPowScalarConvertsCorrectly) {
178178
return (%3))IR";
179179
pointwise_test_helper(graph, true);
180180
}
181+
182+
TEST(Converters, ElementWiseTypePromotionDisambiguatesCastNames) {
183+
const auto graph = R"IR(
184+
graph(%0 : Tensor, %1 : Tensor):
185+
%2 : int = prim::Constant[value=1]()
186+
%3 : Tensor = aten::add(%0, %1, %2)
187+
%4 : Tensor = aten::add(%0, %1, %2)
188+
%5 : Tensor = aten::add(%3, %4, %2)
189+
return (%5))IR";
190+
pointwise_test_helper(graph, false, false, {4, 3, 3, 3}, {4, 3, 3, 3}, false, at::kInt, at::kFloat);
191+
}

0 commit comments

Comments
 (0)