Skip to content

Commit 619e345

Browse files
committed
feat(//core/conversion/evaluators): Allow ITensors to be wrapped in
IValues Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 8c26a1b commit 619e345

File tree

12 files changed

+159
-48
lines changed

12 files changed

+159
-48
lines changed

Diff for: core/conversion/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cc_library(
1919
],
2020
deps = [
2121
"@tensorrt//:nvinfer",
22-
"//core/conversion/arg",
22+
"//core/conversion/var",
2323
"//core/conversion/conversionctx",
2424
"//core/conversion/converters",
2525
"//core/conversion/evaluators",

Diff for: core/conversion/conversion.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <sstream>
22

33
#include "core/util/prelude.h"
4-
#include "core/conversion/arg/Arg.h"
4+
#include "core/conversion/var/Var.h"
55
#include "core/conversion/conversion.h"
66
#include "core/conversion/converters/converters.h"
77
#include "core/conversion/evaluators/evaluators.h"
@@ -35,6 +35,8 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
3535
}
3636
if (ctx->evaluated_value_map.find(eval_in) != ctx->evaluated_value_map.end()) {
3737
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
38+
} else if (ctx->value_tensor_map.find(eval_in) != ctx->value_tensor_map.end()) {
39+
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
3840
} else if (evaluators::shouldEvalAtConversionTime(eval_in->node())) {
3941
auto result = EvaluateNode(ctx, eval_in->node(), level++, limit);
4042
if (result) {
@@ -82,8 +84,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
8284
ctx->AssociateValueAndIValue(input, eval.value());
8385
node_args.push_back(&(ctx->evaluated_value_map[input]));
8486
} else {
85-
LOG_DEBUG(ctx->logger, "Found the value is None");;
86-
node_args.push_back(Arg());
87+
LOG_DEBUG(ctx->logger, "Found the value is None");
88+
node_args.push_back(Var());
8789
}
8890
} else {
8991
// Node input has not been converted yet or is a prim op

Diff for: core/conversion/conversionctx/ConversionCtx.h

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <unordered_map>
55
#include <memory>
66

7-
//#include "ATen/ATen.h"
87
#include "torch/csrc/jit/ir/ir.h"
98
#include "NvInfer.h"
109

Diff for: core/conversion/converters/converters.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
#include "ATen/core/function_schema.h"
88

99
#include "core/util/prelude.h"
10-
#include "core/conversion/arg/Arg.h"
10+
#include "core/conversion/var/Var.h"
1111
#include "core/conversion/conversionctx/ConversionCtx.h"
1212

1313
namespace trtorch {
1414
namespace core {
1515
namespace conversion {
1616
namespace converters {
1717

18-
typedef std::vector<Arg> args;
18+
typedef std::vector<Var> args;
1919
typedef std::function<bool(ConversionCtx*, const torch::jit::Node*, args&)> OpConverter;
2020
struct ConversionPattern {
2121
std::string signature;

Diff for: core/conversion/evaluators/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ cc_library(
1818
],
1919
deps = [
2020
"//core/util:prelude",
21+
"//core/conversion/var",
22+
"//core/conversion/tensorcontainer",
2123
] + select({
2224
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2325
"//conditions:default": ["@libtorch//:libtorch"],

Diff for: core/conversion/evaluators/NodeEvaluatorRegistry.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
5757
return get_evaluator_registry().EvalAtConversionTime(n);
5858
}
5959

60-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, const kwargs& args) {
60+
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
6161
auto evaluator = get_evaluator_registry().GetEvaluator(n->kind());
6262
return evaluator(n, args);
6363
}

Diff for: core/conversion/evaluators/evaluators.h

+24-8
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,45 @@
22

33
#include <string>
44
#include <map>
5+
#include <set>
56

67
#include "torch/csrc/jit/ir/ir.h"
78

9+
#include "core/conversion/tensorcontainer/TensorContainer.h"
10+
#include "core/conversion/var/Var.h"
11+
812
namespace trtorch {
913
namespace core {
1014
namespace conversion {
1115
namespace evaluators {
1216

13-
typedef std::map<const torch::jit::Value*, const torch::jit::IValue*> kwargs;
14-
15-
// NOTE: The input args are a dictionary of Value -> IValue this means
16-
// inputs will not be repeated. We did this so writing encoders
17-
// is similar to converters (where a dictionary makes more sense)
18-
// This mean that you should iterate over node inputs vs. the args
17+
typedef std::map<const torch::jit::Value*, Var> kwargs;
18+
19+
inline bool constTypesOnly(kwargs& args) {
20+
std::set<Var::Type> types;
21+
for (auto a : args) {
22+
if (a.second.type() == Var::kITensor) {
23+
return false;
24+
}
25+
}
26+
return true;
27+
}
28+
29+
// NOTE: The input args are a dictionary of Value -> Var this means
30+
// inputs will not be repeated. We did this because while in the case
31+
// of converters we have the function schema to lay out argument order,
32+
// evaluators dont use the schema, they use node kind as key so it easier
33+
// to use the node itself to pull out arguments.
34+
// This means that you should iterate over node inputs vs. the args
1935
// when writing evaluators
20-
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, const kwargs&)> NodeEvaluator;
36+
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
2137

2238
struct EvalRegistration {
2339
torch::jit::NodeKind kind;
2440
NodeEvaluator evaluator;
2541
};
2642

27-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, const kwargs& args);
43+
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
2844
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
2945
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);
3046
void register_node_evaluator(EvalRegistration r);

Diff for: core/conversion/evaluators/prim.cpp

+46-31
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "ATen/core/ivalue.h"
55
#include "ATen/core/List.h"
66
#include "ATen/core/stack.h"
7+
#include "c10/util/intrusive_ptr.h"
78

89
#include "core/conversion/evaluators/evaluators.h"
910

@@ -16,51 +17,65 @@ namespace {
1617
auto prim_registrations = RegisterNodeEvaluators()
1718
.evaluator({
1819
torch::jit::prim::Constant,
19-
[](const torch::jit::Node* n, const kwargs& args) -> c10::optional<torch::jit::IValue> {
20+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2021
if (n->output()->type()->kind() == at::FunctionType::Kind) {
2122
return {};
2223
}
2324
return torch::jit::toIValue(n->output());
2425
}
2526
}).evaluator({
2627
torch::jit::prim::ListConstruct,
27-
[](const torch::jit::Node* n, const kwargs& args) -> c10::optional<torch::jit::IValue> {
28+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2829
const auto num_inputs = n->inputs().size();
29-
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
30-
if (torch::jit::IntType::get() == lt->getElementType()) {
31-
c10::List<int64_t> list;
32-
list.reserve(num_inputs);
33-
for (auto in : n->inputs()) {
34-
list.emplace_back(std::move(args.at(in)->to<int64_t>()));
35-
}
36-
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
37-
} else if (torch::jit::FloatType::get() == lt->getElementType()) {
38-
c10::List<double> list;
39-
list.reserve(num_inputs);
40-
for (auto in : n->inputs()) {
41-
list.emplace_back(std::move(args.at(in)->to<double>()));
30+
if (constTypesOnly(args)) {
31+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
32+
if (torch::jit::IntType::get() == lt->getElementType()) {
33+
c10::List<int64_t> list;
34+
list.reserve(num_inputs);
35+
for (auto in : n->inputs()) {
36+
list.emplace_back(std::move(args.at(in).unwrapToInt()));
37+
}
38+
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
39+
} else if (torch::jit::FloatType::get() == lt->getElementType()) {
40+
c10::List<double> list;
41+
list.reserve(num_inputs);
42+
for (auto in : n->inputs()) {
43+
list.emplace_back(std::move(args.at(in).unwrapToDouble()));
44+
}
45+
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
46+
} else if (lt->getElementType() == torch::jit::BoolType::get()) {
47+
c10::List<bool> list;
48+
list.reserve(num_inputs);
49+
for (auto in : n->inputs()) {
50+
list.emplace_back(std::move(args.at(in).unwrapToBool()));
51+
}
52+
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
53+
} else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) {
54+
c10::List<at::Tensor> list;
55+
list.reserve(num_inputs);
56+
for (auto in : n->inputs()) {
57+
if (args.at(in).isIValue()) {
58+
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
59+
}
60+
}
61+
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
62+
} else {
63+
c10::TypePtr elementType = lt->getElementType();
64+
auto list = c10::impl::GenericList(elementType);
65+
list.reserve(num_inputs);
66+
for (auto in : n->inputs()) {
67+
list.emplace_back(std::move(*(args.at(in).IValue())));
68+
}
69+
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
4270
}
43-
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
44-
} else if (lt->getElementType() == torch::jit::BoolType::get()) {
45-
c10::List<bool> list;
46-
list.reserve(num_inputs);
47-
for (auto in : n->inputs()) {
48-
list.emplace_back(std::move(args.at(in)->to<bool>()));
49-
}
50-
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
51-
} else if (lt->getElementType()->isSubtypeOf(torch::jit::TensorType::get())) {
52-
c10::List<at::Tensor> list;
53-
list.reserve(num_inputs);
54-
for (auto in : n->inputs()) {
55-
list.emplace_back(std::move(args.at(in)->toTensor()));
56-
}
57-
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
5871
} else {
72+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
5973
c10::TypePtr elementType = lt->getElementType();
6074
auto list = c10::impl::GenericList(elementType);
6175
list.reserve(num_inputs);
6276
for (auto in : n->inputs()) {
63-
list.emplace_back(std::move(*(args.at(in))));
77+
auto x = torch::make_custom_class<TensorContainer>(reinterpret_cast<int64_t>(args.at(in).ITensor()));
78+
list.emplace_back(std::move(x));
6479
}
6580
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
6681
}

Diff for: core/conversion/tensorcontainer/BUILD

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package(default_visibility = ["//visibility:public"])
2+
3+
config_setting(
4+
name = "use_pre_cxx11_abi",
5+
values = {
6+
"define": "abi=pre_cxx11_abi",
7+
}
8+
)
9+
10+
cc_library(
11+
name = "tensorcontainer",
12+
hdrs = [
13+
"TensorContainer.h",
14+
],
15+
srcs = [
16+
"TensorContainer.cpp",
17+
],
18+
deps = [
19+
"@tensorrt//:nvinfer",
20+
"//core/util:prelude",
21+
] + select({
22+
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
23+
"//conditions:default": ["@libtorch//:libtorch"],
24+
}),
25+
alwayslink = True,
26+
)
27+
28+
load("@rules_pkg//:pkg.bzl", "pkg_tar")
29+
30+
pkg_tar(
31+
name = "include",
32+
package_dir = "core/conversion/tensorcontainer/",
33+
srcs = [
34+
"TensorContainer.h",
35+
],
36+
)

Diff for: core/conversion/tensorcontainer/TensorContainer.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#include "core/conversion/tensorcontainer/TensorContainer.h"
2+
3+
namespace trtorch {
4+
namespace core {
5+
namespace conversion {
6+
namespace {
7+
8+
static auto tensor_container =
9+
torch::class_<TensorContainer>("_eval_ivalue_types", "TensorContainer")
10+
.def(torch::init<int64_t>())
11+
.def("clone", &TensorContainer::clone);
12+
13+
} // namespace
14+
} // conversion
15+
} // core
16+
} // trtorch

Diff for: core/conversion/tensorcontainer/TensorContainer.h

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "NvInfer.h"
4+
#include "torch/custom_class.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace conversion {
9+
10+
struct TensorContainer : torch::CustomClassHolder {
11+
int64_t tensor_;
12+
TensorContainer(int64_t init) : tensor_(init) {}
13+
14+
c10::intrusive_ptr<TensorContainer> clone() const {
15+
return c10::make_intrusive<TensorContainer>(tensor_);
16+
}
17+
18+
nvinfer1::ITensor* tensor() {
19+
return reinterpret_cast<nvinfer1::ITensor*>(tensor_);
20+
}
21+
};
22+
23+
} // conversion
24+
} // core
25+
} // trtorch

Diff for: core/execution/register_trt_op.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
2222
auto dims = core::util::toDimsPad(inputs[i].sizes(), 1);
2323
auto shape = core::util::toVec(dims);
2424
contig_inputs.push_back(inputs[i].view(shape).contiguous());
25-
LOG_DEBUG("In shape: " << shape);
25+
LOG_DEBUG("Input shape: " << dims);
2626
ctx->setBindingDimensions(i, dims);
2727
gpu_handles.push_back(contig_inputs.back().data_ptr());
2828
}

0 commit comments

Comments
 (0)