Skip to content

Commit b2a91e9

Browse files
committed
[ET-VK][ez] Make squeeze insertion requirements more strict
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned]
1 parent bd4455c commit b2a91e9

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

Diff for: backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,32 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727
exir_ops.edge.aten.gelu.default,
2828
}
2929

30+
def should_squeeze(self, op, shape: List[int]) -> bool:
31+
if len(shape) == 3:
32+
return shape[1] == 1 and shape[0] > 1
33+
34+
# Prefer not to introduce additional orchestration ops by default
35+
return False
36+
3037
def call_operator(
3138
self,
3239
op, # pyre-ignore
3340
args: Tuple[Argument, ...],
3441
kwargs: Dict[str, Argument],
3542
meta: NodeMetadata,
3643
) -> ProxyValue:
37-
def _squeezable(shape: List[int]) -> bool:
38-
return len(shape) > 2 and 1 in shape
39-
4044
if op not in self._squeezable_ops:
4145
return super().call_operator(op, args, kwargs, meta)
42-
4346
# pyre-ignore[16]: `None` has no attribute `node`
4447
input_shape = args[0].node.meta["val"].shape
4548
output_shape = meta["val"].shape
46-
if not _squeezable(input_shape):
49+
50+
if not self.should_squeeze(op, input_shape):
4751
return super().call_operator(op, args, kwargs, meta)
4852

53+
def _squeezable(shape: List[int]) -> bool:
54+
return len(shape) > 2 and 1 in shape
55+
4956
# squeeze input tensor
5057
squeeze_shape = list(input_shape)
5158
while _squeezable(squeeze_shape):

0 commit comments

Comments
 (0)