Skip to content

Commit f3ce109

Browse files
committed
[ET-VK][ez] Allow logit linear layer to be lowered to Vulkan
Pull Request resolved: #9918 ## Context Due to poor performance of Vulkan's int4 linear operator, the final logit layer of the transformer model was not being delegated to vulkan, and was instead quantized and executed with the XNNPACK delegate. However, with D72412950 / #9883 decent performance can now be achieved with Vulkan/s int4 linear op. Therefore, the final logit layer can be lowered to Vulkan. ## Changes * Remove limit from `VkInt4WeightOnlyQuantizer` that was causing it to ignore the logit layer of the transformer * Do not apply XNNPACK partitioner and quantizer when lowering with Vulkan Differential Revision: [D72480177](https://our.internmc.facebook.com/intern/diff/D72480177/) ghstack-source-id: 276235672
1 parent 346c2da commit f3ce109

File tree

4 files changed

+3
-28
lines changed

4 files changed

+3
-28
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
118118
# Use custom vulkan linear layer as default
119119
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
120120
copy_weights: bool = False,
121-
# Serves the same purpose as `tensor_dim_limit` in
122-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
123-
feature_limit: int = 16384,
124121
):
125122
for name, child in module.named_children():
126123
if isinstance(child, torch.nn.Linear) and (
@@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
131128
if (
132129
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
133130
or padding_allowed
134-
) and (
135-
child.out_features < feature_limit and child.in_features < feature_limit
136131
):
137132
new_linear = linear_class(
138133
child.in_features,
@@ -175,7 +170,6 @@ def __init__(
175170
inner_k_tiles: Optional[int] = 8,
176171
device: torch.device = torch.device("cpu"), # noqa
177172
precision: torch.dtype = torch.float32,
178-
feature_limit: int = 16384,
179173
) -> None:
180174
super().__init__()
181175
assert inner_k_tiles in [2, 4, 8]
@@ -186,9 +180,6 @@ def __init__(
186180
self.padding_allowed: bool = padding_allowed
187181
self.device: torch.device = device
188182
self.precision: torch.dtype = precision
189-
# Serves the same purpose as `tensor_dim_limit` in
190-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
191-
self.feature_limit = feature_limit
192183

193184
@torch.no_grad()
194185
def _create_quantized_state_dict(
@@ -197,10 +188,7 @@ def _create_quantized_state_dict(
197188
cur_state_dict = model.state_dict()
198189
for fqn, mod in model.named_modules():
199190
# Add additional check to make sure features do not exceed feature limit
200-
if isinstance(mod, torch.nn.Linear) and (
201-
mod.out_features < self.feature_limit
202-
and mod.in_features < self.feature_limit
203-
):
191+
if isinstance(mod, torch.nn.Linear):
204192
out_features = mod.out_features
205193
in_features = mod.in_features
206194
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")

backends/vulkan/op_registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def register_int8_mm_op(features: OpFeatures):
392392

393393
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
394394
def register_int4_mm_op(features: OpFeatures):
395+
features.buffer_impl = True
395396
features.texture_impl = TextureImplFeatures(
396397
uses_axis_map=False,
397398
valid_packed_dims={PackedDim.WIDTH},
@@ -400,6 +401,7 @@ def register_int4_mm_op(features: OpFeatures):
400401
features.optimal_storage = VkStorageType.TEXTURE_3D
401402
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
402403
features.handles_own_prepacking = True
404+
features.skip_limits_check = {1}
403405
return features
404406

405407

examples/models/llama/export_llama_lib.py

-4
Original file line numberDiff line numberDiff line change
@@ -793,10 +793,6 @@ def _to_edge_and_lower_llama( # noqa: C901
793793
args.enable_dynamic_shape,
794794
)
795795
)
796-
# Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK
797-
partitioners.append(
798-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
799-
)
800796
modelname = f"vulkan_{modelname}"
801797

802798
# Need to remove asserts from the graph to prevent graph breaks

examples/models/llama/source_transformation/quantize.py

-11
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,6 @@ def quantize( # noqa C901
206206
q_group_size = 256 if group_size is None else group_size
207207
model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model)
208208

209-
# Apply additional quantizer for linear layers that aren't lowered to Vulkan
210-
# at the moment
211-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
212-
213-
# 1. Quantize in checkpoint dtype.
214-
model = Int8DynActInt4WeightQuantizer(
215-
precision=checkpoint_torch_dtype, groupsize=q_group_size
216-
).quantize(model)
217-
# 2. Set the computation dtype (what weights/acts dequantize to).
218-
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
219-
220209
return model
221210
else:
222211
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)