diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 5c71fc4e0ae7..58c1d3613daf 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -228,8 +228,7 @@ def test_quantization(self): ("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), - ("int_a8w8", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), - ("uint_a16w7", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), + ("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), ] if TorchAoConfig._is_cuda_capability_atleast_8_9(): @@ -253,8 +252,8 @@ def test_quantization(self): for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: quant_kwargs = {} - if quantization_name in ["uint4wo", "uint_a16w7"]: - # The dummy flux model that we use requires us to impose some restrictions on group_size here + if quantization_name in ["uint4wo", "uint7wo"]: + # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here quant_kwargs.update({"group_size": 16}) quantization_config = TorchAoConfig( quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs