diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index eb831e352c..d70cf93b88 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,8 +228,6 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, - # dim order copy operator will be removed; memory layout is handled internally - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] ) def register_ephemeral_op(features: OpFeatures): @@ -322,6 +320,37 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return features +@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) +def register_to_copy_dim_order_op(features: OpFeatures): + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims=all_packed_dims, + ) + features.buffer_impl = True + features.resize_fn = True + + # Currently there is no "real" implementation for to_dim_order_copy, but it can be + # removed as long as the operator is not changing the dtype, i.e. the operator call + # is modifying the dim order only. Therefore, check that the input and output dtypes + # are the same, if so the operator is safe to remove. + def check_dim_order_copy_node(node: torch.fx.Node) -> bool: + in_arg = node.args[0] + if not isinstance(in_arg, torch.fx.Node): + return False + + in_tensor = in_arg.meta.get("val", None) + out_tensor = node.meta.get("val", None) + + if in_tensor.dtype != out_tensor.dtype: + return False + + return True + + features.check_node_fn = check_dim_order_copy_node + + return features + + @update_features( [ exir_ops.edge.aten.bmm.default,