Skip to content

Commit 73f2c97

Browse files
authored
[ET-VK][ez] Allow logit linear layer to be lowered to Vulkan
Differential Revision: D72480177 Pull Request resolved: #9918
1 parent 5232a22 commit 73f2c97

File tree

4 files changed

+3
-28
lines changed

4 files changed

+3
-28
lines changed

Diff for: backends/vulkan/_passes/int4_weight_only_quantizer.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
118118
# Use custom vulkan linear layer as default
119119
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
120120
copy_weights: bool = False,
121-
# Serves the same purpose as `tensor_dim_limit` in
122-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
123-
feature_limit: int = 16384,
124121
):
125122
for name, child in module.named_children():
126123
if isinstance(child, torch.nn.Linear) and (
@@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
131128
if (
132129
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
133130
or padding_allowed
134-
) and (
135-
child.out_features < feature_limit and child.in_features < feature_limit
136131
):
137132
new_linear = linear_class(
138133
child.in_features,
@@ -175,7 +170,6 @@ def __init__(
175170
inner_k_tiles: Optional[int] = 8,
176171
device: torch.device = torch.device("cpu"), # noqa
177172
precision: torch.dtype = torch.float32,
178-
feature_limit: int = 16384,
179173
) -> None:
180174
super().__init__()
181175
assert inner_k_tiles in [2, 4, 8]
@@ -186,9 +180,6 @@ def __init__(
186180
self.padding_allowed: bool = padding_allowed
187181
self.device: torch.device = device
188182
self.precision: torch.dtype = precision
189-
# Serves the same purpose as `tensor_dim_limit` in
190-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
191-
self.feature_limit = feature_limit
192183

193184
@torch.no_grad()
194185
def _create_quantized_state_dict(
@@ -197,10 +188,7 @@ def _create_quantized_state_dict(
197188
cur_state_dict = model.state_dict()
198189
for fqn, mod in model.named_modules():
199190
# Add additional check to make sure features do not exceed feature limit
200-
if isinstance(mod, torch.nn.Linear) and (
201-
mod.out_features < self.feature_limit
202-
and mod.in_features < self.feature_limit
203-
):
191+
if isinstance(mod, torch.nn.Linear):
204192
out_features = mod.out_features
205193
in_features = mod.in_features
206194
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")

Diff for: backends/vulkan/op_registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def register_int8_mm_op(features: OpFeatures):
392392

393393
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
394394
def register_int4_mm_op(features: OpFeatures):
395+
features.buffer_impl = True
395396
features.texture_impl = TextureImplFeatures(
396397
uses_axis_map=False,
397398
valid_packed_dims={PackedDim.WIDTH},
@@ -400,6 +401,7 @@ def register_int4_mm_op(features: OpFeatures):
400401
features.optimal_storage = VkStorageType.TEXTURE_3D
401402
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
402403
features.handles_own_prepacking = True
404+
features.skip_limits_check = {1}
403405
return features
404406

405407

Diff for: examples/models/llama/export_llama_lib.py

-4
Original file line numberDiff line numberDiff line change
@@ -793,10 +793,6 @@ def _to_edge_and_lower_llama( # noqa: C901
793793
args.enable_dynamic_shape,
794794
)
795795
)
796-
# Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK
797-
partitioners.append(
798-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
799-
)
800796
modelname = f"vulkan_{modelname}"
801797

802798
# Need to remove asserts from the graph to prevent graph breaks

Diff for: examples/models/llama/source_transformation/quantize.py

-11
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,6 @@ def quantize( # noqa C901
206206
q_group_size = 256 if group_size is None else group_size
207207
model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model)
208208

209-
# Apply additional quantizer for linear layers that aren't lowered to Vulkan
210-
# at the moment
211-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
212-
213-
# 1. Quantize in checkpoint dtype.
214-
model = Int8DynActInt4WeightQuantizer(
215-
precision=checkpoint_torch_dtype, groupsize=q_group_size
216-
).quantize(model)
217-
# 2. Set the computation dtype (what weights/acts dequantize to).
218-
model = set_8da4w_computation_dtype(model, computation_torch_dtype)
219-
220209
return model
221210
else:
222211
raise Exception(f"Unrecognized quantize mode: {qmode}")

0 commit comments

Comments
 (0)