Skip to content

fix QAT version dependency #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ dependencies = [
"blobfile>=2",

# Miscellaneous
"numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed
"numpy", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed
"tqdm",
"omegaconf",

# Quantization
"torchao==0.3.1",
"torchao==0.4.0",
]
dynamic = ["version"]

Expand Down
2 changes: 1 addition & 1 deletion recipes/configs/llama2/7B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: False

# Reduced precision
dtype: bf16
dtype: fp32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right. I will submit a PR to remove that assertion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw how come this wasn't caught in the tune nightly CI? @joecummings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim we don't actually test with our "prod" configs, instead we define a set of test configs that we deem to be (pretty) representative of the configs we provide. Unfortunately to do loss parity checks we tend to set dtype=fp32 in the tests (see here for the QAT test), so as a result this one slipped by


# Logging
metric_logger:
Expand Down
4 changes: 2 additions & 2 deletions recipes/configs/llama3/8B_qat_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ device: cuda
enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: True

# Reduced precision
dtype: bf16
# Precision
dtype: fp32

# Logging
metric_logger:
Expand Down
14 changes: 9 additions & 5 deletions torchtune/utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# importing TORCH_VERSION_AFTER_2_3 because `Int8DynActInt4WeightQuantizer`
# is only available after 2.3 so we have to guard the pytorch versions to decide
# the list of supported quantizers
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from torchao.utils import TORCH_VERSION_AFTER_2_2, TORCH_VERSION_AFTER_2_3

__all__ = [
"get_quantizer_mode",
Expand All @@ -20,15 +20,19 @@
_quantizer_mode_to_disable_fake_quant = {}
_quantizer_mode_to_enable_fake_quant = {}


if TORCH_VERSION_AFTER_2_3:
# TODO: bump this after tochao releases >0.4.0
# Until 0.4.0, this did not include the version. Eg. AFTER_2_2, does not include 2.2.
# This should be "TORCH_VERSION_AFTER_2_3" after the fix
# More info here: https://github.com/pytorch/ao/pull/684
if TORCH_VERSION_AFTER_2_2:
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

__all__.append("Int8DynActInt4WeightQuantizer")
_quantizer_to_mode[Int8DynActInt4WeightQuantizer] = "8da4w"


if TORCH_VERSION_AFTER_2_4:
# TODO: see comment above
# This should be TORCH_VERSION_AFTER_2_4
if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Expand Down
Loading