Skip to content

Commit a86ac93

Browse files
committed
fix: Add lowering pass to resolve aten::Int.Tensor
- Implement lowering pass which detects canonical `aten::Int.Tensor` cases and recursively replaces input Value pointers until all 0D tensors have been resolved to their scalar components - Lowering pass is specialized to replacing strictly integer-typed Value pointers and can only trace through aten::mul and aten::floor_divide operators, which are two of the most common cases of use - Lowering pass traverses the graph until one of three base cases are encountered (or an invalid Value type is detected). These cases are `prim::NumToTensor`, `prim::Constant` (0D tensor), or simple integers. It then replaces the child nodes with the integer equivalents of the produced Tensors - Added extensive testing of new capabilities for accuracy, robustness, and functionality
1 parent 98376b3 commit a86ac93

File tree

4 files changed

+298
-0
lines changed

4 files changed

+298
-0
lines changed

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
142142
passes::AliasOperators(g);
143143
passes::SiluToSigmoidMultipication(g);
144144
passes::RemoveSingleUse0DTensors(g);
145+
passes::ReplaceAtenInt(g);
145146
passes::RemoveUnnecessaryCasts(g);
146147
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
147148
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
3131
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
3232
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
3333
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
34+
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g);
3435
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3536
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3637
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/remove_unnecessary_casts.cpp

+144
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
23
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
34

45
#include "core/util/prelude.h"
@@ -211,6 +212,149 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
211212
LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g);
212213
}
213214

215+
// Schemas for Aten::Int which can be replaced by scalar equivalents
216+
const std::unordered_set<c10::Symbol> AtenIntReplacementNodeKinds = {
217+
torch::jit::aten::mul,
218+
torch::jit::aten::floor_divide,
219+
};
220+
221+
torch::jit::Value* Validate0DTensor(torch::jit::Value* value) {
222+
// Validates that the input Value* is a 0D Tensor (or int/float)
223+
// Return the stored int/float Value* if so, otherwise null
224+
torch::jit::Value* enclosed_scalar_value = nullptr;
225+
226+
// Regular Int/Float case
227+
if (value->type()->isSubtypeOf(c10::IntType::get()) || value->type()->isSubtypeOf(c10::FloatType::get())) {
228+
enclosed_scalar_value = value;
229+
return enclosed_scalar_value;
230+
}
231+
232+
// Constant Tensor case
233+
if (value->node()->kind() == torch::jit::prim::Constant && value->type()->isSubtypeOf(c10::TensorType::get())) {
234+
// Retrieve the Tensor stored in constant
235+
at::Tensor t = *torch::jit::constant_as<at::Tensor>(value);
236+
// Validate the shape of the Tensor is 0D (single-element) and integral
237+
if (t.sizes() == std::vector<int64_t>({}) && t.item().isIntegral()) {
238+
// Extract the stored value, add it to the graph as a constant
239+
torch::jit::WithInsertPoint guard(value->node());
240+
auto new_const_val = value->owningGraph()->insertConstant(t.item(), c10::nullopt, value->node()->scope());
241+
new_const_val->copyMetadata(value);
242+
new_const_val->setType(c10::IntType::get());
243+
enclosed_scalar_value = new_const_val;
244+
return enclosed_scalar_value;
245+
} else {
246+
LOG_DEBUG("In aten::Int.Tensor removal, encountered a const which was either not 0D or not integral");
247+
}
248+
}
249+
250+
// NumToTensor case
251+
if (value->node()->kind() == torch::jit::prim::NumToTensor && value->type()->isSubtypeOf(c10::TensorType::get())) {
252+
// Input to NumToTensor is relevant scalar
253+
enclosed_scalar_value = value->node()->input();
254+
return enclosed_scalar_value;
255+
}
256+
257+
return enclosed_scalar_value;
258+
}
259+
260+
torch::jit::Value* TracebackAndEliminate0DTensors(torch::jit::Node* node) {
261+
// Trace back through a node and all parents to eliminate 0D Tensors
262+
// and update schemas to their scalar alternatives, returning final
263+
// Value* to user
264+
265+
// Requires valid schema with at least two inputs
266+
if (AtenIntReplacementNodeKinds.find(node->kind()) == AtenIntReplacementNodeKinds.end() ||
267+
node->inputs().size() < 2) {
268+
LOG_DEBUG(
269+
"Encountered node " << node->kind().toQualString()
270+
<< " which is unsupported in the aten::Int.Tensor replacement lowering pass.");
271+
return nullptr;
272+
}
273+
274+
// Validate the first and second function inputs are 0D tensors or scalars
275+
torch::jit::Value* first_input_scalar_value = Validate0DTensor(node->inputs()[0]);
276+
torch::jit::Value* second_input_scalar_value = Validate0DTensor(node->inputs()[1]);
277+
278+
// If the first input is not a scalar, recursively traceback on parent nodes
279+
if (!first_input_scalar_value) {
280+
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
281+
first_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[0]->node());
282+
}
283+
284+
// If the second input is not a scalar, recursively traceback on parent nodes
285+
if (!second_input_scalar_value) {
286+
LOG_DEBUG("In aten::Int.Tensor lowering, now tracing " << node->inputs()[0]->node()->kind().toQualString());
287+
second_input_scalar_value = TracebackAndEliminate0DTensors(node->inputs()[1]->node());
288+
}
289+
290+
if (!first_input_scalar_value || !second_input_scalar_value) {
291+
LOG_DEBUG(
292+
"In aten::Int.Tensor lowering, recursive trace through node input "
293+
<< "parents failed to return a Scalar value for at least one parent node.");
294+
return nullptr;
295+
}
296+
297+
// Set default insert point at node
298+
torch::jit::WithInsertPoint guard(node);
299+
torch::jit::Node* new_node;
300+
301+
switch (node->kind()) {
302+
// In the aten::floor_divide case, the schema syntax changes, so a new node
303+
// must be inserted
304+
case torch::jit::aten::floor_divide:
305+
new_node = node->owningGraph()->create(
306+
torch::jit::aten::floordiv, {first_input_scalar_value, second_input_scalar_value}, 1);
307+
new_node->insertAfter(node);
308+
new_node->output()->setType(c10::IntType::get());
309+
return new_node->output();
310+
311+
// In the aten::mul case, the schema syntax is the same, so we can use the existing schema
312+
// with new inputs
313+
default:
314+
new_node = node->owningGraph()->create(node->kind(), {first_input_scalar_value, second_input_scalar_value}, 1);
315+
new_node->insertAfter(node);
316+
new_node->output()->setType(c10::IntType::get());
317+
return new_node->output();
318+
}
319+
}
320+
321+
void ReplaceAtenInt(std::shared_ptr<torch::jit::Graph>& g) {
322+
// Find all nodes with the aten::Int.Tensor schema and replace those
323+
// by tracing through the node and resolving the use of 0D tensors
324+
// to their corresponding scalar alternatives
325+
326+
// Iterate over all nodes in the graph
327+
for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) {
328+
// Validate schema requirements for aten::Int.Tensor
329+
if (it->kind() == torch::jit::aten::Int && it->inputs().size() == 1 &&
330+
it->input()->type()->isSubtypeOf(c10::TensorType::get())) {
331+
LOG_DEBUG("Found an aten::Int.Tensor case, attempting to resolve input scalars.");
332+
333+
// If the node parent schema is of a supported type, trace back through the graph
334+
if (AtenIntReplacementNodeKinds.find(it->input()->node()->kind()) != AtenIntReplacementNodeKinds.end()) {
335+
LOG_DEBUG(
336+
"Tracing parent node " << it->input()->node()->kind().toQualString()
337+
<< " to eliminate 0D Tensors for aten::Int.Tensor case.");
338+
auto scalar_input_value = TracebackAndEliminate0DTensors(it->input()->node());
339+
if (scalar_input_value) {
340+
it->output()->replaceAllUsesWith(scalar_input_value);
341+
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case succeeded.");
342+
} else {
343+
LOG_DEBUG("Tracing parent nodes for aten::Int.Tensor case failed.");
344+
}
345+
} else {
346+
LOG_DEBUG(
347+
"Parent node schema " << it->input()->node()->kind().toQualString()
348+
<< " is currently unsupported for aten::Int.Tensor case.");
349+
}
350+
}
351+
}
352+
353+
// Clean up remnant operators in graph
354+
torch::jit::EliminateDeadCode(g);
355+
LOG_GRAPH("Post removing aten.Int.Tensor operations: " << *g);
356+
}
357+
214358
} // namespace passes
215359
} // namespace lowering
216360
} // namespace core

