Skip to content

replace adamW and pagedadam with 8bitpagedadam or torchao CPUOffloadOptimizer #1576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
felipemello1 opened this issue Sep 13, 2024 · 16 comments

Comments

@felipemello1
Copy link
Contributor

felipemello1 commented Sep 13, 2024

Apparently there is no reason to use paged adam instead of the 8bit version. We could replace it.

Also, full finetune single device should use paged adam, instead of adamw, for better memory.

For single device, we have torchaos version that is faster than the one from bitsandbytes: https://github.com/pytorch/ao/blob/8236a874479a9a9168e584c81dda8707f4c41006/torchao/prototype/low_bit_optim/cpu_offload.py#L9

@felipemello1 felipemello1 changed the title replaces pagedadam with 8bitpagedadam or torchao CPUOffloadOptimizer replace pagedadam with 8bitpagedadam or torchao CPUOffloadOptimizer Sep 13, 2024
@SalmanMohammadi
Copy link
Collaborator

These are just for our full finetune low memory configs right? I almost wonder if we should re-benchmark this recipe with all the new memory optimizations that have been coming in.

@NeuralFlux
Copy link

Apparently there is no reason to use paged adam instead of the 8bit version. We should replace it.

@felipemello1 can you please cite the source? I'm deciding between optimizers too.

@felipemello1
Copy link
Contributor Author

felipemello1 commented Sep 13, 2024

@NeuralFlux I dont have it :/ But what i heard from some other coworkers is that they didnt observe change in the loss. Are you doing full finetuning? If so, you need PagedAdam/8bit to save memory.

But if you are using LoRA, you dont need pagedAdam, since the gradients are not your bottleneck. You can just use AdamW with fused=True.

@NeuralFlux
Copy link

No worries! I'm doing QLoRA but keep running into OOM because of big sequence lengths. I'm using pagedAdam to save how much ever memory I can. I noticed we are not compatible with torchao==0.5.0 yet, which has the CPU offloader. And what does fused=True do? I read the AdamW docs but could not make sense of it.

@felipemello1
Copy link
Contributor Author

felipemello1 commented Sep 13, 2024

How big is the sequence length? Also, are you using torchtune nighlities? If you arent, please try this:

pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 

And run your model with compile=True and enable_activation_checkpoint=True

you should see a huge difference in memory/tokens per second

@felipemello1
Copy link
Contributor Author

felipemello1 commented Sep 13, 2024

Make sure that your config is using the loss= chunked cross entropy, like we have in our default configs

@NeuralFlux
Copy link

Sure! Also, I noticed QLoRA config mentions dtype as bfloat16. Is this compute dtype or storage dtype?

I tried installing but pip still tells me

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtune 0.2.1 requires torchao==0.3.1, but you have torchao 0.5.0 which is incompatible.

@felipemello1
Copy link
Contributor Author

felipemello1 commented Sep 13, 2024

hmm, maybe try a fresh environment?

conda create -n your_env_name python=3.10
conda activate your_env_name
pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121

Is this compute dtype or storage dtype?

compute is done in bf16. The quantize layers that are not being trained are stored in NF4, which saves 15GiB -> 5GiB if your model is llama 8B

@NeuralFlux
Copy link

Gotcha, I will try that soon (weekend's about to start here haha). Have a good weekend!

@NeuralFlux
Copy link

Hi @felipemello1 . I launched a job over the weekend that worked. Setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True and using bitsandbytes.optim.PagedAdamW did the job. I think the former was the main issue because memory reserved by PyTorch was in GBs. Thanks a lot for your inputs!

@felipemello1 felipemello1 changed the title replace pagedadam with 8bitpagedadam or torchao CPUOffloadOptimizer replace adamW and pagedadam with 8bitpagedadam or torchao CPUOffloadOptimizer Sep 23, 2024
@FurkanGozukara
Copy link

PagedAdamW

do you know difference between PagedAdEMAMix8bit and bitsandbytes.optim.AdEMAMix8bit

@ebsmothers
Copy link
Contributor

Hi @FurkanGozukara I'm not familiar with PagedAdEMAMix8bit, can you share a link?

@FurkanGozukara
Copy link

Hi @FurkanGozukara I'm not familiar with PagedAdEMAMix8bit, can you share a link?

kohya-ss/sd-scripts#1640

@ebsmothers
Copy link
Contributor

@FurkanGozukara thanks for sharing the link. Personally I'm not too familiar with these optimizers as we mostly use ones that are available in bitsandbytes. Maybe best to ask on the sd-scripts repo directly?

@FurkanGozukara
Copy link

@FurkanGozukara thanks for sharing the link. Personally I'm not too familiar with these optimizers as we mostly use ones that are available in bitsandbytes. Maybe best to ask on the sd-scripts repo directly?

thanks. paged seems like automatically using entire vram but not throwing out of vram really good :)

@ebsmothers
Copy link
Contributor

Going to close this issue as it's more of a question for sd-scripts. @FurkanGozukara feel free to reopen if you need any more assistance on optimizer usage in torchtune

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants