diff --git a/core/conversion/converters/impl/element_wise.cpp b/core/conversion/converters/impl/element_wise.cpp index ab4ef4c138..2f0c3a9d13 100644 --- a/core/conversion/converters/impl/element_wise.cpp +++ b/core/conversion/converters/impl/element_wise.cpp @@ -425,8 +425,7 @@ auto element_wise_registrations TORCHTRT_UNUSED = [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // TODO: Remove with functionalization auto self = args[0].ITensorOrFreeze(ctx); - auto otherScalar = args[1].unwrapToScalar().to(); - auto other = tensor_to_const(ctx, torch::tensor({otherScalar})); + auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar()); auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, self, other, util::node_info(n)); TORCHTRT_CHECK(mul, "Unable to create mul layer from node: " << *n); diff --git a/tests/core/conversion/converters/test_element_wise.cpp b/tests/core/conversion/converters/test_element_wise.cpp index 59297ff861..994fb25811 100644 --- a/tests/core/conversion/converters/test_element_wise.cpp +++ b/tests/core/conversion/converters/test_element_wise.cpp @@ -11,7 +11,8 @@ void pointwise_test_helper( bool dynamicInput = false, std::vector shape1 = {5}, std::vector shape2 = {5}, - bool negative_input = false) { + bool negative_input = false, + bool int_tensors = false) { auto g = std::make_shared(); torch::jit::parseIR(graph_ir, g.get()); @@ -26,6 +27,11 @@ void pointwise_test_helper( if (!singleInput) { torch_inputs.push_back(at::randint(1, 5, shape2, {at::kCUDA})); } + if(int_tensors){ + for(size_t i = 0UL; i < torch_inputs.size(); ++i){ + torch_inputs[i] = torch_inputs[i].to(at::kInt); + } + } auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, torch_inputs); @@ -126,6 +132,15 @@ TEST(Converters, ATenMulWithScalarConvertsCorrectly) { pointwise_test_helper(graph, true); } +TEST(Converters, ATenMulWithIntScalarConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %scalar : int = prim::Constant[value=2]() + %1 : Tensor = aten::mul(%0, %scalar) + return (%1))IR"; + pointwise_test_helper(graph, true, false, {5}, {5}, false, true); +} + TEST(Converters, ATenDivConvertsCorrectly) { const auto graph = R"IR( graph(%0 : Tensor, %1 : Tensor): diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index fe211f2baf..04e0bd4811 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -4,6 +4,7 @@ #include "core/ir/ir.h" #include "core/runtime/runtime.h" #include "core/util/prelude.h" +#include "core/util/trt_util.h" #include "cuda_runtime_api.h" #include "torch/csrc/jit/ir/ir.h" #include "torch/csrc/jit/ir/irparser.h" @@ -19,7 +20,7 @@ namespace util { std::vector toInputs(std::vector ten) { std::vector a; for (auto i : ten) { - a.push_back(core::ir::Input(core::util::toVec(i.sizes()))); + a.push_back(core::ir::Input(core::util::toVec(i.sizes()), core::util::ScalarTypeToTRTDataType(i.scalar_type()))); } return a; }