@@ -228,8 +228,7 @@ def test_quantization(self):
228
228
("int8wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
229
229
("int8dq" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
230
230
("uint4wo" , np .array ([0.4609 , 0.5234 , 0.5508 , 0.4199 , 0.4336 , 0.6406 , 0.4316 , 0.4531 , 0.5625 ])),
231
- ("int_a8w8" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4199 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
232
- ("uint_a16w7" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
231
+ ("uint7wo" , np .array ([0.4648 , 0.5195 , 0.5547 , 0.4219 , 0.4414 , 0.6445 , 0.4316 , 0.4531 , 0.5625 ])),
233
232
]
234
233
235
234
if TorchAoConfig ._is_cuda_capability_atleast_8_9 ():
@@ -253,8 +252,8 @@ def test_quantization(self):
253
252
254
253
for quantization_name , expected_slice in QUANTIZATION_TYPES_TO_TEST :
255
254
quant_kwargs = {}
256
- if quantization_name in ["uint4wo" , "uint_a16w7 " ]:
257
- # The dummy flux model that we use requires us to impose some restrictions on group_size here
255
+ if quantization_name in ["uint4wo" , "uint7wo " ]:
256
+ # The dummy flux model that we use has smaller dimensions. This imposes some restrictions on group_size here
258
257
quant_kwargs .update ({"group_size" : 16 })
259
258
quantization_config = TorchAoConfig (
260
259
quant_type = quantization_name , modules_to_not_convert = ["x_embedder" ], ** quant_kwargs
0 commit comments