You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[ET-VK][ez] Make squeeze insertion requirements more strict (#9950)
## Context
Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention.
For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary.
The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator.
## Changes
To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added.
I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on.
Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/)
0 commit comments