Skip to content

[ET-VK][ez] Update requirements for partitioning to_dim_order_copy #7949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ def update_features_impl(op: OpKey):
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
# dim order copy operator will be removed; memory layout is handled internally
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]
)
def register_ephemeral_op(features: OpFeatures):
Expand Down Expand Up @@ -322,6 +320,37 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
return features


@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
def register_to_copy_dim_order_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
valid_packed_dims=all_packed_dims,
)
features.buffer_impl = True
features.resize_fn = True

# Currently there is no "real" implementation for to_dim_order_copy, but it can be
# removed as long as the operator is not changing the dtype, i.e. the operator call
# is modifying the dim order only. Therefore, check that the input and output dtypes
# are the same, if so the operator is safe to remove.
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False

in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)

if in_tensor.dtype != out_tensor.dtype:
return False

return True

features.check_node_fn = check_dim_order_copy_node

return features


@update_features(
[
exir_ops.edge.aten.bmm.default,
Expand Down