Skip to content

Commit b523f9f

Browse files
authored
Relax QAT dtype assertion (#692)
This was added originally for perf reasons specific to 8da4w, but the autograd.Function has since been adapted for more general use. A few users are hitting this assertion error. More context: pytorch/torchtune#1333
1 parent f2c908b commit b523f9f

File tree

2 files changed

+0
-13
lines changed

2 files changed

+0
-13
lines changed

test/quantization/test_qat.py

-6
Original file line numberDiff line numberDiff line change
@@ -422,9 +422,6 @@ def test_qat_4w_primitives(self):
422422

423423
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
424424
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
425-
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
426-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
427-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
428425
def test_qat_4w_linear(self):
429426
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
430427
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
@@ -453,9 +450,6 @@ def test_qat_4w_linear(self):
453450

454451
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower")
455452
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
456-
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
457-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_4, "assert input.dtype == torch.float32" )
458-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 doesn't work for 2.5+ right now")
459453
def test_qat_4w_quantizer(self):
460454
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
461455
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

torchao/quantization/prototype/qat/utils.py

-7
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@ def forward(
3636
block_size: List[int],
3737
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
3838
) -> torch.Tensor:
39-
# Note: for bf16 inputs, casting them to fp32 has the unexpected
40-
# side effect of reducing memory footprint significantly, presumably
41-
# because bf16 * fp32 kernels are not as memory efficient
42-
assert input.dtype == torch.float32
43-
assert scales.dtype == torch.float32
44-
assert zero_points.dtype == torch.int32
45-
4639
(fq, mask) = fake_quantize_affine_cachemask(
4740
input,
4841
block_size,

0 commit comments

Comments
 (0)