Skip to content

Commit a22e99b

Browse files
committed
feat(//core/conversion): Handle adding and wrapping ITensors as
arguments of append and unwrapping singular ITensors as outputs of evaluators Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a7d2b5e commit a22e99b

File tree

3 files changed

+105
-4
lines changed

3 files changed

+105
-4
lines changed

Diff for: core/conversion/conversion.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
4545
if (result) {
4646
// WARN: If the converter returns None then should pass through
4747
// but if repeated dep this section will get called each time
48-
ctx->evaluated_value_map[eval_in] = std::move(result.value());
49-
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
48+
auto val = result.value();
49+
if (val.isCustomClass()){
50+
auto cont = val.toCustomClass<TensorContainer>();
51+
ctx->AssociateValueAndTensor(eval_in, cont->tensor());
52+
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
53+
} else {
54+
ctx->AssociateValueAndIValue(eval_in, val);
55+
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
56+
}
5057
}
5158
} else {
5259
TRTORCH_THROW_ERROR(
@@ -374,6 +381,11 @@ void ConvertBlockToNetDef(
374381
} else {
375382
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
376383
}
384+
} else if (eval.value().isCustomClass()) {
385+
auto container = eval.value().toCustomClass<TensorContainer>();
386+
auto tensor = container->tensor();
387+
LOG_DEBUG(ctx->logger, "Found the value to be an ITensor of shape: " << tensor->getDimensions());
388+
ctx->AssociateValueAndTensor(n->output(0), tensor);
377389
} else if (!eval.value().isTensor()) {
378390
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
379391
ctx->AssociateValueAndIValue(n->output(0), eval.value());

Diff for: core/conversion/evaluators/aten.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,17 @@ auto aten_registrations TRTORCH_UNUSED =
216216
.evaluator({c10::Symbol::fromQualString("aten::append"),
217217
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
218218
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
219-
auto el = args.at(n->input(1)).IValue();
220219

221-
list.push_back(std::move(*el));
220+
if (args.at(n->input(1)).isITensor()) {
221+
auto tensor_holder = TensorContainer();
222+
tensor_holder.hold_tensor(args.at(n->input(1)).ITensor());
223+
auto el = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
224+
list.push_back(std::move(el));
225+
} else {
226+
auto el = args.at(n->input(1)).IValue();
227+
list.push_back(std::move(*el));
228+
}
229+
222230
return list;
223231
},
224232
EvalOptions().validSchemas({

Diff for: tests/core/conversion/evaluators/test_aten_evaluators.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -235,4 +235,85 @@ TEST(Evaluators, FloorFloatIntEvaluatesCorrectly) {
235235
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
236236

237237
ASSERT_TRUE(jit_results[0] == trt_results[0]);
238+
}
239+
240+
TEST(Evaluators, ATenAppendWithITensorEvaluatesCorrectly) {
241+
const auto graph = R"IR(
242+
graph(%0 : Tensor, %1 : Tensor):
243+
%2 : int = prim::Constant[value=0]()
244+
%3 : Tensor[] = prim::ListConstruct(%0)
245+
%4 : Tensor[] = aten::append(%3, %1)
246+
%5 : Tensor = aten::cat(%4, %2)
247+
return (%5))IR";
248+
249+
auto g = std::make_shared<torch::jit::Graph>();
250+
torch::jit::parseIR(graph, &*g);
251+
252+
auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA});
253+
auto in1 = at::randint(1, 10, {3, 3}, {at::kCUDA});
254+
255+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
256+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0, in1});
257+
258+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
259+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0, in1});
260+
261+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
262+
}
263+
264+
TEST(Evaluators, ATenAppendWithTensorEvaluatesCorrectly) {
265+
const auto graph = R"IR(
266+
graph(%0 : Tensor):
267+
%1 : int[] = prim::Constant[value=[3,3]]()
268+
%2 : None = prim::Constant() # :0:0
269+
%20 : Device = prim::Constant[value="cuda"]()
270+
%3 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
271+
%4 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
272+
%5 : int = prim::Constant[value=0]()
273+
%15 : int = prim::Constant[value=1]()
274+
%6 : Tensor[] = prim::ListConstruct(%3)
275+
%7 : Tensor[] = aten::append(%6, %4)
276+
%8 : Tensor = aten::cat(%7, %5)
277+
%9 : Tensor = aten::add(%8, %0, %15)
278+
return (%9))IR";
279+
280+
auto g = std::make_shared<torch::jit::Graph>();
281+
torch::jit::parseIR(graph, &*g);
282+
283+
auto in0 = at::randint(1, 10, {6, 3}, {at::kCUDA});
284+
285+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
286+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0});
287+
288+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
289+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0});
290+
291+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
292+
}
293+
294+
TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) {
295+
const auto graph = R"IR(
296+
graph(%0 : Tensor):
297+
%1 : int[] = aten::size(%0)
298+
%2 : None = prim::Constant() # :0:0
299+
%20 : Device = prim::Constant[value="cuda"]()
300+
%3 : Tensor = aten::zeros(%1, %2, %2, %20, %2)
301+
%4 : int = prim::Constant[value=0]()
302+
%5 : Tensor[] = prim::ListConstruct(%0)
303+
%6 : Tensor[] = aten::append(%5, %3)
304+
%7 : Tensor = aten::cat(%6, %4)
305+
return (%7))IR";
306+
307+
auto g = std::make_shared<torch::jit::Graph>();
308+
torch::jit::parseIR(graph, &*g);
309+
310+
auto in0 = at::randint(1, 10, {3, 3}, {at::kCUDA});
311+
312+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
313+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in0});
314+
315+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
316+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0});
317+
318+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
238319
}

0 commit comments

Comments
 (0)