-
Notifications
You must be signed in to change notification settings - Fork 365
✨[Feature] Perform constant folding of operations in the graph if nodes can be evaluated at compile time #1266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Shouldn’t OpSupported report back that the node can be evaluated? TensorRT/core/conversion/conversion.cpp Line 22 in 5fa38f4
|
#1263 This pr aims to log every state change in target executor for nodes, might help clear up what is happening |
Yes. But the request here is to remove these nodes (make it a constant list which is passed to reshape node directly) altogether in the torchscript graph before partitioning phase so that ListConstruct won't obstruct partitioning and create additional segments.
|
Is this because of the input signature changes? |
My thinking here is that when an op is supported with an Evaluator, it is resolvable to a constant in TRT. If these ops are resolvable to constants in TRT they should also be resolvable to constants in TorchScript. If we implemented an evaluator pass which ran before partitioning and evaluated any evaluatable node (ex. aten::size with static shaped input) we could replace those nodes with constants in the graph. We could then rerun jit's constant propagation potentially simplifying the graph further. I think this has a few benefits:
|
This is true in some cases but not all. For example a Therefore, such a preprocessing system would need to be able to distinguish constant nodes using some characteristic other than node kind. All that is guaranteed by the torchscript IR is that once a value is produced it will never change but that is conditioned on one or many preceding ops being executed which may be dependencies of this value not all of which are going to be guaranteed to be evaluatable. I would think you could potentially evaluate to constants any nodes where you can trace-back to solely constants but I think that's what constant-prop/pool does. There is the slight detail that we have shape information, which could fold a few more ops, but we should then seek to fork the constant-prop/pool passes to take this into account vs other approaches I would think.
|
Thanks @narendasan. are there more examples beyond prim::ListConstruct and prim::ListUnpack? I agree these do not fit the resolve-to-constant pattern I'm trying to address here. Could these be considered as converters?
This addresses one of my concerns. I'm worried about cases where evaluators may resolve a value incorrectly to a constant or error out because the input is not constant. Ex. aten::mul.int of a value from an aten::size with dynamic input shapes should not be resolved to a constant. I think the case where we can resolve an aten::mul.int to a constant that is still in the Torch-TensorRT input after constant propagation is actually fairly rare and would only come up because of constants that we can identify that torchscript cannot (ex. size with static input shapes). If we removed these kinds of evaluators and instead relied on constant propagation after resolving known-to-torch-trt constants (size with static input shapes, dtype) then I think they would be automatically removed by constant propagation in all the cases where it would be valid for torch-trt to do the same. In the remaining cases the aten::mul would still be in the graph and could fallback correctly or be handled by a future converter that could correctly capture the non-constant input. I think based on your input this is what I'm proposing:
This would allow us to explicitly call out what torch-trt is assuming is constant in the initial jit pass and give us a clear torchscript dump after constant propagation of what is still in the graph that will need to be converted or otherwise handled in partitioning. It removes the risk of evaluators incorrectly resolving a value to a constant, creates a path to preserve these ops in fallback regions without erroring out and it allows for the possibility of adding new converter support for nodes that remain in the graph with non-constant inputs (ex. aten::size outputing a shape-tensor which could be consumed by a dynamic reshape). |
Re. 1. I think this sounds reasonable although it not sure how many cases it will hit, like you as point out as soon as there is dynamic shape we are on equal footing to TorchScript lowering information wise (give or take a few dimensions which are static across the input range). Regardless for the effort required I don't see any harm attempting to implement such a pass. The MVP I have in my head is really just inplace replacing directly known (i.e. what we get from the input specs) Re. 2. I don't really see why we need to remove evaluators, the combo of injecting more shape information into the graph then running constant pooling won't necessarily resolve all ops that would be handled by evaluators. My understanding of the fundamental issue here after talking to @peri044 was that data like shapes which Torch-TensorTRT knows but TorchScript lowering doesn't is not taken into account. Seems like 1. solves that on its own. If evaluators just don't get called anymore I guess that's a plus but I don't really expect that to be the case. Also from a while back we were chatting with the TS guys when we were trying to start implementing shape analysis and they said they were looking at baking shape info into the graph themselves, I don't know what came of that or if it was moved to a new project like functorch or fx but it would be worth taking a look at what they have these days that we can leverage. |
Seems like tracing is now including shape information:
|
Would we hit issues with dynamic shapes if we rely on the traced shape values? |
Potentially, also I think this would force people to use trace. We need to investigate this further. |
Is your feature request related to a problem? Please describe.
In the snippet below we have a reshape node %bar which is forced to fallback by min_block_size. Because %bar is forced to fallback its input %37 = prim::ListConstruct is also forced to fallback along with %28 : int = aten::size. In both cases these ops are supported by evaluators and (at least with static shapes) should be resolved to constants during conversion. Currently these ops are creating unnecessary breaks between TRT regions that will impact performance.
Describe the solution you'd like
Could evaluatable nodes be resolved to constants in the torchscript graph before partitioning to avoid this impact on the partition?
Describe alternatives you've considered
Could fallback nodes with no dependencies on active TRT nodes be consolidated to avoid breaking TRT regions? (In this case %37 is not used anywhere other than in the reshape node %bar and could be moved to just before its use)
The text was updated successfully, but these errors were encountered: