Skip to content

Commit 9d1946e

Browse files
committed
feat(//core/conversion): Compiler can now create graphs
out of programs that use conditionals if it can be gaurenteed that there is a single code path followed through the course of the program given input information and the graph This means that right now conditionals within loops is not supported but if a program has a bunch of evaluatable cases and those cases produce tensors as long as the program does not need to run both branches conditionally at runtime the program can still be compiled Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 07ba980 commit 9d1946e

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

Diff for: core/conversion/conversion.cpp

+25-7
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,27 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
201201
});
202202

203203
for (auto p : input_output_pairs) {
204-
auto input = ctx->evaluated_value_map[p.first];
205-
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
204+
if (ctx->evaluated_value_map.find(p.first) != ctx->evaluated_value_map.end()) {
205+
auto input = ctx->evaluated_value_map[p.first];
206+
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
207+
} else if (ctx->value_tensor_map.find(p.first) != ctx->value_tensor_map.end()) {
208+
auto input = ctx->value_tensor_map[p.first];
209+
ctx->value_tensor_map[p.second] = input;
210+
} else {
211+
TRTORCH_THROW_ERROR("Cannot find Value " << p.first->debugName() << " either evaluated values or tensor maps (MapIValues)");
212+
}
206213
}
207214
}
208215

209-
void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
216+
void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false) {
217+
bool output_type_includes_tensor = false;
218+
for (auto o : n->outputs()) {
219+
if (o->type()->isSubtypeOf(c10::TensorType::get())) {
220+
output_type_includes_tensor = true;
221+
}
222+
}
223+
TRTORCH_CHECK(!(contained_in_loop && output_type_includes_tensor), "TRTorch currently cannot compile conditionals within loops");
224+
210225
auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
211226
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
212227
auto b = condition ? n->blocks()[0] : n->blocks()[1];
@@ -215,16 +230,19 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
215230
if (bn->kind() == torch::jit::prim::Loop) {
216231
EvaluateLoopBlock(ctx, bn);
217232
} else if (bn->kind() == torch::jit::prim::If) {
218-
EvaluateConditionalBlock(ctx, bn);
219-
} else {
220-
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.")
233+
EvaluateConditionalBlock(ctx, bn, contained_in_loop);
234+
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
221235
auto eval = EvaluateNode(ctx, bn);
222236
if (!eval.value().isTensor()) {
223237
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
224238
} else {
225239
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
226240
}
227241
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
242+
} else if (converters::node_is_convertable(bn)) {
243+
AddLayer(ctx, bn);
244+
} else {
245+
TRTORCH_THROW_ERROR("TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn);
228246
}
229247
}
230248

@@ -251,7 +269,7 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
251269
if (bn->kind() == torch::jit::prim::Loop) {
252270
EvaluateLoopBlock(ctx, n);
253271
} else if (bn->kind() == torch::jit::prim::If) {
254-
EvaluateConditionalBlock(ctx, bn);
272+
EvaluateConditionalBlock(ctx, bn, true);
255273
} else {
256274
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
257275
auto eval = EvaluateNode(ctx, bn);

0 commit comments

Comments
 (0)