Skip to content

fix QAT version dependency #1333

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Aug 14, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

We updated torchtune to use torchao 0.4.0. It breaks unless user has pytorch 2.4.0. In our scripts, we were using import guards:

https://github.com/felipemello1/torchtune/blob/04ccbf2601653e0e2cceb75e59394df5517d26e3/torchtune/utils/quantization.py#L12

However, "TORCH_VERSION_AFTER_2_4" actually didnt include 2_4. This was fixed in torchao here: pytorch/ao#684, but it wont be available to us until their next release.

After updating TorchAO and the import guards, another error was raised:

[rank0]:   File "/home/felipemello/.conda/envs/test_ao/lib/python3.10/site-packages/torchao/quantization/prototype/qat/utils.py", line 42, in forward
[rank0]:     assert input.dtype == torch.float32

This is because QAT recipe now requires the model to be in float32. More context here: https://github.com/pytorch/ao/blob/0b66ff01ab6ba4094823b8cb134ab5b5a744d73a/torchao/quantization/prototype/qat/utils.py#L39

Changing the QAT recipes to have dtpye = fp32 solved it

Changelog

  • Update torchao=0.4.0
  • Remove pin from Numpy (this is unrelated to this PR, but it was something we needed to do, so it made sense to test everything together Unpin Numpy #1344)
  • Temporarily change the import guards. They MUST be updated with the next torchao release. (should I add some assertion that checks torchao__version__ <= 0.4.0?)
  • QAT configs use fp32

Test plan

I was able to run the code below. But I did not try to compare with previous version.

tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full 
image

Copy link

pytorch-bot bot commented Aug 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1333

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 04ccbf2 with merge base 6a7951f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 14, 2024
@joecummings
Copy link
Contributor

cc @msaroufim

@msaroufim msaroufim self-requested a review August 14, 2024 15:42
msaroufim
msaroufim previously approved these changes Aug 14, 2024
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming CI is green this is OK to merge

@joecummings
Copy link
Contributor

@felipemello1 Would you mind just also double checking our version guards for AO? We'll need to be extra careful around this now that we're relaxing this pin.

CI will only catch our tests, which are a subset of how our library is used.

@felipemello1
Copy link
Contributor Author

Would you mind just also double checking our version guards for AO?

Not sure what you mean by it. Can you add an example or a link of what you would like me to do?

@pbontrager
Copy link
Contributor

pbontrager commented Aug 14, 2024

This change allows for stable builds of torchtune to break in the future. If there is a stable package for torchtune that works fine with torchao, and then torchao releases a new stable package with bc breaking changes, our existing stable packages would try to install the new torchao package and break. We need to keep torchao pinned and then use a tool like dependabot to keep the pinned version up to date.

For CI we should decide if we want to pin to the latest version of PyTorch and possible have separate tests for PyTorch nightlies with unpinned PyTorch libraries. @ebsmothers

@msaroufim msaroufim dismissed their stale review August 14, 2024 18:34

still discussing

@msaroufim
Copy link
Member

msaroufim commented Aug 14, 2024

If there is a stable package for torchtune that works fine with torchao, and then torchao releases a new stable package with bc breaking changes, our existing stable packages

TLDR: just update to 0.4

When users pip install torchtune the ao version should be pinned so the official release packages are guaranteed to work. If then users choose to upgrade AO there is no guarantee things will work (same as PyTorch) but we'll try our best not to break things for no good reason

In tune CI you should always be testing all your latest stable dependencies and all your latest nightly dependencies. We should never be catching BC issues at release time but at nightly CI time, that way upgrading a stable release can be a safe activity. Personally I wouldn't wait more than a few days after an official AO release to make an upgrade

@felipemello1 felipemello1 changed the title remove torchao version pin fix QAT version dependency Aug 16, 2024
@felipemello1 felipemello1 requested a review from msaroufim August 16, 2024 15:04
@@ -65,7 +65,7 @@ enable_activation_checkpointing: True
memory_efficient_fsdp_wrap: False

# Reduced precision
dtype: bf16
dtype: fp32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look right. I will submit a PR to remove that assertion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw how come this wasn't caught in the tune nightly CI? @joecummings

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@msaroufim we don't actually test with our "prod" configs, instead we define a set of test configs that we deem to be (pretty) representative of the configs we provide. Unfortunately to do loss parity checks we tend to set dtype=fp32 in the tests (see here for the QAT test), so as a result this one slipped by

Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

version upgrade looks fine to me, probably want Andrew to also review the reduced precision change you have since not sure if that has perf implication

@msaroufim msaroufim self-requested a review August 16, 2024 17:21
msaroufim
msaroufim previously approved these changes Aug 16, 2024
@andrewor14
Copy link
Contributor

@felipemello1 How urgent is upgrading torchao? I will submit a fix in torchao itself to remove that assertion, but I don't think we want to change the default precision in the QAT recipes. We have another release planned in early september, but if that's too late for you maybe we can do a 0.4.1 release with the fix?

andrewor14 added a commit to pytorch/ao that referenced this pull request Aug 16, 2024
This was added originally for perf reasons specific to 8da4w,
but the autograd.Function has since been adapted for more general
use. A few users are hitting this assertion error.

More context: pytorch/torchtune#1333
@andrewor14
Copy link
Contributor

pytorch/ao#692

@msaroufim msaroufim self-requested a review August 16, 2024 17:36
@felipemello1
Copy link
Contributor Author

We have another release planned in early september, but if that's too late for you maybe we can do a 0.4.1 release with the fix?

Thanks for the fix @andrewor14!

If making a release is not a huge effort, this would solve multiple problems: Our regression error, the import version, and the dtype. Itt would be convenient. However, i dont think that we have a huge number of users using QAT, so waiting for september wouldnt be terrible.

In summary, if making the release is easy, that would be very neat. But if its going to take you a considerable amount of time, we can wait 2 weeks.

andrewor14 added a commit to pytorch/ao that referenced this pull request Aug 16, 2024
This was added originally for perf reasons specific to 8da4w,
but the autograd.Function has since been adapted for more general
use. A few users are hitting this assertion error.

More context: pytorch/torchtune#1333
@ebsmothers
Copy link
Contributor

Is this on hold until next torchao release then? And if so are we gonna just bump to 0.5.0? If so let's make sure that our nightly CI is green before that release

@felipemello1
Copy link
Contributor Author

Is this on hold until next torchao release then?

thats my understanding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants