Skip to content

[ET-VK][ez] Allow logit linear layer to be lowered to Vulkan #9918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions backends/vulkan/_passes/int4_weight_only_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
# Use custom vulkan linear layer as default
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
copy_weights: bool = False,
# Serves the same purpose as `tensor_dim_limit` in
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
feature_limit: int = 16384,
):
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear) and (
Expand All @@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
if (
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
or padding_allowed
) and (
child.out_features < feature_limit and child.in_features < feature_limit
):
new_linear = linear_class(
child.in_features,
Expand Down Expand Up @@ -175,7 +170,6 @@ def __init__(
inner_k_tiles: Optional[int] = 8,
device: torch.device = torch.device("cpu"), # noqa
precision: torch.dtype = torch.float32,
feature_limit: int = 16384,
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
Expand All @@ -186,9 +180,6 @@ def __init__(
self.padding_allowed: bool = padding_allowed
self.device: torch.device = device
self.precision: torch.dtype = precision
# Serves the same purpose as `tensor_dim_limit` in
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
self.feature_limit = feature_limit

@torch.no_grad()
def _create_quantized_state_dict(
Expand All @@ -197,10 +188,7 @@ def _create_quantized_state_dict(
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
# Add additional check to make sure features do not exceed feature limit
if isinstance(mod, torch.nn.Linear) and (
mod.out_features < self.feature_limit
and mod.in_features < self.feature_limit
):
if isinstance(mod, torch.nn.Linear):
out_features = mod.out_features
in_features = mod.in_features
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def register_int8_mm_op(features: OpFeatures):

@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
def register_int4_mm_op(features: OpFeatures):
features.buffer_impl = True
features.texture_impl = TextureImplFeatures(
uses_axis_map=False,
valid_packed_dims={PackedDim.WIDTH},
Expand All @@ -400,6 +401,7 @@ def register_int4_mm_op(features: OpFeatures):
features.optimal_storage = VkStorageType.TEXTURE_3D
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
features.handles_own_prepacking = True
features.skip_limits_check = {1}
return features


Expand Down
4 changes: 0 additions & 4 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,10 +793,6 @@ def _to_edge_and_lower_llama( # noqa: C901
args.enable_dynamic_shape,
)
)
# Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK
partitioners.append(
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
)
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
Expand Down
11 changes: 0 additions & 11 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,6 @@ def quantize( # noqa C901
q_group_size = 256 if group_size is None else group_size
model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model)

# Apply additional quantizer for linear layers that aren't lowered to Vulkan
# at the moment
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

# 1. Quantize in checkpoint dtype.
model = Int8DynActInt4WeightQuantizer(
precision=checkpoint_torch_dtype, groupsize=q_group_size
).quantize(model)
# 2. Set the computation dtype (what weights/acts dequantize to).
model = set_8da4w_computation_dtype(model, computation_torch_dtype)

return model
else:
raise Exception(f"Unrecognized quantize mode: {qmode}")
Expand Down
Loading