-
Notifications
You must be signed in to change notification settings - Fork 617
fix QAT version dependency #1333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1333
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 04ccbf2 with merge base 6a7951f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc @msaroufim |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming CI is green this is OK to merge
@felipemello1 Would you mind just also double checking our version guards for AO? We'll need to be extra careful around this now that we're relaxing this pin. CI will only catch our tests, which are a subset of how our library is used. |
Not sure what you mean by it. Can you add an example or a link of what you would like me to do? |
This change allows for stable builds of torchtune to break in the future. If there is a stable package for torchtune that works fine with torchao, and then torchao releases a new stable package with bc breaking changes, our existing stable packages would try to install the new torchao package and break. We need to keep torchao pinned and then use a tool like dependabot to keep the pinned version up to date. For CI we should decide if we want to pin to the latest version of PyTorch and possible have separate tests for PyTorch nightlies with unpinned PyTorch libraries. @ebsmothers |
TLDR: just update to 0.4 When users In tune CI you should always be testing all your latest stable dependencies and all your latest nightly dependencies. We should never be catching BC issues at release time but at nightly CI time, that way upgrading a stable release can be a safe activity. Personally I wouldn't wait more than a few days after an official AO release to make an upgrade |
@@ -65,7 +65,7 @@ enable_activation_checkpointing: True | |||
memory_efficient_fsdp_wrap: False | |||
|
|||
# Reduced precision | |||
dtype: bf16 | |||
dtype: fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @andrewor14
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look right. I will submit a PR to remove that assertion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw how come this wasn't caught in the tune nightly CI? @joecummings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@msaroufim we don't actually test with our "prod" configs, instead we define a set of test configs that we deem to be (pretty) representative of the configs we provide. Unfortunately to do loss parity checks we tend to set dtype=fp32
in the tests (see here for the QAT test), so as a result this one slipped by
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
version upgrade looks fine to me, probably want Andrew to also review the reduced precision change you have since not sure if that has perf implication
@felipemello1 How urgent is upgrading torchao? I will submit a fix in torchao itself to remove that assertion, but I don't think we want to change the default precision in the QAT recipes. We have another release planned in early september, but if that's too late for you maybe we can do a 0.4.1 release with the fix? |
This was added originally for perf reasons specific to 8da4w, but the autograd.Function has since been adapted for more general use. A few users are hitting this assertion error. More context: pytorch/torchtune#1333
Thanks for the fix @andrewor14! If making a release is not a huge effort, this would solve multiple problems: Our regression error, the import version, and the dtype. Itt would be convenient. However, i dont think that we have a huge number of users using QAT, so waiting for september wouldnt be terrible. In summary, if making the release is easy, that would be very neat. But if its going to take you a considerable amount of time, we can wait 2 weeks. |
This was added originally for perf reasons specific to 8da4w, but the autograd.Function has since been adapted for more general use. A few users are hitting this assertion error. More context: pytorch/torchtune#1333
Is this on hold until next torchao release then? And if so are we gonna just bump to 0.5.0? If so let's make sure that our nightly CI is green before that release |
thats my understanding |
Context
What is the purpose of this PR? Is it to
We updated torchtune to use torchao 0.4.0. It breaks unless user has pytorch 2.4.0. In our scripts, we were using import guards:
https://github.com/felipemello1/torchtune/blob/04ccbf2601653e0e2cceb75e59394df5517d26e3/torchtune/utils/quantization.py#L12
However, "TORCH_VERSION_AFTER_2_4" actually didnt include 2_4. This was fixed in torchao here: pytorch/ao#684, but it wont be available to us until their next release.
After updating TorchAO and the import guards, another error was raised:
This is because QAT recipe now requires the model to be in float32. More context here: https://github.com/pytorch/ao/blob/0b66ff01ab6ba4094823b8cb134ab5b5a744d73a/torchao/quantization/prototype/qat/utils.py#L39
Changing the QAT recipes to have dtpye = fp32 solved it
Changelog
Test plan
I was able to run the code below. But I did not try to compare with previous version.