tests/core/lowering/test_remove_unnecessary_casts.cpp

+152
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
437437
ASSERT_TRUE(
438438
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
439439
}
440+
441+
TEST(LoweringPasses, RemoveAtenIntTensorValuesAgree) {
442+
std::string source_graph_no_inputs = R"IR(
443+
graph():
444+
%0: int = prim::Constant[value=2]()
445+
%11: int = prim::Constant[value=7]()
446+
%3: Tensor = prim::NumToTensor(%0)
447+
%1: Tensor = prim::NumToTensor(%11)
448+
%4: Tensor = aten::floor_divide(%1, %3)
449+
%7: Tensor = aten::mul(%3, %4)
450+
%8: Tensor = aten::mul(%7, %1)
451+
%50: int = aten::Int(%8)
452+
%5: Tensor = prim::NumToTensor(%50)
453+
return (%5))IR";
454+
std::string target_graph_no_inputs = R"IR(
455+
graph():
456+
%0: int = prim::Constant[value=2]()
457+
%1: int = prim::Constant[value=7]()
458+
%4: int = aten::floordiv(%1, %0)
459+
%7: int = aten::mul(%0, %4)
460+
%40: int = aten::mul(%7, %1)
461+
%4: Tensor = prim::NumToTensor(%40)
462+
return (%4))IR";
463+
464+
auto g_in = std::make_shared<torch::jit::Graph>();
465+
auto g_out = std::make_shared<torch::jit::Graph>();
466+
467+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
468+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
469+
470+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
471+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});
472+
473+
ASSERT_TRUE(
474+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
475+
476+
// Ensure the lowering pass transforms the first graph into the second
477+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
478+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
479+
auto sg = std::make_shared<torch::jit::Graph>();
480+
torch::jit::parseIR(source_graph_no_inputs, sg.get());
481+
482+
torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);
483+
484+
auto tg = std::make_shared<torch::jit::Graph>();
485+
torch::jit::parseIR(target_graph_no_inputs, tg.get());
486+
487+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
488+
}
489+
490+
TEST(LoweringPasses, RemoveAtenIntSizeTensorValuesAgree) {
491+
std::string source_graph_no_inputs = R"IR(
492+
graph(%x.0: Tensor):
493+
%10: int = prim::Constant[value=0]()
494+
%100: int = aten::size(%x.0, %10)
495+
%0: Tensor = prim::NumToTensor(%100)
496+
%11: int = prim::Constant[value=9]()
497+
%1: Tensor = prim::NumToTensor(%11)
498+
%4: Tensor = aten::floor_divide(%1, %0)
499+
%7: Tensor = aten::mul(%0, %4)
500+
%8: Tensor = aten::mul(%7, %1)
501+
%50: int = aten::Int(%8)
502+
%5: Tensor = prim::NumToTensor(%50)
503+
return (%5))IR";
504+
std::string target_graph_no_inputs = R"IR(
505+
graph(%x.0: Tensor):
506+
%10: int = prim::Constant[value=0]()
507+
%0: int = aten::size(%x.0, %10)
508+
%1: int = prim::Constant[value=9]()
509+
%4: int = aten::floordiv(%1, %0)
510+
%7: int = aten::mul(%0, %4)
511+
%40: int = aten::mul(%7, %1)
512+
%4: Tensor = prim::NumToTensor(%40)
513+
return (%4))IR";
514+
515+
auto g_in = std::make_shared<torch::jit::Graph>();
516+
auto g_out = std::make_shared<torch::jit::Graph>();
517+
518+
auto in_0 = at::rand({2, 3, 5, 5}, {at::kCUDA});
519+
520+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
521+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
522+
523+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {in_0});
524+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {in_0});
525+
526+
ASSERT_TRUE(
527+
torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6));
528+
529+
// Ensure the lowering pass transforms the first graph into the second
530+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
531+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
532+
auto sg = std::make_shared<torch::jit::Graph>();
533+
torch::jit::parseIR(source_graph_no_inputs, sg.get());
534+
535+
torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);
536+
537+
auto tg = std::make_shared<torch::jit::Graph>();
538+
torch::jit::parseIR(target_graph_no_inputs, tg.get());
539+
540+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
541+
}
542+
543+
TEST(LoweringPasses, RemoveAtenIntConstTensorValuesAgree) {
544+
// Ensure the lowering pass transforms the first graph into the second
545+
std::string source_graph = R"IR(
546+
graph(%0: int):
547+
%1: Tensor = prim::Constant[value=[8]]()
548+
%3: Tensor = prim::NumToTensor(%0)
549+
%4: Tensor = aten::floor_divide(%3, %1)
550+
%5: int = aten::Int(%4)
551+
return (%5))IR";
552+
553+
std::string target_graph = R"IR(
554+
graph(%0 : int):
555+
%1 : Tensor = prim::Constant[value=[8]]()
556+
%2 : int = prim::Constant[value=8]()
557+
%3 : int = aten::floordiv(%0, %2)
558+
return (%3))IR";
559+
560+
auto sg = std::make_shared<torch::jit::Graph>();
561+
torch::jit::parseIR(source_graph, &*sg);
562+
563+
// Manually enter 0d tensor const for source
564+
auto first_op_sg = *(sg->block()->nodes().begin());
565+
torch::jit::Value* r_sg = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_sg->scope());
566+
r_sg->copyMetadata(first_op_sg->output());
567+
r_sg->setType(c10::TensorType::get());
568+
first_op_sg->output()->replaceAllUsesWith(r_sg);
569+
first_op_sg->destroy();
570+
571+
torch_tensorrt::core::lowering::passes::ReplaceAtenInt(sg);
572+
torch::jit::ConstantPooling(sg);
573+
sg = torch::jit::Canonicalize(sg, false);
574+
575+
auto tg = std::make_shared<torch::jit::Graph>();
576+
torch::jit::parseIR(target_graph, &*tg);
577+
578+
// Manually enter 0d tensor const for target
579+
auto first_op_tg = *(tg->block()->nodes().begin());
580+
torch::jit::Value* r_tg = tg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op_tg->scope());
581+
r_tg->copyMetadata(first_op_tg->output());
582+
r_tg->setType(c10::TensorType::get());
583+
first_op_tg->output()->replaceAllUsesWith(r_tg);
584+
first_op_tg->destroy();
585+
586+
torch::jit::ConstantPooling(tg);
587+
tg = torch::jit::Canonicalize(tg, false);
588+
589+
// Validate identical graphs after pooling constants and canonicalizing
590+
ASSERT_TRUE((tg->toString() == sg->toString()));
591+
}

0 commit comments

Comments
 (0)