Skip to content

feat: refactoring segmentation in partitioning #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from

Conversation

bowang007
Copy link
Collaborator

@bowang007 bowang007 commented May 14, 2022

Signed-off-by: Bo Wang [email protected]

Description

Fixes #1031
We are doing a mess in resolveNonTensorInput and resolveTensorListInput, there are many different issues related to these 2 functions. Moreover, keep fixing these issues are introducing more and more complexities in our code base. We hope to solve this related issues in a gentle way and could simplify our code base as well.

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@facebook-github-bot
Copy link
Contributor

Hi @bowang007!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

1 similar comment
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@github-actions github-actions bot added the component: core Issues re: The core compiler label Jun 2, 2022
@@ -75,7 +75,8 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
q.pop();
// for every node that produces this fallback node's NonTensor input, they should fallback too
for (auto input : cur_node->inputs()) {
if (!isTensor(input) && fallback_nodes.insert({input->node(), 4}).second) {
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this input->node() traversal misses cases where an op modifies an input.

Ex.
%0 : ListConstruct()
%1 : aten::append(%0, %val)
%2 : aten::append(%1, %val)
%3 : aten::cat(%0)

Looking at just the input->node() of the cat skips over the appends which could result in a change in model behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mfeliz-cruise did you hit this kind of issue in your model?
I built a model locally with this graph:

graph(%x.1 : Tensor,
      %y.1 : Tensor):
  %2 : int = prim::Constant[value=0]() 
  %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
  %z.1 : Tensor = aten::__getitem__(%mod_list.1, %2) 
  %5 : Tensor[] = aten::append(%mod_list.1, %x.1) 
  %6 : Tensor[] = aten::append(%mod_list.1, %y.1) 
  %7 : Tensor[] = aten::append(%mod_list.1, %z.1) 
  %8 : Tensor = aten::cat(%mod_list.1, %2) 
  return (%8)

with forced fallback on aten::cat, the model runs fine. It's true that the input->node() of aten::cat could skip over the aten::append, however, since aten::cat dependency node is prim::ListConstruct, and prim::ListConstruct's output is used later by aten::append, as a result, aten::append will also fallback.

I think there are several cases and these cases could all be covered:

%0: nodeA
%1: nodeB_modifyinput(%0)
%2: nodeC(%1)

1: nodeC fallbacks, then there are 2 senarios:
1.1: %1 is Tensor, in this case, nodeA should not fallback, nodeB which modifies input should also not fallback. We cannot get them fallback by traversing.
1.2: %1 is not Tensor, which could be a TensorList like what happens here. So, in this case, nodeA will fallback. Since nodeA fallback, nodeB which modifies %0 will also fallback finally because nodeB is in nodeA's output list.

2: nodeB fallbacks, 2 cases too:
2.1 if the input that nodeB is modifying is a Tensor, then nodeA and nodeC will not fallback.
2.2 if the input that nodeB is modifying is not a Tensor, then nodeA will fallback since nodeA is nodeB's input, since nodeA fallbacks, nodeC will also fallback since it's in nodeA's output. I also did a local test with forced fallback on aten::append, which also works fine.

Did I miss any case here in the analysis?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right that these cases should be handled by the forward fallback propagation from LIstConstruct. The cases where I've seen issues are caused by GetDependencyNodes where the ListConstruct node is identified as a dependency and copied into the segment of the cat without the appends, changing the behavior of the model.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mfeliz-cruise thanks for your feedback.
I updated the getDependencyNodes and now it should find all the modifying nodes and then take them as dependency nodes as well.
Please notify that since modifying nodes are not exhaustive so maybe we need to keep updating this list https://github.com/bowang007/TRTorch/blob/1df7cbb4f703d399a97901779fbc59034a8c8932/core/partitioning/partitioning.cpp#L43 as well.
Please try to check if it works. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Bo, could we use the AliasInfo embedded in the node->schema to find these ops on the fly rather than looking it up from a static list? That seems like it could be more robust.

I'm looking at what's done in AliasDB here to get AliasInfo for an input: https://github.com/pytorch/pytorch/blob/d26c575ff581e4df0e9c72b339a25999c6cae59e/torch/csrc/jit/ir/alias_analysis.cpp#L816

And then check if it's a write:
https://github.com/pytorch/pytorch/blob/d26c575ff581e4df0e9c72b339a25999c6cae59e/torch/csrc/jit/ir/alias_analysis.cpp#L847

https://caffe2.ai/doxygen-c/html/classc10_1_1_alias_info.html

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would also let you check which inputs are modified by an op rather than assuming if a value is used by a modifying op it will be modified.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mfeliz-cruise ! This is really helpful!
The method your proposed is much more gentle than the current solution.
Thanks for sharing this AliasInfo APIs, I'm going to take a look and try to use them to find the modifying nodes.

if (segmented_blocks[i].contain_raw_value(use.first)) {
use.second.produce_id = i;
if (!inputs_to_resolve.empty()) {
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
Copy link
Contributor

@mfeliz-cruise mfeliz-cruise Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getDependencyNodes will miss any modifying dependency ops as it is currently written.

Ex.
%0 : ListConstruct()
%1 : aten::append(%0, %val)
%2 : aten::append(%1, %val)
%3 : aten::cat(%0)

getDependencyNodes for the aten::cat will only return the ListConstruct.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @mfeliz-cruise !
Let me check if there is any APIs that we can use in Torchscript to find all the modifying ops as well.

@github-actions github-actions bot added the component: tests Issues re: Tests label Jun 16, 2022
@peri044
Copy link
Collaborator

peri044 commented Jun 22, 2022

Closing in favor of #1140

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Refactoring the Graph Segmentation in Partitioning
5 participants