Skip to content

Commit 6421f3d

Browse files
committed
feat(//core/conversion): Evaluation of static conditionals works now
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 7466b8a commit 6421f3d

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

Diff for: core/conversion/conversion.cpp

+44-10
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190190
}
191191
}
192192

193+
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n);
194+
193195
void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194196
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195197
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
@@ -204,6 +206,31 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
204206
}
205207
}
206208

209+
void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
210+
auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
211+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
212+
auto b = condition ? n->blocks()[0] : n->blocks()[1];
213+
214+
for (const auto bn : b->nodes()) {
215+
if (bn->kind() == torch::jit::prim::Loop) {
216+
EvaluateLoopBlock(ctx, bn);
217+
} else if (bn->kind() == torch::jit::prim::If) {
218+
EvaluateConditionalBlock(ctx, bn);
219+
} else {
220+
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.")
221+
auto eval = EvaluateNode(ctx, bn);
222+
if (!eval.value().isTensor()) {
223+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
224+
} else {
225+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
226+
}
227+
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
228+
}
229+
}
230+
231+
MapIValues(ctx, b->outputs(), n->outputs(), 0, 0);
232+
}
233+
207234
// TODO: With functionalization pass we may be able to make this into a regular evaluator later
208235
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
209236
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
@@ -213,16 +240,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
213240

214241
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);
215242

216-
LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
217-
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
218-
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
219-
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
243+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n);
244+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
245+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool());
246+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
220247

221248
while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
222249
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
223250
for (auto bn : n->blocks()[0]->nodes()) {
224-
auto eval = EvaluateNode(ctx, bn);
225-
if (eval) {
251+
if (bn->kind() == torch::jit::prim::Loop) {
252+
EvaluateLoopBlock(ctx, n);
253+
} else if (bn->kind() == torch::jit::prim::If) {
254+
EvaluateConditionalBlock(ctx, bn);
255+
} else {
256+
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
257+
auto eval = EvaluateNode(ctx, bn);
226258
if (!eval.value().isTensor()) {
227259
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
228260
} else {
@@ -236,8 +268,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
236268
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
237269
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
238270
trip_count.swap(new_trip_count);
239-
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
240-
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
271+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool());
272+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
241273
}
242274
}
243275

@@ -255,6 +287,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
255287
bool blacklisted = isNodeConversionBlacklisted(n);
256288
if (n->kind() == torch::jit::prim::Loop) {
257289
EvaluateLoopBlock(ctx, n);
290+
} else if (n->kind() == torch::jit::prim::If) {
291+
EvaluateConditionalBlock(ctx, n);
258292
} else if (to_eval) {
259293
auto eval = EvaluateNode(ctx, n);
260294
if (eval) {
@@ -303,10 +337,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
303337
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
304338
std::set<std::string> unsupported_ops;
305339
for (const auto n : b->nodes()) {
306-
if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) {
340+
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
307341
auto schema = n->maybeSchema();
308342
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
309-
<< " (conversion.VerifyCoverterSupportForBlock");
343+
<< " (conversion.VerifyCoverterSupportForBlock)");
310344
std::stringstream ss;
311345
ss << *schema;
312346
unsupported_ops.insert(ss.str());

0 commit comments

Comments
 (0)