diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index c883b6989c..f44f55a02f 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -174,9 +174,11 @@ auto select_registrations TRTORCH_UNUSED = {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto embeddingTensor = args[0].ITensorOrFreeze(ctx); - auto indicesTensor = args[1].ITensor(); + auto indicesTensor = args[1].ITensorOrFreeze(ctx); // Set datatype for indices tensor to INT32 - indicesTensor->setType(nvinfer1::DataType::kINT32); + auto identity = ctx->net->addIdentity(*indicesTensor); + identity->setOutputType(0, nvinfer1::DataType::kINT32); + indicesTensor = identity->getOutput(0); // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0); diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index 44d6669e5e..9e8a475ce8 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -106,7 +106,7 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) { auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in}); // Run TensorRT - auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32); + auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kFloat); auto trt_in = at::tensor({0, 1, 2}, options_trt); auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in}); auto trt = trt_results[0].reshape(jit_results[0].sizes());