Skip to content

[RFC] Adding support for non-constant dims for aten.view #1131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gpetters94 opened this issue Aug 1, 2022 · 13 comments
Closed

[RFC] Adding support for non-constant dims for aten.view #1131

gpetters94 opened this issue Aug 1, 2022 · 13 comments

Comments

@gpetters94
Copy link
Collaborator

I'm working on lowering OPT, and I'm running into the following:

error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
note: see current operation: %932 = "torch.aten.view"(%898, %931) : (!torch.vtensor<[1,12,7,64],f32>, !torch.list<int>) -> !torch.vtensor<[12,7,64],f32>

Inspecting the lowering of aten.view, it looks like the output shape is -1, -1, 64 because the first two input dims aren't constants. The solution I'd like to write is to recursively follow the dims up the tree, verifying that all the ops are either constants, no-ops (i.e. NumToTensor), or math ops (i.e. multiplication, addition) and then performing the math statically to determine the output shape. Does this sound like how we want to implement this?

@gpetters94 gpetters94 changed the title [RFC] Adding support for multiple unknown dims for aten.view [RFC] Adding support for non-constant dims for aten.view Aug 1, 2022
@silvasean
Copy link
Contributor

The output shape looks like [12, 7, 64] in your snippet and not [-1, -1, 64]. Can you show the actual IR snippet you are dealing with?

@gpetters94
Copy link
Collaborator Author

In the actually processing of aten.view, it checks if each input dim is a constant. If not it assigns kUnknownDim to it, and in this case the first two inputs are not constants. The code is here.

@silvasean
Copy link
Contributor

Can you show the IR before the pass?

@silvasean
Copy link
Contributor

(for future reference, it's usually important to show a reduced, fully valid IR example with any bug reports like this)

@gpetters94
Copy link
Collaborator Author

Here's the IR after failure: https://gist.github.com/gpetters94/af96b032acb0e6c6274af9aff62ec5e3

The relevant part is:

  %136 = torch.aten.mul.Tensor %123, %71 : !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64>
  %137 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %138 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %139 = torch.aten.Int.Tensor %136 : !torch.vtensor<[],si64> -> !torch.int
  %140 = torch.prim.ListConstruct %int1, %int7, %int12, %int64 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %141 = torch.aten.view %126, %140 : !torch.vtensor<[1,7,768],f32>, !torch.list<int> -> !torch.vtensor<[1,7,12,64],f32>
  %142 = torch.aten.transpose.int %141, %int1, %int2 : !torch.vtensor<[1,7,12,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %143 = torch.aten.contiguous %142, %int0 : !torch.vtensor<[1,12,7,64],f32>, !torch.int -> !torch.vtensor<[1,12,7,64],f32>
  %144 = torch.aten.numel %143 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %136, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %70, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %139, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %143, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>

@gpetters94
Copy link
Collaborator Author

Here's the distilled version:

func.func @forward(%arg0: !torch.vtensor<[1,12,7,64],f32>) -> !torch.vtensor<[12,7,64],f32> {
  %str = torch.constant.str "floor"
  %int7 = torch.constant.int 7
  %int12 = torch.constant.int 12
  %int64 = torch.constant.int 64
  %144 = torch.aten.numel %arg0 : !torch.vtensor<[1,12,7,64],f32> -> !torch.int
  %145 = torch.prim.NumToTensor.Scalar %144 : !torch.int -> !torch.vtensor<[],si64>
  %tensor7 = torch.prim.NumToTensor.Scalar %int7 : !torch.int -> !torch.vtensor<[],si64>
  %tensor64 = torch.prim.NumToTensor.Scalar %int64 : !torch.int -> !torch.vtensor<[],si64>
  %146 = torch.aten.div.Tensor_mode %145, %tensor7, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %147 = torch.aten.div.Tensor_mode %146, %tensor64, %str : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.str -> !torch.vtensor<[],si64>
  %148 = torch.aten.Int.Tensor %147 : !torch.vtensor<[],si64> -> !torch.int
  %149 = torch.prim.ListConstruct %int12, %148, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %150 = torch.aten.view %arg0, %149 : !torch.vtensor<[1,12,7,64],f32>, !torch.list<int> -> !torch.vtensor<[12,7,64],f32>
  return %150 : !torch.vtensor<[12,7,64],f32>
}

@silvasean
Copy link
Contributor

silvasean commented Aug 2, 2022

It looks we have already done all the shape math statically, because the result shape is inferred as !torch.vtensor<[12,7,64],f32>. So I don't want to do any special local logic here for that.

You should be able to extend #935 for torch.aten.div.Tensor_mode to do more folding here if that is useful as well.

@gpetters94
Copy link
Collaborator Author

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

@silvasean
Copy link
Contributor

So should I just rewrite aten.view to use the statically-inferred output shape when the current logic fails?

That would make sense to me. Actually, I would add a canonicalization that replaces the view sizes operand with a constant list if the result shape is static (and the operand is not already a constant list).

@gpetters94
Copy link
Collaborator Author

Sure, I can do that. Where are canonicalizations added?

@silvasean
Copy link
Contributor

TorchOps.cpp -- you need to add let hasCanonicalizer = 1 the ODS definition.

@silvasean
Copy link
Contributor

See here for more info: https://mlir.llvm.org/docs/Canonicalization/

silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 5, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 5, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 15, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 16, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 17, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 17, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
silvasean added a commit to silvasean/torch-mlir that referenced this issue Aug 17, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
llvm#1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
silvasean added a commit that referenced this issue Aug 17, 2022
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
#1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
@gpetters94
Copy link
Collaborator Author

Implemented this in #1337

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this issue Oct 3, 2022
Co-authored-by: Tung D. Le <[email protected]>
Co-authored-by: Alexandre Eichenberger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants