Skip to content

Commit 34f84df

Browse files
authored
Merge pull request #319 from NVIDIA/mul_scalar
Add mul.scalar converter
2 parents f453ede + e0ce9f3 commit 34f84df

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

core/conversion/converters/impl/element_wise.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
363363
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
364364
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
365365

366+
mul->setName(util::node_info(n).c_str());
367+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
368+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
369+
return true;
370+
}})
371+
.pattern({"aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)",
372+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
373+
// TODO: Remove with functionalization
374+
auto self = args[0].ITensorOrFreeze(ctx);
375+
auto otherScalar = args[1].unwrapToScalar().to<float>();
376+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
377+
auto mul =
378+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n));
379+
TRTORCH_CHECK(mul, "Unable to create mul layer from node: " << *n);
380+
366381
mul->setName(util::node_info(n).c_str());
367382
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], mul->getOutput(0));
368383
LOG_DEBUG("Output tensor shape: " << out->getDimensions());

tests/core/conversion/converters/test_element_wise.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ TEST(Converters, ATenMulConvertsCorrectly) {
111111
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
112112
}
113113

114+
TEST(Converters, ATenMulWithScalarConvertsCorrectly) {
115+
const auto graph = R"IR(
116+
graph(%0 : Tensor):
117+
%scalar : float = prim::Constant[value=2.4]()
118+
%1 : Tensor = aten::mul(%0, %scalar)
119+
return (%1))IR";
120+
pointwise_test_helper(graph, true);
121+
}
122+
114123
TEST(Converters, ATenDivConvertsCorrectly) {
115124
const auto graph = R"IR(
116125
graph(%0 : Tensor, %1 : Tensor):

0 commit comments

Comments
 (0)