Skip to content

Commit 1334166

Browse files
sayakpaulCaptnSeraphpatrickvonplatenbghira
authored and
Jimmy
committed
add: train to text image with sdxl script. (huggingface#4505)
* add: train to text image with sdxl script. Co-authored-by: CaptnSeraph <[email protected]> * fix: partial func. * fix: default value of output_dir. * make style * set num inference steps to 25. * remove mentions of LoRA. * up min version * add: ema cli arg * run device placement while running step. * precompute vae encodings too. * fix * debug * should work now. * debug * debug * goes alright? * style * debugging * debugging * debugging * debugging * fix * reinit scheduler if prediction_type was passed. * akways cast vae in float32 * better handling of snr. Co-authored-by: bghira <[email protected]> * the vae should be also passed * add: docs. * add: sdlx t2i tests * save the pipeline * autocast. * fix: save_model_card * fix: save_model_card. --------- Co-authored-by: CaptnSeraph <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: bghira <[email protected]>
1 parent 35ec86c commit 1334166

File tree

7 files changed

+1298
-15
lines changed

7 files changed

+1298
-15
lines changed

docs/source/en/training/dreambooth.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -707,4 +707,4 @@ accelerate launch train_dreambooth.py \
707707

708708
## Stable Diffusion XL
709709

710-
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).
710+
We support fine-tuning of the UNet and text encoders shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).

docs/source/en/training/text2image.md

+6
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,9 @@ image.save("yoda-pokemon.png")
275275
```
276276
</jax>
277277
</frameworkcontent>
278+
279+
280+
## Stable Diffusion XL
281+
282+
* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
283+
* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).

examples/test_examples.py

+24
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,30 @@ def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_ch
757757
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
758758
)
759759

760+
def test_text_to_image_sdxl(self):
761+
with tempfile.TemporaryDirectory() as tmpdir:
762+
test_args = f"""
763+
examples/text_to_image/train_text_to_image_sdxl.py
764+
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
765+
--dataset_name hf-internal-testing/dummy_image_text_data
766+
--resolution 64
767+
--center_crop
768+
--random_flip
769+
--train_batch_size 1
770+
--gradient_accumulation_steps 1
771+
--max_train_steps 2
772+
--learning_rate 5.0e-04
773+
--scale_lr
774+
--lr_scheduler constant
775+
--lr_warmup_steps 0
776+
--output_dir {tmpdir}
777+
""".split()
778+
779+
run_command(self._launch_args + test_args)
780+
# save_pretrained smoke test
781+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
782+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
783+
760784
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
761785
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
762786
prompt = "a prompt"

examples/text_to_image/README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -319,4 +319,5 @@ According to [this issue](https://github.com/huggingface/diffusers/issues/2234#i
319319

320320
## Stable Diffusion XL
321321

322-
We support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_xl.py` script. Please refer to the docs [here](./README_sdxl.md).
322+
* We support fine-tuning the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) via the `train_text_to_image_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).
323+
* We also support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](./README_sdxl.md).

examples/text_to_image/README_sdxl.md

+66-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,8 @@
1-
# LoRA training example for Stable Diffusion XL (SDXL)
1+
# Stable Diffusion XL text-to-image fine-tuning
22

3-
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
3+
The `train_text_to_image_sdxl.py` script shows how to fine-tune Stable Diffusion XL (SDXL) on your own dataset.
44

5-
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
6-
7-
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
8-
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
9-
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
10-
11-
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
12-
13-
With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
14-
on consumer GPUs like Tesla T4, Tesla V100.
5+
🚨 This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset. 🚨
156

167
## Running locally with PyTorch
178

@@ -57,6 +48,69 @@ When running `accelerate config`, if we specify torch compile mode to True there
5748

5849
### Training
5950

51+
```bash
52+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
53+
export VAE="madebyollin/sdxl-vae-fp16-fix"
54+
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
55+
56+
accelerate launch train_text_to_image_sdxl.py \
57+
--pretrained_model_name_or_path=$MODEL_NAME \
58+
--pretrained_vae_model_name_or_path=$VAE \
59+
--dataset_name=$DATASET_NAME \
60+
--enable_xformers_memory_efficient_attention \
61+
--resolution=512 --center_crop --random_flip \
62+
--proportion_empty_prompts=0.2 \
63+
--train_batch_size=1 \
64+
--gradient_accumulation_steps=4 --gradient_checkpointing \
65+
--max_train_steps=10000 \
66+
--use_8bit_adam \
67+
--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
68+
--mixed_precision="fp16" \
69+
--report_to="wandb" \
70+
--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
71+
--checkpointing_steps=5000 \
72+
--output_dir="sdxl-pokemon-model" \
73+
--push_to_hub
74+
```
75+
76+
**Notes**:
77+
78+
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
79+
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
80+
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
81+
82+
### Inference
83+
84+
```python
85+
from diffusers import DiffusionPipeline
86+
import torch
87+
88+
model_path = "you-model-id-goes-here" # <-- change this
89+
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
90+
pipe.to("cuda")
91+
92+
prompt = "A pokemon with green eyes and red legs."
93+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
94+
image.save("pokemon.png")
95+
```
96+
97+
## LoRA training example for Stable Diffusion XL (SDXL)
98+
99+
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
100+
101+
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
102+
103+
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
104+
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
105+
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
106+
107+
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
108+
109+
With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
110+
on consumer GPUs like Tesla T4, Tesla V100.
111+
112+
### Training
113+
60114
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
61115

62116
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**

examples/text_to_image/train_text_to_image_lora_sdxl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
16+
"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
1717

1818
import argparse
1919
import itertools

0 commit comments

Comments
 (0)