diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py index 409cbb4b755..d0b73b8af0e 100644 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ b/backends/vulkan/_passes/int4_weight_only_quantizer.py @@ -118,9 +118,6 @@ def _vk_replace_linear_int4( # Use custom vulkan linear layer as default linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, copy_weights: bool = False, - # Serves the same purpose as `tensor_dim_limit` in - # executorch.backends.vulkan.partitioner.VulkanSupportedOperators - feature_limit: int = 16384, ): for name, child in module.named_children(): if isinstance(child, torch.nn.Linear) and ( @@ -131,8 +128,6 @@ def _vk_replace_linear_int4( if ( _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed - ) and ( - child.out_features < feature_limit and child.in_features < feature_limit ): new_linear = linear_class( child.in_features, @@ -175,7 +170,6 @@ def __init__( inner_k_tiles: Optional[int] = 8, device: torch.device = torch.device("cpu"), # noqa precision: torch.dtype = torch.float32, - feature_limit: int = 16384, ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] @@ -186,9 +180,6 @@ def __init__( self.padding_allowed: bool = padding_allowed self.device: torch.device = device self.precision: torch.dtype = precision - # Serves the same purpose as `tensor_dim_limit` in - # executorch.backends.vulkan.partitioner.VulkanSupportedOperators - self.feature_limit = feature_limit @torch.no_grad() def _create_quantized_state_dict( @@ -197,10 +188,7 @@ def _create_quantized_state_dict( cur_state_dict = model.state_dict() for fqn, mod in model.named_modules(): # Add additional check to make sure features do not exceed feature limit - if isinstance(mod, torch.nn.Linear) and ( - mod.out_features < self.feature_limit - and mod.in_features < self.feature_limit - ): + if isinstance(mod, torch.nn.Linear): out_features = mod.out_features in_features = mod.in_features logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b33430a6bca..5b6637039f5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -392,6 +392,7 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): + features.buffer_impl = True features.texture_impl = TextureImplFeatures( uses_axis_map=False, valid_packed_dims={PackedDim.WIDTH}, @@ -400,6 +401,7 @@ def register_int4_mm_op(features: OpFeatures): features.optimal_storage = VkStorageType.TEXTURE_3D features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True + features.skip_limits_check = {1} return features diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 8e6d4fefb0e..249a25f23c4 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -793,10 +793,6 @@ def _to_edge_and_lower_llama( # noqa: C901 args.enable_dynamic_shape, ) ) - # Apply XNNPACK after Vulkan so that undelegated ops can be accelerated by XNNPACK - partitioners.append( - get_xnnpack_partitioner(dynamic_quant_only_partitioner=False) - ) modelname = f"vulkan_{modelname}" # Need to remove asserts from the graph to prevent graph breaks diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 2ef016de097..d51d4378705 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -206,17 +206,6 @@ def quantize( # noqa C901 q_group_size = 256 if group_size is None else group_size model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) - # Apply additional quantizer for linear layers that aren't lowered to Vulkan - # at the moment - from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - - # 1. Quantize in checkpoint dtype. - model = Int8DynActInt4WeightQuantizer( - precision=checkpoint_torch_dtype, groupsize=q_group_size - ).quantize(model) - # 2. Set the computation dtype (what weights/acts dequantize to). - model = set_8da4w_computation_dtype(model, computation_torch_dtype) - return model else: raise Exception(f"Unrecognized quantize mode: {qmode}")