Skip to content

Commit 572d8e2

Browse files
Adding better way to define multiple concepts and also validation capabilities. (#3807)
* - Added validation parameters - Changed some parameter descriptions to better explain their use. - Fixed a few typos. - Added concept_list parameter for better management of multiple subjects - changed logic for image validation * - Fixed bad logic for class data root directories * Defaulting validation_steps to None for an easier logic * Fixed multiple validation prompts * Fixed bug on validation negative prompt * Changed validation logic for tracker. * Added uuid for validation image labeling * Fix error when comparing validation prompts and validation negative prompts * Improved error message when negative prompts for validation are more than the number of prompts * - Changed image tracking number from epoch to global_step - Added Typing for functions * Added some validations more when using concept_list parameter and the regular ones. * Fixed error message * Added more validations for validation parameters * Improved messaging for errors * Fixed validation error for parameters with default values * - Added train step to image name for validation - reformatted code * - Added train step to image's name for validation - reformatted code * Updated README.md file. * reverted back original script of train_dreambooth.py * reverted back original script of train_dreambooth.py * left one blank line at the eof * reverted back setup.py * reverted back setup.py * added same logic for when parameters for prior preservation are used without enabling the flag while using concept_list parameter. * Ran black formatter. * fixed a few strings * fixed import sort with isort and removed fstrings without placeholder * fixed import order with ruff (since with isort wasn't ok) --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2e8668f commit 572d8e2

File tree

2 files changed

+404
-53
lines changed

2 files changed

+404
-53
lines changed

examples/research_projects/multi_subject_dreambooth/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,53 @@ This example shows training for 2 subjects, but please note that the model can b
8686

8787
Note also that in this script, `sks` and `t@y` were used as tokens to learn the new subjects ([this thread](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/issues/71) inspired the use of `t@y` as our second identifier). However, there may be better rare tokens to experiment with, and results also seemed to be good when more intuitive words are used.
8888

89+
**Important**: New parameters are added to the script, making possible to validate the progress of the training by
90+
generating images at specified steps. Taking also into account that a comma separated list in a text field for a prompt
91+
it's never a good idea (simply because it is very common in prompts to have them as part of a regular text) we
92+
introduce the `concept_list` parameter: allowing to specify a json-like file where you can define the different
93+
configuration for each subject that you want to train.
94+
95+
An example of how to generate the file:
96+
```python
97+
import json
98+
99+
# here we are using parameters for prior-preservation and validation as well.
100+
concepts_list = [
101+
{
102+
"instance_prompt": "drawing of a t@y meme",
103+
"class_prompt": "drawing of a meme",
104+
"instance_data_dir": "/some_folder/meme_toy",
105+
"class_data_dir": "/data/meme",
106+
"validation_prompt": "drawing of a t@y meme about football in Uruguay",
107+
"validation_negative_prompt": "black and white"
108+
},
109+
{
110+
"instance_prompt": "drawing of a sks sir",
111+
"class_prompt": "drawing of a sir",
112+
"instance_data_dir": "/some_other_folder/sir_sks",
113+
"class_data_dir": "/data/sir",
114+
"validation_prompt": "drawing of a sks sir with the Uruguayan sun in his chest",
115+
"validation_negative_prompt": "an old man",
116+
"validation_guidance_scale": 20,
117+
"validation_number_images": 3,
118+
"validation_inference_steps": 10
119+
}
120+
]
121+
122+
with open("concepts_list.json", "w") as f:
123+
json.dump(concepts_list, f, indent=4)
124+
```
125+
And then just point to the file when executing the script:
126+
127+
```bash
128+
# exports...
129+
accelerate launch train_multi_subject_dreambooth.py \
130+
# more parameters...
131+
--concepts_list="concepts_list.json"
132+
```
133+
134+
You can use the helper from the script to get a better sense of each parameter.
135+
89136
### Inference
90137

91138
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.

0 commit comments

Comments
 (0)