File tree 2 files changed +0
-13
lines changed
torchao/quantization/prototype/qat
2 files changed +0
-13
lines changed Original file line number Diff line number Diff line change @@ -422,9 +422,6 @@ def test_qat_4w_primitives(self):
422
422
423
423
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
424
424
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
425
- # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
426
- @unittest .skipIf (TORCH_VERSION_AT_LEAST_2_4 , "assert input.dtype == torch.float32" )
427
- @unittest .skipIf (TORCH_VERSION_AT_LEAST_2_5 , "int4 doesn't work for 2.5+ right now" )
428
425
def test_qat_4w_linear (self ):
429
426
from torchao .quantization .prototype .qat .api import Int4WeightOnlyQATLinear
430
427
from torchao .quantization .GPTQ import WeightOnlyInt4Linear
@@ -453,9 +450,6 @@ def test_qat_4w_linear(self):
453
450
454
451
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "skipping when torch version is 2.4 or lower" )
455
452
@unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
456
- # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
457
- @unittest .skipIf (TORCH_VERSION_AT_LEAST_2_4 , "assert input.dtype == torch.float32" )
458
- @unittest .skipIf (TORCH_VERSION_AT_LEAST_2_5 , "int4 doesn't work for 2.5+ right now" )
459
453
def test_qat_4w_quantizer (self ):
460
454
from torchao .quantization .prototype .qat import Int4WeightOnlyQATQuantizer
461
455
from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
Original file line number Diff line number Diff line change @@ -36,13 +36,6 @@ def forward(
36
36
block_size : List [int ],
37
37
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
38
38
) -> torch .Tensor :
39
- # Note: for bf16 inputs, casting them to fp32 has the unexpected
40
- # side effect of reducing memory footprint significantly, presumably
41
- # because bf16 * fp32 kernels are not as memory efficient
42
- assert input .dtype == torch .float32
43
- assert scales .dtype == torch .float32
44
- assert zero_points .dtype == torch .int32
45
-
46
39
(fq , mask ) = fake_quantize_affine_cachemask (
47
40
input ,
48
41
block_size ,
You can’t perform that action at this time.
0 commit comments