Skip to content

Commit 6d60246

Browse files
authored
Merge pull request #81 from NVIDIA/fuse_addmm_branches
Adds basic scripting support
2 parents 22a4490 + 80d5069 commit 6d60246

File tree

85 files changed

+1235
-458
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

85 files changed

+1235
-458
lines changed

core/compiler.cpp

+10-7
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
150150
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
151151
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
152152
for (const torch::jit::script::Method& method : mod.get_methods()) {
153-
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
154-
auto new_g = std::make_shared<torch::jit::Graph>();
155-
AddEngineToGraph(new_mod, new_g, engine);
156-
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
157-
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
158-
new_mod.type()->addMethod(new_method);
159-
new_method->setSchema(schema);
153+
// Don't convert hidden methods
154+
if (method.name().rfind("_", 0)) {
155+
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
156+
auto new_g = std::make_shared<torch::jit::Graph>();
157+
AddEngineToGraph(new_mod, new_g, engine);
158+
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
159+
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
160+
new_mod.type()->addMethod(new_method);
161+
new_method->setSchema(schema);
162+
}
160163
}
161164

162165
return new_mod;

core/conversion/conversion.cpp

+79-7
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,57 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190190
}
191191
}
192192

193+
void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194+
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195+
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
196+
std::back_inserter(input_output_pairs),
197+
[](auto in, auto out){
198+
return std::make_pair(in, out);
199+
});
200+
201+
for (auto p : input_output_pairs) {
202+
auto input = ctx->evaluated_value_map[p.first];
203+
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
204+
}
205+
}
206+
207+
// TODO: With functionalization pass we may be able to make this into a regular evaluator later
208+
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
209+
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
210+
auto start_cond = ctx->evaluated_value_map[n->input(1)];
211+
ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]] = torch::jit::IValue(0);
212+
auto trip_count = ctx->evaluated_value_map[n->blocks()[0]->inputs()[0]];
213+
214+
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);
215+
216+
LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
217+
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
218+
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
219+
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
220+
221+
while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
222+
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
223+
for (auto bn : n->blocks()[0]->nodes()) {
224+
auto eval = EvaluateNode(ctx, bn);
225+
if (eval) {
226+
if (!eval.value().isTensor()) {
227+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
228+
} else {
229+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
230+
}
231+
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
232+
}
233+
}
234+
235+
MapIValues(ctx, n->blocks()[0]->outputs(), n->outputs(), 1, 0);
236+
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
237+
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
238+
trip_count.swap(new_trip_count);
239+
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
240+
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
241+
}
242+
}
243+
193244
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
194245
LOG_INFO(ctx->logger, "Converting Block");
195246

@@ -202,7 +253,19 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
202253
for (const auto n : nodes) {
203254
bool to_eval = evaluators::shouldEvalAtConversionTime(n);
204255
bool blacklisted = isNodeConversionBlacklisted(n);
205-
if (!to_eval && !blacklisted) {
256+
if (n->kind() == torch::jit::prim::Loop) {
257+
EvaluateLoopBlock(ctx, n);
258+
} else if (to_eval) {
259+
auto eval = EvaluateNode(ctx, n);
260+
if (eval) {
261+
if (!eval.value().isTensor()) {
262+
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
263+
} else {
264+
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
265+
}
266+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
267+
}
268+
} else if (!blacklisted) {
206269
// Should error out if something fails
207270
AddLayer(ctx, n);
208271
} else {
@@ -237,22 +300,29 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
237300
return engine;
238301
}
239302

240-
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
241-
bool supported = true;
303+
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
242304
std::set<std::string> unsupported_ops;
243305
for (const auto n : b->nodes()) {
244-
if (!OpSupported(n)) {
306+
if (!OpSupported(n) && n->kind() != torch::jit::prim::Loop) {
245307
auto schema = n->maybeSchema();
246308
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
247309
<< " (conversion.VerifyCoverterSupportForBlock");
248310
std::stringstream ss;
249311
ss << *schema;
250312
unsupported_ops.insert(ss.str());
251-
supported = false;
313+
}
314+
for (const auto sub_b : n->blocks()) {
315+
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
316+
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
252317
}
253318
}
319+
return unsupported_ops;
320+
}
321+
322+
bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
323+
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
254324

255-
if (!supported) {
325+
if (unsupported_ops.size() != 0) {
256326
std::stringstream unsupported_msg;
257327
unsupported_msg << "Method requested cannot be compiled by TRTorch.\nUnsupported operators listed below:" << std::endl;
258328
for (auto s : unsupported_ops) {
@@ -261,8 +331,10 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
261331
unsupported_msg << "You can either implement converters for these ops in your application or request implementation" << std::endl;
262332
unsupported_msg << "https://www.github.com/nvidia/TRTorch/issues" << std::endl;
263333
LOG_ERROR(unsupported_msg.str());
334+
return false;
335+
} else {
336+
return true;
264337
}
265-
return supported;
266338
}
267339

268340
} // namespace conversion

core/conversion/conversion_blacklist.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
1515
"aten::backward",
1616
"aten::save",
1717
"aten::contiguous",
18+
"aten::to",
1819
"prim::RaiseException",
1920
"prim::Print",
2021
"prim::device",

core/conversion/converters/BUILD

-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ cc_library(
2525
"impl/matrix_multiply.cpp",
2626
"impl/pooling.cpp",
2727
"impl/reduce.cpp",
28-
"impl/shape.cpp",
2928
"impl/shuffle.cpp",
3029
"impl/softmax.cpp",
3130
"impl/unary.cpp",

core/conversion/converters/impl/shape.cpp

-32
This file was deleted.

core/conversion/evaluators/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
srcs = [
1616
"NodeEvaluatorRegistry.cpp",
1717
"prim.cpp",
18+
"aten.cpp"
1819
],
1920
deps = [
2021
"//core/util:prelude",

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

+44-12
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,59 @@ namespace core {
1515
namespace conversion {
1616
namespace evaluators {
1717
namespace {
18-
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, NodeEvaluator>;
18+
using EvaluatorLUT = std::unordered_map<torch::jit::NodeKind, EvalRegistration>;
19+
20+
bool FindInVec(std::vector<c10::OperatorName>& names, c10::OperatorName target) {
21+
for (auto n : names) {
22+
if (n == target) {
23+
return true;
24+
}
25+
}
26+
return false;
27+
}
1928

2029
class NodeEvaluatorRegistry {
2130
public:
22-
void RegisterEvaluator(torch::jit::NodeKind node_kind, NodeEvaluator& evaluator) {
31+
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
2332
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
24-
evaluator_lut_[node_kind] = std::move(evaluator);
33+
evaluator_lut_[node_kind] = std::move(eval_reg);
2534
}
2635

27-
NodeEvaluator GetEvaluator(const torch::jit::NodeKind node_kind) {
36+
NodeEvaluator FindEvaluator(const torch::jit::Node* n) {
37+
auto node_kind = n->kind();
2838
auto iter = evaluator_lut_.find(node_kind);
2939
if (iter == evaluator_lut_.end()) {
30-
LOG_ERROR("Requested evaluator for " << node_kind.toQualString() << ", but no such evaluator was found");
3140
return nullptr;
3241
}
33-
return iter->second;
42+
auto eval_reg = iter->second;
43+
if (eval_reg.options.use()) {
44+
for (auto o : n->outputs()) {
45+
if (eval_reg.options.blacklisted_output_types.find(o->type()) != eval_reg.options.blacklisted_output_types.end()) {
46+
return nullptr;
47+
}
48+
}
49+
50+
if (eval_reg.options.valid_schemas.size() != 0) {
51+
auto schema = n->maybeSchema();
52+
TRTORCH_CHECK(schema, "Evaluator for " << node_kind.toQualString() << "only runs on certain schemas, but schema for node is not retrievable");
53+
if (!FindInVec(eval_reg.options.valid_schemas, schema->operator_name())) {
54+
return nullptr;
55+
}
56+
}
57+
}
58+
59+
return eval_reg.evaluator;
60+
}
61+
62+
NodeEvaluator GetEvaluator(const torch::jit::Node* n) {
63+
auto evaluator = FindEvaluator(n);
64+
TRTORCH_CHECK(evaluator, "Requested evaluator for " << n->kind().toQualString() << ", but no such evaluator was found");
65+
return evaluator;
3466
}
3567

3668
bool EvalAtConversionTime(const torch::jit::Node* n) {
37-
auto eval_at_conversion = evaluator_lut_.find(n->kind());
38-
if (eval_at_conversion == evaluator_lut_.end()) {
69+
auto evaluator = FindEvaluator(n);
70+
if (evaluator == nullptr) {
3971
return false;
4072
} else {
4173
return true;
@@ -58,16 +90,16 @@ bool shouldEvalAtConversionTime(const torch::jit::Node* n) {
5890
}
5991

6092
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
61-
auto evaluator = get_evaluator_registry().GetEvaluator(n->kind());
93+
auto evaluator = get_evaluator_registry().GetEvaluator(n);
6294
return evaluator(n, args);
6395
}
6496

65-
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator) {
66-
get_evaluator_registry().RegisterEvaluator(node_kind, evaluator);
97+
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
98+
get_evaluator_registry().RegisterEvaluator(node_kind, std::move(eval_reg));
6799
}
68100

69101
void register_node_evaluator(EvalRegistration r) {
70-
register_node_evaluator(r.kind, r.evaluator);
102+
register_node_evaluator(r.kind, std::move(r));
71103
}
72104

73105
RegisterNodeEvaluators&& RegisterNodeEvaluators::evaluator(EvalRegistration r) && {

0 commit comments

Comments
 (0)