Skip to content

Commit 07ba980

Browse files
committed
fix(//core/conversion/evaluators): A couple fixes for evaluators
- Fixes aten::append to correctly append values and not pointers - Fixes prim::RaiseException and aten::warn to print out strings Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 6421f3d commit 07ba980

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

Diff for: core/conversion/evaluators/aten.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
207207
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
208208
auto el = args.at(n->input(1)).IValue();
209209

210-
list.push_back(std::move(el));
210+
list.push_back(std::move(*el));
211211
return list;
212212
},
213213
EvalOptions().validSchemas({
@@ -430,16 +430,16 @@ auto aten_registrations TRTORCH_UNUSED = RegisterNodeEvaluators()
430430
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
431431
auto el = args.at(n->input(0)).unwrapToDouble();
432432

433-
return std::floor(el);
433+
return static_cast<int64_t>(std::floor(el));
434434
},
435435
EvalOptions().validSchemas({
436436
"aten::floor.float(float a) -> (int)",
437437
})
438438
}).evaluator({
439439
c10::Symbol::fromQualString("aten::warn"),
440440
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
441-
auto warning = args.at(n->input(0)).IValue()->toString();
442-
LOG_WARNING(warning);
441+
auto warning = args.at(n->input(0)).IValue();
442+
LOG_WARNING("Warning from TorchScript: " << *warning);
443443
return {};
444444
},
445445
EvalOptions()

Diff for: core/conversion/evaluators/prim.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ auto prim_registrations = RegisterNodeEvaluators()
242242
}).evaluator({
243243
c10::Symbol::fromQualString("prim::RaiseException"),
244244
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
245-
auto exception = args.at(n->input(0)).IValue()->toString();
246-
TRTORCH_THROW_ERROR(exception);
245+
auto exception = args.at(n->input(0)).IValue();
246+
TRTORCH_THROW_ERROR("Error from TorchScript: " << *exception);
247247
return {};
248248
}
249249
});

0 commit comments

Comments
 (0)