-
Notifications
You must be signed in to change notification settings - Fork 555
[RFC] Adding support for non-constant dims for aten.view #1131
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The output shape looks like [12, 7, 64] in your snippet and not [-1, -1, 64]. Can you show the actual IR snippet you are dealing with? |
In the actually processing of |
Can you show the IR before the pass? |
(for future reference, it's usually important to show a reduced, fully valid IR example with any bug reports like this) |
Here's the IR after failure: https://gist.github.com/gpetters94/af96b032acb0e6c6274af9aff62ec5e3 The relevant part is:
|
Here's the distilled version:
|
It looks we have already done all the shape math statically, because the result shape is inferred as You should be able to extend #935 for torch.aten.div.Tensor_mode to do more folding here if that is useful as well. |
So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails? |
That would make sense to me. Actually, I would add a canonicalization that replaces the view sizes operand with a constant list if the result shape is static (and the operand is not already a constant list). |
Sure, I can do that. Where are canonicalizations added? |
TorchOps.cpp -- you need to add |
See here for more info: https://mlir.llvm.org/docs/Canonicalization/ |
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: llvm#1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: #1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.
Implemented this in #1337 |
Co-authored-by: Tung D. Le <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
I'm working on lowering OPT, and I'm running into the following:
Inspecting the lowering of aten.view, it looks like the output shape is
-1, -1, 64
because the first two input dims aren't constants. The solution I'd like to write is to recursively follow the dims up the tree, verifying that all the ops are either constants, no-ops (i.e. NumToTensor), or math ops (i.e. multiplication, addition) and then performing the math statically to determine the output shape. Does this sound like how we want to implement this?The text was updated successfully, but these errors were encountered: