-
Notifications
You must be signed in to change notification settings - Fork 365
feat: refactoring segmentation in partitioning #1067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Bo Wang <[email protected]>
Hi @bowang007! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
1 similar comment
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Signed-off-by: Bo Wang <[email protected]>
@@ -75,7 +75,8 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac | |||
q.pop(); | |||
// for every node that produces this fallback node's NonTensor input, they should fallback too | |||
for (auto input : cur_node->inputs()) { | |||
if (!isTensor(input) && fallback_nodes.insert({input->node(), 4}).second) { | |||
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this input->node() traversal misses cases where an op modifies an input.
Ex.
%0 : ListConstruct()
%1 : aten::append(%0, %val)
%2 : aten::append(%1, %val)
%3 : aten::cat(%0)
Looking at just the input->node() of the cat skips over the appends which could result in a change in model behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mfeliz-cruise did you hit this kind of issue in your model?
I built a model locally with this graph:
graph(%x.1 : Tensor,
%y.1 : Tensor):
%2 : int = prim::Constant[value=0]()
%mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
%z.1 : Tensor = aten::__getitem__(%mod_list.1, %2)
%5 : Tensor[] = aten::append(%mod_list.1, %x.1)
%6 : Tensor[] = aten::append(%mod_list.1, %y.1)
%7 : Tensor[] = aten::append(%mod_list.1, %z.1)
%8 : Tensor = aten::cat(%mod_list.1, %2)
return (%8)
with forced fallback on aten::cat, the model runs fine. It's true that the input->node() of aten::cat could skip over the aten::append, however, since aten::cat dependency node is prim::ListConstruct, and prim::ListConstruct's output is used later by aten::append, as a result, aten::append will also fallback.
I think there are several cases and these cases could all be covered:
%0: nodeA
%1: nodeB_modifyinput(%0)
%2: nodeC(%1)
1: nodeC fallbacks, then there are 2 senarios:
1.1: %1 is Tensor, in this case, nodeA should not fallback, nodeB which modifies input should also not fallback. We cannot get them fallback by traversing.
1.2: %1 is not Tensor, which could be a TensorList like what happens here. So, in this case, nodeA will fallback. Since nodeA fallback, nodeB which modifies %0 will also fallback finally because nodeB is in nodeA's output list.
2: nodeB fallbacks, 2 cases too:
2.1 if the input that nodeB is modifying is a Tensor, then nodeA and nodeC will not fallback.
2.2 if the input that nodeB is modifying is not a Tensor, then nodeA will fallback since nodeA is nodeB's input, since nodeA fallbacks, nodeC will also fallback since it's in nodeA's output. I also did a local test with forced fallback on aten::append, which also works fine.
Did I miss any case here in the analysis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're right that these cases should be handled by the forward fallback propagation from LIstConstruct. The cases where I've seen issues are caused by GetDependencyNodes where the ListConstruct node is identified as a dependency and copied into the segment of the cat without the appends, changing the behavior of the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mfeliz-cruise thanks for your feedback.
I updated the getDependencyNodes and now it should find all the modifying nodes and then take them as dependency nodes as well.
Please notify that since modifying nodes are not exhaustive so maybe we need to keep updating this list https://github.com/bowang007/TRTorch/blob/1df7cbb4f703d399a97901779fbc59034a8c8932/core/partitioning/partitioning.cpp#L43 as well.
Please try to check if it works. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Bo, could we use the AliasInfo embedded in the node->schema to find these ops on the fly rather than looking it up from a static list? That seems like it could be more robust.
I'm looking at what's done in AliasDB here to get AliasInfo for an input: https://github.com/pytorch/pytorch/blob/d26c575ff581e4df0e9c72b339a25999c6cae59e/torch/csrc/jit/ir/alias_analysis.cpp#L816
And then check if it's a write:
https://github.com/pytorch/pytorch/blob/d26c575ff581e4df0e9c72b339a25999c6cae59e/torch/csrc/jit/ir/alias_analysis.cpp#L847
https://caffe2.ai/doxygen-c/html/classc10_1_1_alias_info.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would also let you check which inputs are modified by an op rather than assuming if a value is used by a modifying op it will be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mfeliz-cruise ! This is really helpful!
The method your proposed is much more gentle than the current solution.
Thanks for sharing this AliasInfo APIs, I'm going to take a look and try to use them to find the modifying nodes.
core/partitioning/partitioning.cpp
Outdated
if (segmented_blocks[i].contain_raw_value(use.first)) { | ||
use.second.produce_id = i; | ||
if (!inputs_to_resolve.empty()) { | ||
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getDependencyNodes will miss any modifying dependency ops as it is currently written.
Ex.
%0 : ListConstruct()
%1 : aten::append(%0, %val)
%2 : aten::append(%1, %val)
%3 : aten::cat(%0)
getDependencyNodes for the aten::cat will only return the ListConstruct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mfeliz-cruise !
Let me check if there is any APIs that we can use in Torchscript to find all the modifying ops as well.
Signed-off-by: Bo Wang <[email protected]>
Signed-off-by: Bo Wang <[email protected]>
…r value that's produced in outer block Signed-off-by: Bo Wang <[email protected]>
Signed-off-by: Bo Wang <[email protected]>
Signed-off-by: Bo Wang <[email protected]>
Closing in favor of #1140 |
Signed-off-by: Bo Wang [email protected]
Description
Fixes #1031
We are doing a mess in resolveNonTensorInput and resolveTensorListInput, there are many different issues related to these 2 functions. Moreover, keep fixing these issues are introducing more and more complexities in our code base. We hope to solve this related issues in a gentle way and could simplify our code base as well.
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: