Skip to content

Commit c20465a

Browse files
SS-JIAkirklandsign
authored andcommitted
[ET-VK][ez] Make squeeze insertion requirements more strict (#9950)
## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see #8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/)
1 parent c16f3c5 commit c20465a

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

Diff for: backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727
exir_ops.edge.aten.gelu.default,
2828
}
2929

30+
def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
31+
if len(shape) == 3:
32+
return shape[1] == 1 and shape[0] > 1
33+
if len(shape) == 4:
34+
# No need to squeeze if all dims are 1 except the width dim
35+
if all(dim == 1 for dim in shape[:-1]):
36+
return False
37+
# Otherwise, check for squeezable dim
38+
return 1 in shape[:-1]
39+
40+
# Prefer not to introduce additional orchestration ops by default
41+
return False
42+
3043
def call_operator(
3144
self,
3245
op, # pyre-ignore
3346
args: Tuple[Argument, ...],
3447
kwargs: Dict[str, Argument],
3548
meta: NodeMetadata,
3649
) -> ProxyValue:
37-
def _squeezable(shape: List[int]) -> bool:
38-
return len(shape) > 2 and 1 in shape
39-
4050
if op not in self._squeezable_ops:
4151
return super().call_operator(op, args, kwargs, meta)
42-
4352
# pyre-ignore[16]: `None` has no attribute `node`
4453
input_shape = args[0].node.meta["val"].shape
4554
output_shape = meta["val"].shape
46-
if not _squeezable(input_shape):
55+
56+
if not self.should_squeeze(op, input_shape):
4757
return super().call_operator(op, args, kwargs, meta)
4858

59+
def _squeezable(shape: List[int]) -> bool:
60+
return len(shape) > 2 and 1 in shape
61+
4962
# squeeze input tensor
5063
squeeze_shape = list(input_shape)
5164
while _squeezable(squeeze_shape):

0 commit comments

Comments
 (0)