Skip to content

Commit de269af

Browse files
committed
fix bugs in embedding converter
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 9ff9c22 commit de269af

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

core/conversion/converters/impl/select.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,11 @@ auto select_registrations TRTORCH_UNUSED =
174174
{"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)",
175175
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
176176
auto embeddingTensor = args[0].ITensorOrFreeze(ctx);
177-
auto indicesTensor = args[1].ITensor();
177+
auto indicesTensor = args[1].ITensorOrFreeze(ctx);
178178
// Set datatype for indices tensor to INT32
179-
indicesTensor->setType(nvinfer1::DataType::kINT32);
179+
auto identity = ctx->net->addIdentity(*indicesTensor);
180+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
181+
indicesTensor = identity->getOutput(0);
180182

181183
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from
182184
auto gather_layer = ctx->net->addGather(*embeddingTensor, *indicesTensor, 0);

tests/core/conversion/converters/test_select.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
106106
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
107107

108108
// Run TensorRT
109-
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kI32);
109+
auto options_trt = torch::TensorOptions().device(torch::kCUDA, 0).dtype(torch::kFloat);
110110
auto trt_in = at::tensor({0, 1, 2}, options_trt);
111111
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
112112
auto trt = trt_results[0].reshape(jit_results[0].sizes());

0 commit comments

Comments
 (0)