-
Notifications
You must be signed in to change notification settings - Fork 363
feat: support int64 <=> int32 auto conversion #1407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
cb3b250
4ab7e9c
194a3bb
2fd0223
710a42b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
#include <queue> | ||
#include "ATen/ATen.h" | ||
#include "torch/csrc/jit/api/module.h" | ||
#include "torch/csrc/jit/passes/constant_pooling.h" | ||
|
@@ -57,6 +58,62 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI | |
return ivalue_map; | ||
} | ||
|
||
torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) { | ||
std::queue<torch::jit::Value*> q; | ||
q.push(val); | ||
std::unordered_set<torch::jit::Node*> visited; | ||
while (!q.empty()) { | ||
auto cur_val = q.front(); | ||
q.pop(); | ||
auto node = cur_val->node(); | ||
if (node->kind().toQualString() == std::string("aten::to")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May need an additional check to ensure that the
A check could be something like an additional (node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
(node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @gs-olive Any reproducer for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @bowang007 - as an update, while this is no longer throwing an error on my end, my thought was that we do need this check you have, but maybe it should be more stringent - something like: if ((node->kind().toQualString() == std::string("aten::to")) &&
((node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
(node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType))) { This is because, in the case where the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bowang007 Please let me know what you think about the comment in the thread above: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gs-olive I got your point now, let me update this part. |
||
return node; | ||
} | ||
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) { | ||
visited.insert(node); | ||
for (auto input : node->inputs()) { | ||
q.push(input); | ||
} | ||
} | ||
} | ||
return nullptr; | ||
} | ||
|
||
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) { | ||
torch::jit::Node* cast_node = getUpstreamCastNode(seg_block.raw_inputs()[index]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an if/else here to use |
||
auto g = seg_block.g(); | ||
// if we can find upstream aten::to node, we use it's parameters for creating new cast node | ||
if (cast_node) { | ||
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map; | ||
value_map.insert({cast_node->inputs()[0], g->inputs()[index]}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May need an if/else here to check if insert should be |
||
if (!is_input) { | ||
// if this value is output, we need to cast it to int32 | ||
auto const_val = g->insertConstant(3); | ||
value_map.insert({cast_node->inputs()[1], const_val}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throws an error when the upstream The check here could be something like an if/else checking the debugName at the second index, as in: if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
value_map.insert({cast_node->inputs()[2], const_val});
} else {
value_map.insert({cast_node->inputs()[1], const_val});
} |
||
} | ||
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, g, value_map); }; | ||
cast_node = g->createClone(cast_node, env); | ||
// auto cast_node = g->prependNode(g->createClone(cast_node, env)); | ||
} else { | ||
// if there is no explicit cast aten::to operation, we need to create a node | ||
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3); | ||
auto const_zero = g->insertConstant(0); | ||
const_zero->setType(torch::jit::BoolType::get()); | ||
auto none_val = g->insertNode(g->createNone())->output(); | ||
cast_node = g->create(torch::jit::aten::to, {g->inputs()[index], const_type, const_zero, const_zero, none_val}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an if/else here to use |
||
|
||
// auto cast_node = g->prependNode(g->create(torch::jit::aten::to, {g->inputs()[i], const_type, const_zero, | ||
// const_zero, none_val})); seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, | ||
// cast_node->outputs()[0]); LOG_DEBUG(seg_block << " in shape analysis"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider removing commented code if not needed. |
||
} | ||
if (is_input) { | ||
g->prependNode(cast_node); | ||
} else { | ||
g->appendNode(cast_node); | ||
} | ||
return cast_node; | ||
} | ||
|
||
void getSegmentsOutputByRunning( | ||
SegmentedBlock& seg_block, | ||
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps, | ||
|
@@ -142,16 +199,43 @@ void getSegmentsOutputByRunning( | |
ivalues_maps[output] = jit_results[idx++]; | ||
} | ||
|
||
// auto int64 <=> int32 conversion | ||
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) { | ||
// First, check if there is Int64 input | ||
for (size_t i = 0; i < seg_block.inputs().size(); ++i) { | ||
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { | ||
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; | ||
at::ScalarType t = cur_ivalue.toTensor().scalar_type(); | ||
if (t == at::kLong) { | ||
// we add a cast operation to cast the type to Int64 | ||
auto cast_node = createCastNode(seg_block, i, true); | ||
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]); | ||
} | ||
} | ||
} | ||
for (size_t i = 0; i < seg_block.outputs().size(); ++i) { | ||
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) { | ||
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]]; | ||
at::ScalarType t = cur_ivalue.toTensor().scalar_type(); | ||
if (t == at::kLong) { | ||
auto cast_node = createCastNode(seg_block, i, false); | ||
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// set input shape for each segmented block so we wil use it in conversion process | ||
std::vector<ir::Input> input_shapes; | ||
std::vector<at::ScalarType> input_types; | ||
for (auto& i : seg_block.raw_inputs()) { | ||
if (ivalues_maps[i].isTensor()) { | ||
for (size_t i = 0; i < seg_block.inputs().size(); ++i) { | ||
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) { | ||
// set the input_shape and data_type | ||
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for | ||
// shape inference | ||
auto cur_ivalue = ivalues_maps[i]; | ||
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]]; | ||
at::ScalarType t = cur_ivalue.toTensor().scalar_type(); | ||
|
||
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) { | ||
TORCHTRT_THROW_ERROR( | ||
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled"); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#include <string> | ||
#include "core/partitioning/partitioning.h" | ||
#include "core/util/trt_util.h" | ||
#include "gtest/gtest.h" | ||
#include "torch/csrc/jit/ir/irparser.h" | ||
#include "torch/script.h" | ||
|
||
bool checkInsertedCastNodeNumber(torch_tensorrt::core::partitioning::SegmentedBlock& seg_block, int target_count) { | ||
int64_t cnt = 0; | ||
for (auto node : seg_block.nodes()) { | ||
if (node->kind().toQualString() == std::string("aten::to")) { | ||
cnt++; | ||
} | ||
} | ||
std::cout << "Found count of " << cnt << " inserted aten::to nodes, (looking for " << target_count << " aten::to nodes)" | ||
<< std::endl; | ||
|
||
return target_count == cnt; | ||
} | ||
|
||
TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) { | ||
const auto graph = R"IR( | ||
graph(%0 : Tensor, | ||
%1 : Tensor): | ||
%2 : int = prim::Constant[value=4]() | ||
%3 : bool = prim::Constant[value=0]() | ||
%4 : NoneType = prim::Constant() | ||
%5 : int = prim::Constant[value=1]() | ||
%7: Tensor = aten::to(%1, %2, %3, %3, %4) | ||
%8 : Tensor = aten::mul(%0, %0) | ||
%9 : Tensor = aten::scatter(%8, %5, %7, %5) | ||
%10 : Tensor = aten::scatter(%7, %5, %7, %5) | ||
%12 : Tensor = aten::add(%10, %10, %5) | ||
return (%9, %12))IR"; | ||
|
||
auto g = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(graph, g.get(), true); | ||
|
||
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; | ||
partitioning_info.enabled = true; | ||
partitioning_info.forced_fallback_operators = {"aten::scatter"}; | ||
partitioning_info.truncate_long_and_double = true; | ||
std::vector<torch_tensorrt::core::ir::Input> inputs; | ||
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); | ||
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); | ||
|
||
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map; | ||
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types; | ||
inputs_map.insert({g->inputs()[0], {inputs[0]}}); | ||
input_types.insert({g->inputs()[0], {{at::kFloat}}}); | ||
inputs_map.insert({g->inputs()[1], {inputs[1]}}); | ||
input_types.insert({g->inputs()[1], {{at::kInt}}}); | ||
|
||
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); | ||
|
||
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); | ||
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); | ||
auto segmented_blocks = ctx.partitioned_blocks.begin()->second; | ||
|
||
for (auto& seg_block : segmented_blocks) { | ||
LOG_DEBUG(seg_block << " cur seg block"); | ||
} | ||
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); | ||
|
||
} | ||
|
||
TEST(Partitioning, ImplicitAutoConversionCorrectly) { | ||
const auto graph = R"IR( | ||
graph(%0 : Tensor): | ||
%2 : int = prim::Constant[value=0]() | ||
%4 : int = aten::size(%0, %2) | ||
%6 : Tensor = prim::NumToTensor(%4) | ||
%2 : int = prim::Constant[value=5]() | ||
%7 : int[] = prim::ListConstruct(%2, %2) | ||
%8 : bool = prim::Constant[value=0]() | ||
%9 : Tensor = aten::expand(%6, %7, %8) | ||
|
||
%10 : Tensor = aten::mul(%9, %9) | ||
return (%10))IR"; | ||
|
||
auto g = std::make_shared<torch::jit::Graph>(); | ||
torch::jit::parseIR(graph, g.get(), true); | ||
|
||
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info; | ||
partitioning_info.enabled = true; | ||
partitioning_info.forced_fallback_operators = {"aten::expand"}; | ||
partitioning_info.truncate_long_and_double = true; | ||
std::vector<torch_tensorrt::core::ir::Input> inputs; | ||
|
||
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5})); | ||
|
||
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map; | ||
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types; | ||
inputs_map.insert({g->inputs()[0], {inputs[0]}}); | ||
input_types.insert({g->inputs()[0], {{at::kFloat}}}); | ||
|
||
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types); | ||
|
||
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info); | ||
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map); | ||
auto segmented_blocks = ctx.partitioned_blocks.begin()->second; | ||
|
||
for (auto& seg_block : segmented_blocks) { | ||
LOG_DEBUG(seg_block << " cur seg block"); | ||
} | ||
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2)); | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.