Skip to content

Commit 5bc977d

Browse files
committed
feat: support aten::Int
Signed-off-by: inocsin <[email protected]>
1 parent bd72677 commit 5bc977d

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,30 @@ auto aten_registrations TRTORCH_UNUSED =
383383
"aten::Float.int(int a) -> float",
384384
"aten::Float.bool(bool a) -> float",
385385
})})
386+
.evaluator({c10::Symbol::fromQualString("aten::Int"),
387+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
388+
if (args.at(n->input(0)).IValue()->isInt()) {
389+
auto a = args.at(n->input(0)).unwrapToInt();
390+
return (int)a;
391+
} else if (args.at(n->input(0)).IValue()->isDouble()) {
392+
auto a = args.at(n->input(0)).unwrapToDouble();
393+
return (int)a;
394+
} else if (args.at(n->input(0)).IValue()->isBool()) {
395+
auto a = args.at(n->input(0)).unwrapToBool();
396+
return (int)a;
397+
} else {
398+
TRTORCH_THROW_ERROR(
399+
"Unimplemented data type for aten::Int evaluator: "
400+
<< args.at(n->input(0)).IValue()->type()->str());
401+
return {};
402+
}
403+
},
404+
EvalOptions().validSchemas({
405+
"aten::Int.Scalar(Scalar a) -> int",
406+
"aten::Int.int(int a) -> int",
407+
"aten::Int.bool(bool a) -> int",
408+
"aten::Int.float(float a) -> int",
409+
})})
386410
.evaluator({c10::Symbol::fromQualString("aten::__not__"),
387411
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
388412
auto el = args.at(n->input(0)).unwrapToBool();

Diff for: tests/core/conversion/evaluators/test_aten_evaluators.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -399,4 +399,20 @@ TEST(Evaluators, ATenCopyEvaluatesCorrectly) {
399399
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
400400

401401
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
402+
}
403+
404+
TEST(Evaluators, IntFloatEvaluatesCorrectly) {
405+
const auto graph = R"IR(
406+
graph():
407+
%1 : float = prim::Constant[value=9.3]()
408+
%2 : int = aten::Int(%1)
409+
return (%2))IR";
410+
411+
auto g = std::make_shared<torch::jit::Graph>();
412+
torch::jit::parseIR(graph, g.get());
413+
414+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
415+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
416+
417+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
402418
}

0 commit comments

Comments
 (0)