Skip to content

feat: support int64 <=> int32 auto conversion #1407

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
registerSegmentsOutputs(ctx, block);

// run shape analysis on each segmented block
LOG_DEBUG("Running shape analysis for segmented graphs");
runShapeAnalysis(ctx, block, example_tensor_map);
}
}
Expand Down
90 changes: 87 additions & 3 deletions core/partitioning/shape_analysis.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <queue>
#include "ATen/ATen.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
Expand Down Expand Up @@ -57,6 +58,62 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
return ivalue_map;
}

torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
std::queue<torch::jit::Value*> q;
q.push(val);
std::unordered_set<torch::jit::Node*> visited;
while (!q.empty()) {
auto cur_val = q.front();
q.pop();
auto node = cur_val->node();
if (node->kind().toQualString() == std::string("aten::to")) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need an additional check to ensure that the aten::to schema is valid for dtype insertion, as some of these schemas do not take an integer dtype at all, for example:

  • aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(b|a)
  • aten::to(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
  • aten::to(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)

A check could be something like an additional && with

(node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
(node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @gs-olive Any reproducer for this?
What I'm not sure about is that for getUpstreamNode() function when we pass in a int32 value will the first cast node be the cast node that casts this value to int64? If that's the case, then we don't need this check.
In other words, is it possible that the first cast node involving the passed value is to cast some other value? If the first cast node is not the cast node that casts to int64, will the second cast node be what we want?

Copy link
Collaborator

@gs-olive gs-olive Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bowang007 - as an update, while this is no longer throwing an error on my end, my thought was that we do need this check you have, but maybe it should be more stringent - something like:

    if ((node->kind().toQualString() == std::string("aten::to")) &&
        ((node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
        (node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType))) {

This is because, in the case where the aten::to is the second option in my above comment, then inserting a constant like 3 will cause the model to fail, as the schema for to as requested needs a ScalarType and not an int. I don't have a specific model to reproduce an error with, and I do not think I encountered one while testing, I just thought it is generally safer to be more strict about the type of upstream cast node used to recast to Int32 - specifically, if we are unsure whether a node has a valid schema for repurposing, we should choose the safer option which is to manually insert an Int32 cast node, as you do in createCastNode.

Copy link
Collaborator

@gs-olive gs-olive Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bowang007 Please let me know what you think about the comment in the thread above:
#1407 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gs-olive I got your point now, let me update this part.

return node;
}
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
visited.insert(node);
for (auto input : node->inputs()) {
q.push(input);
}
}
}
return nullptr;
}

torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) {
torch::jit::Node* cast_node = getUpstreamCastNode(seg_block.raw_inputs()[index]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an if/else here to use raw_inputs() if is_input = true otherwise use raw_outputs()

auto g = seg_block.g();
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
if (cast_node) {
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
value_map.insert({cast_node->inputs()[0], g->inputs()[index]});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May need an if/else here to check if insert should be g->inputs()[index] or g->outputs()[index].

if (!is_input) {
// if this value is output, we need to cast it to int32
auto const_val = g->insertConstant(3);
value_map.insert({cast_node->inputs()[1], const_val});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throws an error when the upstream aten::to node does not have dtype as its second argument. For example, the schema aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(b|a) has Device as its second value, and this insertion causes it to be transformed to an invalid schema. We need to differentiate between schemas to ensure the dtype is placed in the right position. It seems that valid schemas for aten::to have dtype as either the second or third argument, or not at all. I believe there should be a check should be in getUpstreamCastNode to see if dtype is any of the arguments, and then a second check here to see if it is second or third argument in the schema.

The check here could be something like an if/else checking the debugName at the second index, as in:

if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
  value_map.insert({cast_node->inputs()[2], const_val});
} else {
  value_map.insert({cast_node->inputs()[1], const_val});
}

}
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, g, value_map); };
cast_node = g->createClone(cast_node, env);
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
} else {
// if there is no explicit cast aten::to operation, we need to create a node
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
auto const_zero = g->insertConstant(0);
const_zero->setType(torch::jit::BoolType::get());
auto none_val = g->insertNode(g->createNone())->output();
cast_node = g->create(torch::jit::aten::to, {g->inputs()[index], const_type, const_zero, const_zero, none_val});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an if/else here to use g->inputs() if is_input = true otherwise use g->outputs()


// auto cast_node = g->prependNode(g->create(torch::jit::aten::to, {g->inputs()[i], const_type, const_zero,
// const_zero, none_val})); seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node,
// cast_node->outputs()[0]); LOG_DEBUG(seg_block << " in shape analysis");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing commented code if not needed.

}
if (is_input) {
g->prependNode(cast_node);
} else {
g->appendNode(cast_node);
}
return cast_node;
}

void getSegmentsOutputByRunning(
SegmentedBlock& seg_block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
Expand Down Expand Up @@ -142,16 +199,43 @@ void getSegmentsOutputByRunning(
ivalues_maps[output] = jit_results[idx++];
}

// auto int64 <=> int32 conversion
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
// First, check if there is Int64 input
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
// we add a cast operation to cast the type to Int64
auto cast_node = createCastNode(seg_block, i, true);
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
}
}
}
for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
if (t == at::kLong) {
auto cast_node = createCastNode(seg_block, i, false);
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
}
}
}
}

// set input shape for each segmented block so we wil use it in conversion process
std::vector<ir::Input> input_shapes;
std::vector<at::ScalarType> input_types;
for (auto& i : seg_block.raw_inputs()) {
if (ivalues_maps[i].isTensor()) {
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
// set the input_shape and data_type
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
// shape inference
auto cur_ivalue = ivalues_maps[i];
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
at::ScalarType t = cur_ivalue.toTensor().scalar_type();

if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
TORCHTRT_THROW_ERROR(
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
Expand Down
5 changes: 5 additions & 0 deletions tests/core/partitioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ partitioning_test(
name = "test_resolve_nontensor_inputs",
)

partitioning_test(
name = "test_type_auto_conversion",
)

cc_test(
name = "test_loading_model",
srcs = ["test_loading_model.cpp"],
Expand Down Expand Up @@ -112,5 +116,6 @@ test_suite(
":test_shape_analysis",
":test_stitched_graph",
":test_tensorrt_conversion",
":test_type_auto_conversion",
],
)
109 changes: 109 additions & 0 deletions tests/core/partitioning/test_type_auto_conversion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include <string>
#include "core/partitioning/partitioning.h"
#include "core/util/trt_util.h"
#include "gtest/gtest.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/script.h"

bool checkInsertedCastNodeNumber(torch_tensorrt::core::partitioning::SegmentedBlock& seg_block, int target_count) {
int64_t cnt = 0;
for (auto node : seg_block.nodes()) {
if (node->kind().toQualString() == std::string("aten::to")) {
cnt++;
}
}
std::cout << "Found count of " << cnt << " inserted aten::to nodes, (looking for " << target_count << " aten::to nodes)"
<< std::endl;

return target_count == cnt;
}

TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor,
%1 : Tensor):
%2 : int = prim::Constant[value=4]()
%3 : bool = prim::Constant[value=0]()
%4 : NoneType = prim::Constant()
%5 : int = prim::Constant[value=1]()
%7: Tensor = aten::to(%1, %2, %3, %3, %4)
%8 : Tensor = aten::mul(%0, %0)
%9 : Tensor = aten::scatter(%8, %5, %7, %5)
%10 : Tensor = aten::scatter(%7, %5, %7, %5)
%12 : Tensor = aten::add(%10, %10, %5)
return (%9, %12))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.forced_fallback_operators = {"aten::scatter"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});
inputs_map.insert({g->inputs()[1], {inputs[1]}});
input_types.insert({g->inputs()[1], {{at::kInt}}});

auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);

torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));

}

TEST(Partitioning, ImplicitAutoConversionCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%2 : int = prim::Constant[value=0]()
%4 : int = aten::size(%0, %2)
%6 : Tensor = prim::NumToTensor(%4)
%2 : int = prim::Constant[value=5]()
%7 : int[] = prim::ListConstruct(%2, %2)
%8 : bool = prim::Constant[value=0]()
%9 : Tensor = aten::expand(%6, %7, %8)

%10 : Tensor = aten::mul(%9, %9)
return (%10))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get(), true);

torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
partitioning_info.enabled = true;
partitioning_info.forced_fallback_operators = {"aten::expand"};
partitioning_info.truncate_long_and_double = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;

inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));

std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
inputs_map.insert({g->inputs()[0], {inputs[0]}});
input_types.insert({g->inputs()[0], {{at::kFloat}}});

auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);

torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;

for (auto& seg_block : segmented_blocks) {
LOG_DEBUG(seg_block << " cur seg block");
}
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
}