Skip to content

Commit e08e78e

Browse files
gs-olivebowang007
authored andcommitted
fix: Repair invalid schema arising from lowering pass (#1786)
1 parent 919d0f1 commit e08e78e

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

core/lowering/passes/remove_unnecessary_casts.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,48 @@ void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) {
138138
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
139139
user->destroy();
140140
break;
141+
case c10::aten::div:
142+
// If the first two entries to aten::div are non-Tensors,
143+
// there cannot be a rounding mode specified (3rd entry)
144+
if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) &&
145+
!user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) &&
146+
user->inputs().size() == 3 &&
147+
user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) &&
148+
torch::jit::toIValue(user->inputs()[2]).has_value()) {
149+
// Select the first 2 entries of the inputs, corresponding to the values
150+
auto div_args = user->inputs().slice(0, 2);
151+
152+
// Depending on the rounding mode, create the appropriate nodes
153+
if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") {
154+
// Truncate case (round result towards 0)
155+
torch::jit::Node* new_node_div;
156+
// Create node which simply divides the two entries
157+
new_node_div = g->create(c10::aten::div, div_args, 1);
158+
new_node_div->insertAfter(user);
159+
new_node_div->outputs()[0]->setType(c10::FloatType::get());
160+
161+
// Create node which casts the result to an integer, effectively truncating
162+
new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1);
163+
new_node->insertAfter(new_node_div);
164+
new_node->outputs()[0]->setType(c10::IntType::get());
165+
166+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
167+
user->destroy();
168+
break;
169+
170+
} else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") {
171+
// Floor case (round result down)
172+
// Replace aten::div with aten::floordiv
173+
new_node = g->create(c10::aten::floordiv, div_args, 1);
174+
new_node->insertAfter(user);
175+
new_node->outputs()[0]->setType(c10::IntType::get());
176+
177+
user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]);
178+
user->destroy();
179+
break;
180+
}
181+
}
182+
141183
default:
142184
new_node = g->create(user->kind(), user->inputs(), 1);
143185
new_node->insertAfter(user);

tests/core/lowering/test_remove_unnecessary_casts.cpp

+151
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
#include "torch/csrc/jit/passes/canonicalize.h"
9+
#include "torch/csrc/jit/passes/constant_pooling.h"
810

911
TEST(LoweringPasses, RemoveUnnecessaryCastIntCorrectly) {
1012
std::string source_graph = R"IR(
@@ -255,6 +257,155 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) {
255257
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
256258
}
257259

