Skip to content

feat: support int64 <=> int32 auto conversion #1407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
registerSegmentsOutputs(ctx, block);

// run shape analysis on each segmented block
LOG_DEBUG("Running shape analysis for segmented graphs");
runShapeAnalysis(ctx, block, example_tensor_map);
}
}
Expand Down
91 changes: 88 additions & 3 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <queue>
#include "ATen/ATen.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
Expand Down Expand Up @@ -57,6 +58,61 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
return ivalue_map;
}

torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
std::queue<torch::jit::Value*> q;
q.push(val);
std::unordered_set<torch::jit::Node*> visited;
while (!q.empty()) {
auto cur_val = q.front();
q.pop();
auto node = cur_val->node();
if ((node->kind().toQualString() == std::string("aten::to")) &&
((node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
(node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType))) {
return node;
}
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
visited.insert(node);
for (auto input : node->inputs()) {
q.push(input);
}
}
}
return nullptr;
}

torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) {
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
auto g = seg_block.g();
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
if (cast_node) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
if (!is_input) {
// if this value is output, we need to cast it to int32
auto const_val = g->insertConstant(3);
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
value_map.insert({cast_node->inputs()[2], const_val});
} else {
value_map.insert({cast_node->inputs()[1], const_val});
}
}
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, g, value_map); };
cast_node = g->createClone(cast_node, env);
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
} else {
// if there is no explicit cast aten::to operation, we need to create a node
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
auto const_zero = g->insertConstant(0);
const_zero->setType(torch::jit::BoolType::get());
auto none_val = g->insertNode(g->createNone())->output();
cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
}
return cast_node;
}

void getSegmentsOutputByRunning(
SegmentedBlock& seg_block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
Expand Down Expand Up @@ -142,16 +198,45 @@ void getSegmentsOutputByRunning(
ivalues_maps[output] = jit_results[idx++];
}

// auto int64 <=> int32 conversion
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
// First, check if there is Int64 input
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true);
seg_block.g()->prependNode(cast_node);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
}
}
for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
auto cast_node = createCastNode(seg_block, i, false);
seg_block.g()->appendNode(cast_node);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
}
}
}
}

// set input shape for each segmented block so we wil use it in conversion process
std::vector<ir::Input> input_shapes;
std::vector<at::ScalarType> input_types;
for (auto& i : seg_block.raw_inputs()) {
if (ivalues_maps[i].isTensor()) {
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
// set the input_shape and data_type
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
// shape inference
auto cur_ivalue = ivalues_maps[i];
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();

if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
TORCHTRT_THROW_ERROR(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
Expand Down
5 changes: 5 additions & 0 deletions tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ partitioning_test(
name = "test_resolve_nontensor_inputs",
)

partitioning_test(
name = "test_type_auto_conversion",
)

cc_test(
name = "test_loading_model",
srcs = ["test_loading_model.cpp"],
Expand Down Expand Up @@ -112,5 +116,6 @@ test_suite(
":test_shape_analysis",
":test_stitched_graph",
":test_tensorrt_conversion",
":test_type_auto_conversion",
],
)
106 changes: 106 additions & 0 deletions tests/core/partitioning/test_type_auto_conversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include <string>
#include "core/partitioning/partitioning.h"
#include "core/util/trt_util.h"
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/script.h"

bool checkInsertedCastNodeNumber(torch_tensorrt::core::partitioning::SegmentedBlock& seg_block, int target_count) {
int64_t cnt = 0;
for (auto node : seg_block.nodes()) {
if (node->kind().toQualString() == std::string("aten::to")) {
cnt++;
}
}
std::cout << "Found count of " << cnt << " inserted aten::to nodes, (looking for " << target_count
<< " aten::to nodes)" << std::endl;

return target_count == cnt;
}

TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : int = prim::Constant[value=4]()
%3 : bool = prim::Constant[value=0]()
%4 : NoneType = prim::Constant()
%5 : int = prim::Constant[value=1]()
%7: Tensor = aten::to(%1, %2, %3, %3, %4)
%8 : Tensor = aten::mul(%0, %0)
%9 : Tensor = aten::scatter(%8, %5, %7, %5)
%10 : Tensor = aten::scatter(%7, %5, %7, %5)
%12 : Tensor = aten::add(%10, %10, %5)
return (%9, %12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.forced_fallback_operators = {"aten::scatter"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});
inputs_map.insert({g->inputs()[1], {inputs[1]}});
input_types.insert({g->inputs()[1], {{at::kInt}}});

auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);

torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}

TEST(Partitioning, ImplicitAutoConversionCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%2 : int = prim::Constant[value=0]()
%4 : int = aten::size(%0, %2)
%6 : Tensor = prim::NumToTensor(%4)
%2 : int = prim::Constant[value=5]()
%7 : int[] = prim::ListConstruct(%2, %2)
%8 : bool = prim::Constant[value=0]()
%9 : Tensor = aten::expand(%6, %7, %8)

%10 : Tensor = aten::mul(%9, %9)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.forced_fallback_operators = {"aten::expand"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;

inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});

auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);

torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}