260+
TEST(LoweringPasses, RemoveSingleUse0DTensorsDivTruncIntValuesAgree) {
261+
// Ensure the source and target graphs have equivalent outputs
262+
// (Source and Target are computing equivalent values)
263+
std::string source_graph_no_inputs = R"IR(
264+
graph():
265+
%0: int = prim::Constant[value=2]()
266+
%11: int = prim::Constant[value=-3]()
267+
%234 : str = prim::Constant[value="trunc"]()
268+
%3: Tensor = prim::NumToTensor(%0)
269+
%1: Tensor = prim::NumToTensor(%11)
270+
%4: Tensor = aten::div(%1, %3, %234)
271+
%50: int = aten::Int(%4)
272+
%5: Tensor = prim::NumToTensor(%50)
273+
return (%5))IR";
274+
std::string target_graph_no_inputs = R"IR(
275+
graph():
276+
%0: int = prim::Constant[value=2]()
277+
%1: int = prim::Constant[value=-3]()
278+
%40: float = aten::div(%1, %0)
279+
%41: int = aten::Int(%40)
280+
%4: Tensor = prim::NumToTensor(%41)
281+
return (%4))IR";
282+
283+
auto g_in = std::make_shared<torch::jit::Graph>();
284+
auto g_out = std::make_shared<torch::jit::Graph>();
285+
286+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
287+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
288+
289+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
290+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});
291+
292+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
293+
294+
// Ensure the lowering pass transforms the first graph into the second
295+
std::string source_graph = R"IR(
296+
graph(%0: int):
297+
%1: Tensor = prim::Constant[value=[8]]()
298+
%3: Tensor = prim::NumToTensor(%0)
299+
%234: str = prim::Constant[value="trunc"]()
300+
%4: Tensor = aten::div(%3, %1, %234)
301+
%5: int = aten::Int(%4)
302+
return (%5))IR";
303+
304+
std::string target_graph = R"IR(
305+
graph(%0 : int):
306+
%1 : str = prim::Constant[value="trunc"]()
307+
%2 : int = prim::Constant[value=8]()
308+
%3 : float = aten::div(%0, %2)
309+
%4 : int = aten::Int(%3)
310+
return (%4))IR";
311+
312+
auto sg = std::make_shared<torch::jit::Graph>();
313+
torch::jit::parseIR(source_graph, &*sg);
314+
315+
auto first_op = *(sg->block()->nodes().begin());
316+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
317+
r->copyMetadata(first_op->output());
318+
r->setType(c10::TensorType::get());
319+
first_op->output()->replaceAllUsesWith(r);
320+
first_op->destroy();
321+
322+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
323+
torch::jit::ConstantPooling(sg);
324+
sg = torch::jit::Canonicalize(sg, false);
325+
326+
auto tg = std::make_shared<torch::jit::Graph>();
327+
torch::jit::parseIR(target_graph, &*tg);
328+
torch::jit::ConstantPooling(tg);
329+
tg = torch::jit::Canonicalize(tg, false);
330+
331+
// Validate identical graphs after pooling constants and canonicalizing
332+
ASSERT_TRUE((tg->toString() == sg->toString()));
333+
}
334+
335+
TEST(LoweringPasses, RemoveSingleUse0DTensorsDivFloorIntValuesAgree) {
336+
// Ensure the source and target graphs have equivalent outputs
337+
// (Source and Target are computing equivalent values)
338+
std::string source_graph_no_inputs = R"IR(
339+
graph():
340+
%0: int = prim::Constant[value=2]()
341+
%11: int = prim::Constant[value=-3]()
342+
%234 : str = prim::Constant[value="floor"]()
343+
%3: Tensor = prim::NumToTensor(%0)
344+
%1: Tensor = prim::NumToTensor(%11)
345+
%4: Tensor = aten::div(%1, %3, %234)
346+
%50: int = aten::Int(%4)
347+
%5: Tensor = prim::NumToTensor(%50)
348+
return (%5))IR";
349+
std::string target_graph_no_inputs = R"IR(
350+
graph():
351+
%0: int = prim::Constant[value=2]()
352+
%1: int = prim::Constant[value=-3]()
353+
%40: int = aten::floordiv(%1, %0)
354+
%41: int = aten::Int(%40)
355+
%4: Tensor = prim::NumToTensor(%41)
356+
return (%4))IR";
357+
358+
auto g_in = std::make_shared<torch::jit::Graph>();
359+
auto g_out = std::make_shared<torch::jit::Graph>();
360+
361+
torch::jit::parseIR(source_graph_no_inputs, g_in.get());
362+
torch::jit::parseIR(target_graph_no_inputs, g_out.get());
363+
364+
auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {});
365+
auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {});
366+
367+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor()));
368+
369+
// Ensure the lowering pass transforms the first graph into the second
370+
std::string source_graph = R"IR(
371+
graph(%0: int):
372+
%1: Tensor = prim::Constant[value=[8]]()
373+
%3: Tensor = prim::NumToTensor(%0)
374+
%234: str = prim::Constant[value="floor"]()
375+
%4: Tensor = aten::div(%3, %1, %234)
376+
%5: int = aten::Int(%4)
377+
return (%5))IR";
378+
379+
std::string target_graph = R"IR(
380+
graph(%0 : int):
381+
%1 : str = prim::Constant[value="floor"]()
382+
%2 : int = prim::Constant[value=8]()
383+
%3 : int = aten::floordiv(%0, %2)
384+
return (%3))IR";
385+
386+
auto sg = std::make_shared<torch::jit::Graph>();
387+
torch::jit::parseIR(source_graph, &*sg);
388+
389+
auto first_op = *(sg->block()->nodes().begin());
390+
torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8), c10::nullopt, first_op->scope());
391+
r->copyMetadata(first_op->output());
392+
r->setType(c10::TensorType::get());
393+
first_op->output()->replaceAllUsesWith(r);
394+
first_op->destroy();
395+
396+
torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg);
397+
torch::jit::ConstantPooling(sg);
398+
sg = torch::jit::Canonicalize(sg, false);
399+
400+
auto tg = std::make_shared<torch::jit::Graph>();
401+
torch::jit::parseIR(target_graph, &*tg);
402+
torch::jit::ConstantPooling(tg);
403+
tg = torch::jit::Canonicalize(tg, false);
404+
405+
// Validate identical graphs after pooling constants and canonicalizing
406+
ASSERT_TRUE((tg->toString() == sg->toString()));
407+
}
408+
258409
TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) {
259410
std::string source_graph_no_inputs = R"IR(
260411
graph():

0 commit comments

Comments
 (0)