-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add UniDiffuser model and pipeline #2963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Currently, the code in the PR isn't in a working state, and I haven't implemented tests or tested the code yet. I've opened the PR because I wanted to get some preliminary feedback on the design and code. In particular, I have the following questions: Design Questions:
Would it be better if I split
Questions about Tests:
[if there is a better place to move this discussion, please let me know :) ] |
def __init__( | ||
self, | ||
tokenizer: GPT2Tokenizer, | ||
text_decoder: GPT2LMHeadModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we try to seperate the tokenizer and text decoder here.
diffusers
should be able to load the tokenizer out of the box, you just have to define it in the pipeline, e.g. here: https://github.com/huggingface/diffusers/pull/2963/files#r1159526676
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we cannot pass the text_decoder
here at init as this would prevent us to be able to use from_pretained(...)
of the model class. Could you maybe try to follow the design as done here:
class SpectrogramContEncoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): |
See how we import blocks from transformers
to design our own new model. I think here you could just do the following:
def __init__(
self,
num_layers=12,
...
):
config = GPT2Config(...take all the config params from init)
self.text_decoder = GPT2LMHeadModel(config)
We then design a new checkpoint architecture for the UniDiffusersTextDecoder
and upload pretrained weights for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the design of __init__
following the example: removed tokenizer
and text_decoder
args, added GPT2 config args.
eos = "<|EOS|>" | ||
special_tokens_dict = {"eos_token": eos} | ||
self.tokenizer = tokenizer | ||
self.tokenizer.add_special_tokens(special_tokens_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we can do this directly for the uploaded tokenizer. E.g. let's just upload a tokenizer that has EOS already added so that we don't have to do it every time we call the model at init
More than happy to help here later on!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the tokenizer logic from __init__
, will work on uploading the appropriate tokenizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've prepared some native diffusers checkpoints for the current implementation of the UniDiffuserPipeline
and its building blocks (e.g. UniDiffuserModel
, UniDiffuserTextDecoder
, etc.) [see the convert_to_ckpt.py
script]. How can I upload these up to the hub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to upload some models to the hub (see e.g. small test models here), but I'm confused about how to save/push to hub a tokenizer with added special tokens. The documentation for PreTrainedTokenizerBase.from_pretrained
says that it won't save modifications to the tokenizer after initialization and I wasn't able to find any resources on how to do it after searching.
For reference, the code in the base unidiffuser library is something like
eos = '<|EOS|>'
special_tokens_dict = {'eos_token': eos}
base_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
base_tokenizer.add_special_tokens(special_tokens_dict)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding uploading the weights and the new tokenizer, you can can call push_to_hub()
directly on the model.
So, for example (considering UniDiffuserModel
is already populated with the pre-trained checkpoints):
unidiffusers = UniDiffuserModel(...)
unidiffusers.push_to_hub("your_hub_user_name/model_id")
Same applies for the rest of the models and the tokenizer.
self.transformer = text_decoder | ||
# TODO: need to set the eos_token_id correctly | ||
self.transformer.config.eos_token_id = self.tokenizer.eos_token_id | ||
self.transformer.resize_token_embeddings(len(self.tokenizer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also make sure that the GPT2Transformer has the correct number of word embeddings before loading it so that we don't have to always resize the embedding every time at init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. I would prefer to have the decoder with the rejigged embeddings on the Hub rather than rejigging on the fly.
return generated_captions | ||
|
||
@torch.no_grad() | ||
def generate_beam( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for me!
""" | ||
|
||
@register_to_config | ||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design here looks good to me! Note that I think we can remove some redundant code that is not needed for this use case. I think you only need one of the three cases:
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should have removed most of the redundant code (kept only the code handling the patch input case, since that's what the original UniDiffuser implementation used).
Great first design! I left some comments directly in the code. In short I think the general design is very nice - the models should be defined under the pipeline folder just like you do and the pipeline also looks quite nice already. Answering your questions in line
I think since the purpose of UniDiffusers is exactly to bring all modes into the same distribution, one pipeline is nice here. So this design works for me. I'd maybe just not have a "mode" call input, but instead automatically decide the mode depending on what the user puts in. E.g. if the user just passes a "text" input, we're in text2img mode, if just a "image" input, we're in image to text mode => would this design work or are the inputs not enough to define which mode one is in? E.g. are muiltple modes possible for the same input combination?
Left comments mostly directly in the code. In short:
Not really. Some guides that could help:
Regarding tiny models, yeah we just create them ourselves. What you can do here is to just load tiny configs to create random tiny models and use those for faster testing :-)
Hope this helps a bit so that you can move forward, let me know if you need more help! |
Thanks for the review! With regards to this:
for the currently supported modes, there is some ambiguity when neither text nor image input is provided. In this case, we cannot be sure whether the user wants unconditional ("marginal") image generation, unconditional ("marginal") text generation, or joint image-text generation. The original code additionally supports image variation ("img2text2img") and text variation ("text2img2text") modes, whose inputs would be the same as the image-to-text (a conditioning image) and text-to-image (a conditioning prompt) modes, respectively. So supporting these modes would also cause some ambiguity. So perhaps we could infer the [Just as a side note, the image variation implementation is different between [Edit: pushed new commit with possible implementation as described above] |
I have referenced some codes of yours and combined with mine, and also submited an initial version PR PaddlePaddle/PaddleNLP#5487 , hope to learn from each other and contribute to the community |
Hi @patrickvonplaten and @baofff, In looking at the noise prediction model architecture, I'm using
In light of this, I have the following questions:
|
As a note, if you want to look at the code I used to calculate the |
Very cool! This looks like almost ready to be merged to me - thanks a lot for re-iterating on the design :-) |
@williamberman @sayakpaul when you have a moment, it'd be super cool if you could review |
stop_token: str = "<|EOS|>", | ||
): | ||
""" | ||
Generates text using the given tokenizer and text prompt or token embedding via beam search. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Help me understand this a bit. Why would there be a need to generate text from a given text prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I think I wrote this docstring in a confusing way. In the context of UniDiffuser sampling, we use this function to generate output text (when appropriate) from the text latents after we process the CLIP-embedded input prompt using the unet
(UniDiffuserModel
) model. The method accepts both prompt
and embed
arguments, for input tokens and embeddings respectively, but we only ever call it with input embeddings (as described above):
generated_captions.append(self.generate_beam(tokenizer, embed=feature, device=device)[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh okay. Yeah, then I guess we need to it make it a bit clearer from the code?
labels (`torch.Tensor`, *optional*): | ||
TODO | ||
""" | ||
embedding_text = self.transformer.transformer.wte(tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why transformer.transformer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took the forward(...)
method from the original code. Upon reviewing the code, I think the forward
method was probably intended to do the following: the tokens
argument is a sequence of input vocab token IDs for the GPT2LMHeadModel
, while the prefix
argument is the hidden state of another model (e.g. something like transformers.modeling_outputs.BaseModelOutputWithPooling.last_hidden_state
of a CLIPTextModel
). prefix
then gets converted to an intermediate representation via self.encode_prefix(...)
and then converted into the latent space of the GPT model via self.decode_prefix(...)
(if they are being used). We then combine the embedding of tokens
with the prefix
embedding and then do a forward pass of the internal GPT2LMHeadModel
.
I guess it's confusing currently because on lines 52-54 instead of using n_embd
as the input dimension to nn.Linear
we should instead have a new argument prefix_inner_dim
and use that, e.g.
self.encode_prefix = (
nn.Linear(prefix_inner_dim, self.prefix_hidden_dim) if self.prefix_hidden_dim is not None else nn.Identity()
)
Furthermore, prefix_hidden_dim
should probably always need to be supplied, since prefix_inner_dim
and n_embd
are in general not guaranteed to be the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for explaining!
I guess we will know better once we start testing the code.
Design looks quite clean and matured to me. My questions / comments are very minor ones and are probably already covered by @patrickvonplaten's comments.
+1 to this. |
I've uploaded a I've also opened a PR at |
This is great! Thanks so much for your efforts. I think now the TODOs are:
I have also merged your PR. So, hopefully, this unblocks you. @patrickvonplaten can help us with the repo transfers. |
Let me know once you need help with a model transfer |
* Fix a bug of pano when not doing CFG * enhance code quality * apply formatting. --------- Co-authored-by: Sayak Paul <[email protected]>
* fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality
* fix: norm group test for UNet3D. * chore: speed up the panorama tests (fast). * set default value of _test_inference_batch_single_identical. * fix: batch_sizes default value.
| Pipeline | Tasks | Demo | ||
|---|---|:---:| | ||
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Pipeline | Tasks | Demo | |
|---|---|:---:| | |
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | | | |
| Pipeline | Tasks | Demo | Colab | | |
|---|---|:---:| | |
| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*, *Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | [🤗 Spaces](https://huggingface.co/spaces/thu-ml/unidiffuser) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/unidiffuser.ipynb) | |
For now, let's add a link to the original demo. @hysts is working on to change the demo to have diffusers
usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prepared a Colab Notebook from your awesome documentation: huggingface/notebooks#377
Also prepared this GIF to showcase the powerfulness of the pipeline:
import requests | ||
import torch | ||
from PIL import Image | ||
from io import BytesIO | ||
|
||
from diffusers import UniDiffuserPipeline | ||
|
||
device = "cuda" | ||
model_id_or_path = "thu-ml/unidiffuser-v1" | ||
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) | ||
pipe.to(device) | ||
|
||
# Image-to-text generation | ||
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" | ||
response = requests.get(image_url) | ||
init_image = Image.open(BytesIO(response.content)).convert("RGB") | ||
init_image = init_image.resize((512, 512)) | ||
|
||
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) | ||
i2t_text = sample.text[0] | ||
print(text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import requests | |
import torch | |
from PIL import Image | |
from io import BytesIO | |
from diffusers import UniDiffuserPipeline | |
device = "cuda" | |
model_id_or_path = "thu-ml/unidiffuser-v1" | |
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) | |
pipe.to(device) | |
# Image-to-text generation | |
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" | |
response = requests.get(image_url) | |
init_image = Image.open(BytesIO(response.content)).convert("RGB") | |
init_image = init_image.resize((512, 512)) | |
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) | |
i2t_text = sample.text[0] | |
print(text) | |
from diffusers import UniDiffuserPipeline | |
from diffusers.utils import load_image | |
device = "cuda" | |
model_id_or_path = "thu-ml/unidiffuser-v1" | |
pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) | |
pipe.to(device) | |
# Image-to-text generation | |
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" | |
init_image = load_image(image_url).resize((512, 512)) | |
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) | |
i2t_text = sample.text[0] | |
print(i2t_text) |
Reduces the LoC :)
# Image variation can be performed with a image-to-text generation followed by a text-to-image generation: | ||
# 1. Image-to-text generation | ||
image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" | ||
response = requests.get(image_url) | ||
init_image = Image.open(BytesIO(response.content)).convert("RGB") | ||
init_image = init_image.resize((512, 512)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can follow the same as https://github.com/huggingface/diffusers/pull/2963/files#r1205061596 for loading and resizing the image?
- all | ||
- __call__ | ||
|
||
## ImageTextPipelineOutput |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also wanted to know if there's any argument control the number of images / text I wanted to generate as a part of the variation mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can control the num_images_per_prompt
in the text-to-image mode, so that's settled. But what about text variation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For modes which generate only text (img2text
and text
), there's an analogous num_prompts_per_image
argument to __call__
. So when you perform the second img2text
generation for text variation you can specify num_prompts_per_image > 1
to get multiple text variation samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go here:
https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/outputs.mdx
It feels more natural to me to have the documentation for ImageTextPipelineOutput
alongside ImagePipelineOutput
, which is at the Diffusion Pipeline
doc page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone ahead and moved the ImageTextPipelineOutput
documentation to /api/diffusion_pipeline.mdx
(alongside the ImagePipelineOutput
and AudioPipelineOutput
documentation). Let me know if it would be better somewhere else (for example, at /api/outputs.mdx
as originally suggested) :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand. But we will soon update that too :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, so would it be better if I move it to /api/outputs
? Or is it fine to leave it at /api/diffusion_pipeline
for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep it as is for now. Then we will bulk move things :)
|
||
sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) | ||
i2t_text = sample.text[0] | ||
print(text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: should be i2t_text
.
|
||
### Unconditional Image and Text Generation | ||
|
||
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a `UniDiffuserPipeline` will produce a (image, text) pair: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a `UniDiffuserPipeline` will produce a (image, text) pair: | |
Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair: |
So, that the hyperlink is automatically rendered.
print(text) | ||
``` | ||
|
||
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuser.set_image_to_text_mode`]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuser.set_image_to_text_mode`]. | |
The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`]. |
…fuser.mdx to /api/diffusion_pipeline.mdx.
self.transformer = GPT2LMHeadModel(gpt_config) | ||
|
||
def forward( | ||
self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
return self.encode_prefix(prefix) | ||
|
||
@torch.no_grad() | ||
def generate_captions(self, features, eos_token_id, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
eos_token_id: Optional[int] = None, | ||
input_ids=None, | ||
input_embeds=None, | ||
device=None, | ||
beam_size: int = 5, | ||
entry_length: int = 67, | ||
temperature: float = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eos_token_id: Optional[int] = None, | |
input_ids=None, | |
input_embeds=None, | |
device=None, | |
beam_size: int = 5, | |
entry_length: int = 67, | |
temperature: float = 1.0, | |
input_ids=None, | |
input_embeds=None, | |
device=None, | |
beam_size: int = 5, | |
entry_length: int = 67, | |
temperature: float = 1.0, | |
eos_token_id: Optional[int] = None, |
(nit) Let's change the order here maybe since the eos_token_id
should probably not be the first input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed :).
cross_attention_kwargs=None, | ||
class_labels=None, | ||
): | ||
# Pre-LayerNorm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great this PR looks good to go for me! Just one final tiny nit regarding the ordering of the generate input.
Apart from this, this is good to merge from my side :-) Incredible work here @dg845! This is really a difficult model with many components and the final implementation is super nice :-)
@dg845 once the conflicts are resolved and tests pass, we will merge :) Meanwhile, I will also correct the gif. Really amazing contribution. I hope the contribution experience was enjoyable for you. |
@patrickvonplaten a friendly ping for these transfers: |
Thanks! I really enjoyed working on this PR :). And thanks for all the advice and help along the way :). |
@sayakpaul feel free to merge whenever! All good from my side |
@dg845 thanks again for your amazing contribution. The pipeline and the components are now live at: https://huggingface.co/docs/diffusers/main/en/api/pipelines/unidiffuser |
* Fix a bug of pano when not doing CFG (#3030) * Fix a bug of pano when not doing CFG * enhance code quality * apply formatting. --------- Co-authored-by: Sayak Paul <[email protected]> * Text2video zero refinements (#3070) * fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality * Release: v0.15.0 * [Tests] Speed up panorama tests (#3067) * fix: norm group test for UNet3D. * chore: speed up the panorama tests (fast). * set default value of _test_inference_batch_single_identical. * fix: batch_sizes default value. * [Post release] v0.16.0dev (#3072) * Adds profiling flags, computes train metrics average. (#3053) * WIP controlnet training - bugfix --streaming - bugfix running report_to!='wandb' - adds memory profile before validation * Adds final logging statement. * Sets train epochs to 11. Looking at a longer ~16ep run, we see only good validation images after ~11ep: https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8 * Removes --logging_dir (it's not used). * Adds --profile flags. * Updates --output_dir=runs/fill-circle-{timestamp}. * Compute mean of `train_metrics`. Previously `train_metrics[-1]` was logged, resulting in very bumpy train metrics. * Improves logging a bit. - adds l2_grads gradient norm logging - adds steps_per_sec - sets walltime as x coordinate of train/step - logs controlnet_params config * Adds --ccache (doesn't really help though). * minor fix in controlnet flax example (#2986) * fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide * Bugfix --profile_steps * Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`. * Logs fractional epoch. * Adds relative `walltime` metric. * Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`. * Applied `black`. * Streamlines commands in README a bit. * Removes `--ccache`. This makes only a very small difference (~1 min) with this model size, so removing the option introduced in cdb3cc. * Re-ran `black`. * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Converts spaces to tab. * Removes repeated args. * Skips first step (compilation) in profiling * Updates README with profiling instructions. * Unifies tabs/spaces in README. * Re-ran style & quality. --------- Co-authored-by: Sayak Paul <[email protected]> * [Pipelines] Make sure that None functions are correctly not saved (#3080) * doc string example remove from_pt (#3083) * [Tests] parallelize (#3078) * [Tests] parallelize * finish folder structuring * Parallelize tests more * Correct saving of pipelines * make sure logging level is correct * try again * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]> * Throw deprecation warning for return_cached_folder (#3092) Throw deprecation warning * Allow SD attend and excite pipeline to work with any size output images (#2835) Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603 * [docs] Update community pipeline docs (#2989) * update community pipeline docs * fix formatting * explain sharing workflows * Add to support Guess Mode for StableDiffusionControlnetPipleline (#2998) * add guess mode (WIP) * fix uncond/cond order * support guidance_scale=1.0 and batch != 1 * remove magic coeff * add docstring * add intergration test * add document to controlnet.mdx * made the comments a bit more explanatory * fix table * fix default value for attend-and-excite (#3099) * fix default * remvoe one line as requested by gc team (#3077) remvoe one line * ddpm custom timesteps (#3007) add custom timesteps test add custom timesteps descending order check docs timesteps -> custom_timesteps can only pass one of num_inference_steps and timesteps * Fix breaking change in `pipeline_stable_diffusion_controlnet.py` (#3118) fix breaking change * Add global pooling to controlnet (#3121) * [Bug fix] Fix img2img processor with safety checker (#3127) Fix img2img processor with safety checker * [Bug fix] Make sure correct timesteps are chosen for img2img (#3128) Make sure correct timesteps are chosen for img2img * Improve deprecation warnings (#3131) * Fix config deprecation (#3129) * Better deprecation message * Better deprecation message * Better doc string * Fixes * fix more * fix more * Improve __getattr__ * correct more * fix more * fix * Improve more * more improvements * fix more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * make style * Fix all rest & add tests & remove old deprecation fns --------- Co-authored-by: Pedro Cuenca <[email protected]> * feat: verfication of multi-gpu support for select examples. (#3126) * feat: verfication of multi-gpu support for select examples. * add: multi-gpu training sections to the relvant doc pages. * speed up attend-and-excite fast tests (#3079) * Optimize log_validation in train_controlnet_flax (#3110) extract pipeline from log_validation * make style * Correct textual inversion readme (#3145) * Update README.md * Apply suggestions from code review * Add unet act fn to other model components (#3136) Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior. * class labels timestep embeddings projection dtype cast (#3137) This mimics the dtype cast for the standard time embeddings * [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705) * [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model * Address review comment from PR * PyLint formatting * Some more pylint fixes, unrelated to our change * Another pylint fix * Styling fix * add from_ckpt method as Mixin (#2318) * add mixin class for pipeline from original sd ckpt * Improve * make style * merge main into * Improve more * fix more * up * Apply suggestions from code review * finish docs * rename * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add TensorRT SD/txt2img Community Pipeline to diffusers along with TensorRT utils (#2974) * Add SD/txt2img Community Pipeline to diffusers along with TensorRT utils Signed-off-by: Asfiya Baig <[email protected]> * update installation command Signed-off-by: Asfiya Baig <[email protected]> * update tensorrt installation Signed-off-by: Asfiya Baig <[email protected]> * changes 1. Update setting of cache directory 2. Address comments: merge utils and pipeline code. 3. Address comments: Add section in README Signed-off-by: Asfiya Baig <[email protected]> * apply make style Signed-off-by: Asfiya Baig <[email protected]> --------- Signed-off-by: Asfiya Baig <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * Correct `Transformer2DModel.forward` docstring (#3074) ⚙️chore(transformer_2d) update function signature for encoder_hidden_states * Update pipeline_stable_diffusion_inpaint_legacy.py (#2903) * Update pipeline_stable_diffusion_inpaint_legacy.py * fix preprocessing of Pil images with adequate batch size * revert map * add tests * reformat * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * next try to fix the style * wth is this * Update testing_utils.py * Update testing_utils.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py --------- Co-authored-by: Patrick von Platen <[email protected]> * Modified altdiffusion pipline to support altdiffusion-m18 (#2993) * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 --------- Co-authored-by: root <[email protected]> * controlnet training resize inputs to multiple of 8 (#3135) controlnet training center crop input images to multiple of 8 The pipeline code resizes inputs to multiples of 8. Not doing this resizing in the training script is causing the encoded image to have different height/width dimensions than the encoded conditioning image (which uses a separate encoder that's part of the controlnet model). We resize and center crop the inputs to make sure they're the same size (as well as all other images in the batch). We also check that the initial resolution is a multiple of 8. * adding custom diffusion training to diffusers examples (#3031) * diffusers==0.14.0 update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion * custom diffusion * custom diffusion * custom diffusion * custom diffusion * apply formatting and get rid of bare except. * refactor readme and other minor changes. * misc refactor. * fix: repo_id issue and loaders logging bug. * fix: save_model_card. * fix: save_model_card. * fix: save_model_card. * add: doc entry. * refactor doc,. * custom diffusion * custom diffusion * custom diffusion * apply style. * remove tralining whitespace. * fix: toctree entry. * remove unnecessary print. * custom diffusion * custom diffusion * custom diffusion test * custom diffusion xformer update * custom diffusion xformer update * custom diffusion xformer update --------- Co-authored-by: Nupur Kumari <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Nupur Kumari <[email protected]> * make style * Update custom_diffusion.mdx (#3165) Add missing newlines for rendering the links correctly * Added distillation for quantization example on textual inversion. (#2760) * Added distillation for quantization example on textual inversion. Signed-off-by: Ye, Xinyu <[email protected]> * refined readme and code style. Signed-off-by: Ye, Xinyu <[email protected]> * Update text2images.py * refined code of model load and added compatibility check. Signed-off-by: Ye, Xinyu <[email protected]> * fixed code style. Signed-off-by: Ye, Xinyu <[email protected]> * fix C403 [*] Unnecessary `list` comprehension (rewrite as a `set` comprehension) Signed-off-by: Ye, Xinyu <[email protected]> --------- Signed-off-by: Ye, Xinyu <[email protected]> * Update Noise Autocorrelation Loss Function for Pix2PixZero Pipeline (#2942) * Update Pix2PixZero Auto-correlation Loss * Add fast inversion tests * Clarify purpose and mark as deprecated Fix inversion prompt broadcasting * Register modules set to `None` in config for `test_save_load_optional_components` * Update new tests to coordinate with #2953 * [DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130) * add: LoRA text encoder support for DreamBooth example. * fix initialization. * fix: modification call. * add: entry in the readme. * use dog dataset from hub. * fix: params to clip. * add entry to the LoRA doc. * add: tests for lora. * remove unnecessary list comprehension./ * Update Habana Gaudi documentation (#3169) * Update Habana Gaudi doc * Fix tables * Add model offload to x4 upscaler (#3187) * Add model offload to x4 upscaler * fix * [docs] Deterministic algorithms (#3172) deterministic algos * Update custom_diffusion.mdx to credit the author (#3163) * Update custom_diffusion.mdx * fix: unnecessary list comprehension. * Fix TensorRT community pipeline device set function (#3157) pass silence_dtype_warnings as kwarg Signed-off-by: Asfiya Baig <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * make `from_flax` work for controlnet (#3161) fix from_flax Co-authored-by: Patrick von Platen <[email protected]> * [docs] Clarify training args (#3146) * clarify training arg * apply feedback * Multi Vector Textual Inversion (#3144) * Multi Vector * Improve * fix multi token * improve test * make style * Update examples/test_examples.py * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * update * Finish * Apply suggestions from code review --------- Co-authored-by: Suraj Patil <[email protected]> * Add `Karras sigmas` to HeunDiscreteScheduler (#3160) * Add karras pattern to discrete heun scheduler * Add integration test * Fix failing CI on pytorch test on M1 (mps) --------- Co-authored-by: Patrick von Platen <[email protected]> * [AudioLDM] Fix dtype of returned waveform (#3189) * Fix bug in train_dreambooth_lora (#3183) * Update train_dreambooth_lora.py fix bug * Update train_dreambooth_lora.py * [Community Pipelines] Update lpw_stable_diffusion pipeline (#3197) * Update lpw_stable_diffusion.py * fix cpu offload * Make sure VAE attention works with Torch 2_0 (#3200) * Make sure attention works with Torch 2_0 * make style * Fix more * Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201) Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)" This reverts commit 9965cb50eac12e397473f01535aab43aae76b4ab. * [Bug fix] Fix batch size attention head size mismatch (#3214) * fix mixed precision training on train_dreambooth_inpaint_lora (#3138) cast to weight dtype * adding enable_vae_tiling and disable_vae_tiling functions (#3225) adding enable_vae_tiling and disable_val_tiling functions * Add ControlNet v1.1 docs (#3226) Add v1.1 docs * Fix issue in maybe_convert_prompt (#3188) When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens. Adding a space for the padding tokens fixes this. * Sync cache version check from transformers (#3179) sync cache version check from transformers * Fix docs text inversion (#3166) * Fix docs text inversion * Apply suggestions from code review * add model (#3230) * add * clean * up * clean up more * fix more tests * Improve docs further * improve * more fixes docs * Improve docs more * Update src/diffusers/models/unet_2d_condition.py * fix * up * update doc links * make fix-copies * add safety checker and watermarker to stage 3 doc page code snippets * speed optimizations docs * memory optimization docs * make style * add watermarking snippets to doc string examples * make style * use pt_to_pil helper functions in doc strings * skip mps tests * Improve safety * make style * new logic * fix * fix bad onnx design * make new stable diffusion upscale pipeline model arguments optional * define has_nsfw_concept when non-pil output type * lowercase linked to notebook name --------- Co-authored-by: William Berman <[email protected]> * Allow return pt x4 (#3236) * Add all files * update * Allow fp16 attn for x4 upscaler (#3239) * Add all files * update * Make sure vae is memory efficient for PT 1 * make style * fix fast test (#3241) * Adds a document on token merging (#3208) * add document on token merging. * fix headline. * fix: headline. * add some samples for comparison. * [AudioLDM] Update docs to use updated ckpt (#3240) * [AudioLDM] Update docs to use updated ckpt * make style * Release: v0.16.0 * Post release for 0.16.0 (#3244) * Post release * fix more * [docs] only mention one stage (#3246) * [docs] only mention one stage * add blurb on auto accepting --------- Co-authored-by: William Berman <[email protected]> * Write model card in controlnet training script (#3229) Write model card in controlnet training script. * [2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020) * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * Review comments * [Review comment]: Add is_torchsde_available() * [Review comment]: Test and docs * [Review comment] * [Review comment] * [Review comment] * [Review comment] * [Review comment] --------- Co-authored-by: njindal <[email protected]> * [Stochastic Sampler][Slow Test]: Cuda test fixes (#3257) [Slow Test]: Cuda test fixes Co-authored-by: njindal <[email protected]> * Remove required from tracker_project_name (#3260) Remove required from tracker_project_name. As observed by https://github.com/off99555 in https://github.com/huggingface/diffusers/issues/2695#issuecomment-1470755050, it already has a default value. * adding required parameters while calling the get_up_block and get_down_block (#3210) * removed unnecessary parameters from get_up_block and get_down_block functions * adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions --------- Co-authored-by: Sayak Paul <[email protected]> * [docs] Update interface in repaint.mdx (#3119) Update repaint.mdx accomodate to #1701 * Update IF name to XL (#3262) Co-authored-by: multimodalart <[email protected]> * fix typo in score sde pipeline (#3132) * Fix typo in textual inversion JAX training script (#3123) The pipeline is built as `pipe` but then used as `pipeline`. * AudioDiffusionPipeline - fix encode method after config changes (#3114) * config fixes * deprecate get_input_dims * Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline"" (#3265) Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)" This reverts commit 91a2a80eb2f98a9f64b9e287715add244dc6f2f3. * Fix community pipelines (#3266) * update notebook (#3259) Co-authored-by: yiyixuxu <[email protected]> * [docs] add notes for stateful model changes (#3252) * [docs] add notes for stateful model changes * Update docs/source/en/optimization/fp16.mdx Co-authored-by: Pedro Cuenca <[email protected]> * link to accelerate docs for discarding hooks --------- Co-authored-by: Pedro Cuenca <[email protected]> * [LoRA] quality of life improvements in the loading semantics and docs (#3180) * 👽 qol improvements for LoRA. * better function name? * fix: LoRA weight loading with the new format. * address Patrick's comments. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * change wording around encouraging the use of load_lora_weights(). * fix: function name. --------- Co-authored-by: Patrick von Platen <[email protected]> * [Community Pipelines] EDICT pipeline implementation (#3153) * EDICT pipeline initial commit - Starting point taking from https://github.com/Joqsan/edict-diffusion * refactor __init__() method * minor refactoring * refactor scheduler code - remove scheduler and move its methods to the EDICTPipeline class * make CFG optional - refactor encode_prompt(). - include optional generator for sampling with vae. - minor variable renaming * add EDICT pipeline description to README.md * replace preprocess() with VaeImageProcessor * run make style and make quality commands --------- Co-authored-by: Patrick von Platen <[email protected]> * [Docs]zh translated docs update (#3245) * zh translated docs update * update _toctree * Update logging.mdx (#2863) Fix typos * Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125) * try multi controlnet inpaint * multi controlnet inpaint * multi controlnet inpaint * Let's make sure that dreambooth always uploads to the Hub (#3272) * Update Dreambooth README * Adapt all docs as well * automatically write model card * fix * make style * Diffedit Zero-Shot Inpainting Pipeline (#2837) * Update Pix2PixZero Auto-correlation Loss * Add Stable Diffusion DiffEdit pipeline * Add draft documentation and import code * Bugfixes and refactoring * Add option to not decode latents in the inversion process * Harmonize preprocessing * Revert "Update Pix2PixZero Auto-correlation Loss" This reverts commit b218062fed08d6cc164206d6cb852b2b7b00847a. * Update annotations * rename `compute_mask` to `generate_mask` * Update documentation * Update docs * Update Docs * Fix copy * Change shape of output latents to batch first * Update docs * Add first draft for tests * Bugfix and update tests * Add `cross_attention_kwargs` support for all pipeline methods * Fix Copies * Add support for PIL image latents Add support for mask broadcasting Update docs and tests Align `mask` argument to `mask_image` Remove height and width arguments * Enable MPS Tests * Move example docstrings * Fix test * Fix test * fix pipeline inheritance * Harmonize `prepare_image_latents` with StableDiffusionPix2PixZeroPipeline * Register modules set to `None` in config for `test_save_load_optional_components` * Move fixed logic to specific test class * Clean changes to other pipelines * Update new tests to coordinate with #2953 * Update slow tests for better results * Safety to avoid potential problems with torch.inference_mode * Add reference in SD Pipeline Overview * Fix tests again * Enforce determinism in noise for generate_mask * Fix copies * Widen test tolerance for fp16 based on `test_stable_diffusion_upscale_pipeline_fp16` * Add LoraLoaderMixin and update `prepare_image_latents` * clean up repeat and reg * bugfix * Remove invalid args from docs Suppress spurious warning by repeating image before latent to mask gen * add constant learning rate with custom rule (#3133) * add constant lr with rules * add constant with rules in TYPE_TO_SCHEDULER_FUNCTION * add constant lr rate with rule * hotfix code quality * fix doc style * change name constant_with_rules to piecewise constant * Allow disabling torch 2_0 attention (#3273) * Allow disabling torch 2_0 attention * make style * Update src/diffusers/models/attention.py * [doc] add link to training script (#3271) add link to training script Co-authored-by: yiyixuxu <[email protected]> * temp disable spectogram diffusion tests (#3278) The note-seq package throws an error on import because the default installed version of Ipython is not compatible with python 3.8 which we run in the CI. https://github.com/huggingface/diffusers/actions/runs/4830121056/jobs/8605954838#step:7:9 * Changed sample[0] to images[0] (#3304) A pipeline object stores the results in `images` not in `sample`. Current code blocks don't work. * Typo in tutorial (#3295) * Torch compile graph fix (#3286) * fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test * Postprocessing refactor img2img (#3268) * refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * [Torch 2.0 compile] Fix more torch compile breaks (#3313) * Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]> * fix: scale_lr and sync example readme and docs. (#3299) * fix: scale_lr and sync example readme and docs. * fix doc link. * Update stable_diffusion.mdx (#3310) fixed import statement * Fix missing variable assign in DeepFloyd-IF-II (#3315) Fix missing variable assign lol * Correct doc build for patch releases (#3316) Update build_documentation.yml * Add Stable Diffusion RePaint to community pipelines (#3320) * Add Stable Diffsuion RePaint to community pipelines - Adds Stable Diffsuion RePaint to community pipelines - Add Readme enty for pipeline * Fix: Remove wrong import - Remove wrong import - Minor change in comments * Fix: Code formatting of stable_diffusion_repaint * Fix: ruff errors in stable_diffusion_repaint * Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314) * fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]> * [docs] Improve LoRA docs (#3311) * update docs * add to toctree * apply feedback * Added input pretubation (#3292) * Added input pretubation * Fixed spelling * Update write_own_pipeline.mdx (#3323) * update controlling generation doc with latest goodies. (#3321) * [Quality] Make style (#3341) * Fix config dpm (#3343) * Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344) * add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo * Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. * Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs. * Update fast tests to use test checkpoints stored on the hub and to better match the reference UniDiffuser implementation. * Fix code with make style. * Revert "Fix code style with make style." This reverts commit 10a174a12c82e6abd3d5a57665719a03dbb85ca7. * Add self.image_encoder, self.text_decoder to list of models to offload to CPU in the enable_sequential_cpu_offload(...)/enable_model_cpu_offload(...) methods to make test_cpu_offload_forward_pass pass. * Fix code quality with make style. * Support using a data type embedding for UniDiffuser-v1. * Add fast test for checking UniDiffuser-v1 sampling. * Make changes so that the repository consistency tests pass. * Add UniDiffuser dummy objects via make fix-copies. * Fix bugs and make improvements to the UniDiffuser pipeline: - Improve batch size inference and fix bugs when num_images_per_prompt or num_prompts_per_image > 1 - Add tests for num_images_per_prompt, num_prompts_per_image > 1 - Improve check_inputs, especially regarding checking supplied latents - Add reset_mode method so that mode inference can be re-enabled after mode is set manually - Fix some warnings related to accessing class members directly instead of through their config - Small amount of refactoring in pipeline_unidiffuser.py * Fix code style with make style. * Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring. * Add documentation for UniDiffuser and fix some typos/formatting in docstrings. * Fix code with make style. * Refactor and improve the UniDiffuser convert_from_ckpt.py script. * Move the UniDiffusers convert_from_ckpy.py script to diffusers/scripts/convert_unidiffuser_to_diffusers.py * Fix code quality via make style. * Improve UniDiffuser slow tests. * make style * Fix some typos in the UniDiffuser docs. * Remove outdated logic based on transformers version in UniDiffuser pipeline __init__.py * Remove dependency on einops by refactoring einops operations to pure torch operations. * make style * Add slow test on full checkpoint for joint mode and correct expected image slices/text prefixes. * make style * Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager. * Revert "Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager." This reverts commit 1a58958ab4f024dbc4c90a6404c2e66210db6d00. * Add fast test for CUDA/fp16 model behavior (currently failing). * Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality. * make style * Use a CLIPVisionModelWithProjection instead of CLIPVisionModel for image_encoder to better match the original UniDiffuser implementation. * Make style and remove some testing code. * Fix shape errors for the 'joint' and 'img2text' modes. * Fix tests and remove some testing code. * Add option to use fixed latents for UniDiffuserPipelineSlowTests and fix issue in modeling_text_decoder.py. * Improve UniDiffuser docs, particularly the usage examples, and improve slow tests with new expected outputs. * make style * Fix examples to load model in float16. * In image-to-text mode, sample from the autoencoder moment distribution instead of always getting its mode. * make style * When encoding the image using the VAE, scale the image latents by the VAE's scaling factor. * make style * Clean up code and make slow tests pass. * make fix-copies * [docs] Fix docstring (#3334) fix docstring Co-authored-by: Patrick von Platen <[email protected]> * if dreambooth lora (#3360) * update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings * Postprocessing refactor all others (#3337) * add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [docs] Improve safetensors docstring (#3368) * clarify safetensor docstring * fix typo * apply feedback * add: a warning message when using xformers in a PT 2.0 env. (#3365) * add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> * StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322) * StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [docs] Adapt a model (#3326) * first draft * apply feedback * conv_in.weight thrown away * [docs] Load safetensors (#3333) * safetensors * apply feedback * apply feedback * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [Docs] Fix stable_diffusion.mdx typo (#3398) Fix typo in last code block. Correct "prommpts" to "prompt" * Support ControlNet v1.1 shuffle properly (#3340) * add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]> * [Tests] better determinism (#3374) * enable deterministic pytorch and cuda operations. * disable manual seeding. * make style && make quality for unet_2d tests. * enable determinism for the unet2dconditional model. * add CUBLAS_WORKSPACE_CONFIG for better reproducibility. * relax tolerance (very weird issue, though). * revert to torch manual_seed() where needed. * relax more tolerance. * better placement of the cuda variable and relax more tolerance. * enable determinism for 3d condition model. * relax tolerance. * add: determinism to alt_diffusion. * relax tolerance for alt diffusion. * dance diffusion. * dance diffusion is flaky. * test_dict_tuple_outputs_equivalent edit. * fix two more tests. * fix more ddim tests. * fix: argument. * change to diff in place of difference. * fix: test_save_load call. * test_save_load_float16 call. * fix: expected_max_diff * fix: paint by example. * relax tolerance. * add determinism to 1d unet model. * torch 2.0 regressions seem to be brutal * determinism to vae. * add reason to skipping. * up tolerance. * determinism to vq. * determinism to cuda. * determinism to the generic test pipeline file. * refactor general pipelines testing a bit. * determinism to alt diffusion i2i * up tolerance for alt diff i2i and audio diff * up tolerance. * determinism to audioldm * increase tolerance for audioldm lms. * increase tolerance for paint by paint. * increase tolerance for repaint. * determinism to cycle diffusion and sd 1. * relax tol for cycle diffusion 🚲 * relax tol for sd 1.0 * relax tol for controlnet. * determinism to img var. * relax tol for img variation. * tolerance to i2i sd * make style * determinism to inpaint. * relax tolerance for inpaiting. * determinism for inpainting legacy * relax tolerance. * determinism to instruct pix2pix * determinism to model editing. * model editing tolerance. * panorama determinism * determinism to pix2pix zero. * determinism to sag. * sd 2. determinism * sd. tolerance * disallow tf32 matmul. * relax tolerance is all you need. * make style and determinism to sd 2 depth * relax tolerance for depth. * tolerance to diffedit. * tolerance to sd 2 inpaint. * up tolerance. * determinism in upscaling. * tolerance in upscaler. * more tolerance relaxation. * determinism to v pred. * up tol for v_pred * unclip determinism * determinism to unclip img2img * determinism to text to video. * determinism to last set of tests * up tol. * vq cumsum doesn't have a deterministic kernel * relax tol * relax tol * [docs] Add transformers to install (#3388) add transformers to install * [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add omegaconf for tests (#3400) Add omegaconfg * Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353) * Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]> * Fix docker file (#3402) * up * up * fix: deepseepd_plugin retrieval from accelerate state (#3410) * [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399) * Add `sigmoid` beta scheduler to `DDPMScheduler` docstring * Add `sigmoid` beta scheduler to `RePaintScheduler` docstring --------- Co-authored-by: Patrick von Platen <[email protected]> * Don't install accelerate and transformers from source (#3415) * Don't install transformers and accelerate from source (#3414) * Improve fast tests (#3416) Update pr_tests.yml * attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten * [Docs] update the PT 2.0 optimization doc with latest findings (#3370) * add: benchmarking stats for A100 and V100. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address patrick's comments. * add: rtx 4090 stats * ⚔ benchmark reports done * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * 3313 pr link. * add: plots. Co-authored-by: Pedro <[email protected]> * fix formattimg * update number percent. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Fix style rendering (#3433) * Fix style rendering. * Fix typo * unCLIP scheduler do not use note (#3417) * Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen <[email protected]> * fix warning message pipeline loading (#3446) * add stable diffusion tensorrt img2img pipeline (#3419) * add stable diffusion tensorrt img2img pipeline Signed-off-by: Asfiya Baig <[email protected]> * update docstrings Signed-off-by: Asfiya Baig <[email protected]> --------- Signed-off-by: Asfiya Baig <[email protected]> * Refactor controlnet and add img2img and inpaint (#3386) * refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed * [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335) * Add DPM-Solver Multistep Inverse Scheduler * Add draft tests for DiffEdit * Add inverse sde-dpmsolver steps to tune image diversity from inverted latents * Fix tests --------- Co-authored-by: Patrick von Platen <[email protected]> * [Docs] Fix incomplete docstring for resnet.py (#3438) Fix incomplete docstrings for resnet.py * fix tiled vae blend extent range (#3384) fix tiled vae bleand extent range * Small update to "Next steps" section (#3443) Small update to "Next steps" section: - PyTorch 2 is recommended. - Updated improvement figures. * Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298) * Update pipeline_if_superresolution.py Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape * IFSuperResolutionPipeline: allow the user to override the height and width through the arguments * update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions --------- Co-authored-by: Patrick von Platen <[email protected]> * Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424) * Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]> * [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448) Added bugfix using f strings. * Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404) * gradient checkpointing bug fix * bug fix; changes for reviews * reformat * reformat --------- Co-authored-by: Patrick von Platen <[email protected]> * Make dreambooth lora more robust to orig unet (#3462) * Make dreambooth lora more robust to orig unet * up * Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463) Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size. * Add min snr to text2img lora training script (#3459) add min snr to text2img lora training script * Add inpaint lora scale support (#3460) * add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy <[email protected]> * [From ckpt] Fix from_ckpt (#3466) * Correct from_ckpt * make style * Update full dreambooth script to work with IF (#3425) * Add IF dreambooth docs (#3470) * parameterize pass single args through tuple (#3477) * attend and excite tests disable determinism on the class level (#3478) * dreambooth docs torch.compile note (#3471) * dreambooth docs torch.compile note * Update examples/dreambooth/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/README.md Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * add: if entry in the dreambooth training docs. (#3472) * [docs] Textual inversion inference (#3473) * add textual inversion inference to docs * add to toctree --------- Co-authored-by: Sayak Paul <[email protected]> * [docs] Distributed inference (#3376) * distributed inference * move to inference section * apply feedback * update with split_between_processes * apply feedback * [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479) explicit view kernel size as number elements in flattened indices * mps & onnx tests rework (#3449) * Remove ONNX tests from PR. They are already a part of push_tests.yml. * Remove mps tests from PRs. They are already performed on push. * Fix workflow name for fast push tests. * Extract mps tests to a workflow. For better control/filtering. * Remove --extra-index-url from mps tests * Increase tolerance of mps test This test passes in my Mac (Ventura 13.3) but fails in the CI hardware (Ventura 13.2). I ran the local tests following the same steps that exist in the CI workflow. * Temporarily run mps tests on pr So we can test. * Revert "Temporarily run mps tests on pr" Tests passed, go back to running on push. * [Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457) * add: debugging to enabling memory efficient processing * add: better warning message. * [Docs] add note on local directory path. (#3397) add note on local directory path. Co-authored-by: Patrick von Platen <[email protected]> * Refactor full determinism (#3485) * up * fix more * Apply suggestions from code review * fix more * fix more * Check it * Remove 16:8 * fix more * fix more * fix more * up * up * Test only stable diffusion * Test only two files * up * Try out spinning up processes that can be killed * up * Apply suggestions from code review * up * up * Fix DPM single (#3413) * Fix DPM single * add test * fix one more bug * Apply suggestions from code review Co-authored-by: StAlKeR7779 <[email protected]> --------- Co-authored-by: StAlKeR7779 <[email protected]> * Add `use_Karras_sigmas` to DPMSolverSinglestepScheduler (#3476) * add use_karras_sigmas * add karras test * add doc * Adds local_files_only bool to prevent forced online connection (#3486) * make style * [Docs] Korean translation (optimization, training) (#3488) * feat) optimization kr translation * fix) typo, italic setting * feat) dreambooth, text2image kr * feat) lora kr * fix) LoRA * fix) fp16 fix * fix) doc-builder style * fix) fp16 일부 단어 수정 * fix) fp16 style fix * fix) opt, training docs update * feat) toctree update * feat) toctree update --------- Co-authored-by: Chanran Kim <[email protected]> * DataLoader respecting EXIF data in Training Images (#3465) * DataLoader will now bake in any transforms or image manipulations contained in the EXIF Images may have rotations stored in EXIF. Training using such images will cause those transforms to be ignored while training and thus produce unexpected results * Fixed the Dataloading EXIF issue in main DreamBooth training as well * Run make style (black & isort) * make style * feat: allow disk offload for diffuser models (#3285) * allow disk offload for diffuser models * sort import * add max_memory argument * Changed sample[0] to images[0] (#3304) A pipeline object stores the results in `images` not in `sample`. Current code blocks don't work. * Typo in tutorial (#3295) * Torch compile graph fix (#3286) * fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test * Postprocessing refactor img2img (#3268) * refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * [Torch 2.0 compile] Fix more torch compile breaks (#3313) * Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]> * fix: scale_lr and sync example readme and docs. (#3299) * fix: scale_lr and sync example readme and docs. * fix doc link. * Update stable_diffusion.mdx (#3310) fixed import statement * Fix missing variable assign in DeepFloyd-IF-II (#3315) Fix missing variable assign lol * Correct doc build for patch releases (#3316) Update build_documentation.yml * Add Stable Diffusion RePaint to community pipelines (#3320) * Add Stable Diffsuion RePaint to community pipelines - Adds Stable Diffsuion RePaint to community pipelines - Add Readme enty for pipeline * Fix: Remove wrong import - Remove wrong import - Minor change in comments * Fix: Code formatting of stable_diffusion_repaint * Fix: ruff errors in stable_diffusion_repaint * Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314) * fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]> * [docs] Improve LoRA docs (#3311) * update docs * add to toctree * apply feedback * Added input pretubation (#3292) * Added input pretubation * Fixed spelling * Update write_own_pipeline.mdx (#3323) * update controlling generation doc with latest goodies. (#3321) * [Quality] Make style (#3341) * Fix config dpm (#3343) * Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344) * add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo * Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. * Rename --only_save_embeds to --save_as_full_pipeline (#3206) * Set --only_save_embeds to False by default Due to how the option is named, it makes more sense to behave like this. * Refactor only_save_embeds to save_as_full_pipeline * [AudioLDM] Generalise conversion script (#3328) Co-authored-by: Patrick von Platen <[email protected]> * Fix TypeError when using prompt_embeds and negative_prompt (#2982) * test: Added test case * fix: fixed type checking issue on _encode_prompt * fix: fixed copies consistency * fix: one copy was not sufficient * Fix pipeline class on README (#3345) Update README.md * Inpainting: typo in docs (#3331) Typo in docs Co-authored-by: Patrick von Platen <[email protected]> * Add `use_Karras_sigmas` to LMSDiscreteScheduler (#3351) * add karras sigma to lms discrete scheduler * add test for lms_scheduler karras * reformat test lms * Batched load of textual inversions (#3277) * Batched load of textual inversions - Only call resize_token_embeddings once per batch as it is the most expensive operation - Allow pretrained_model_name_or_path and token to be an optional list - Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function - Add comment that single files (e.g. .pt/.safetensors) are supported - Add comment for token parameter - Convert token override log message from warning to info * Update src/diffusers/loaders.py Check for duplicate tokens Co-authored-by: Patrick von Platen <[email protected]> * Update condition for None tokens --------- Co-authored-by: Patrick von Platen <[email protected]> * make fix-copies * [docs] Fix docstring (#3334) fix docstring Co-authored-by: Patrick von Platen <[email protected]> * if dreambooth lora (#3360) * update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings * Postprocessing refactor all others (#3337) * add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [docs] Improve safetensors docstring (#3368) * clarify safetensor docstring * fix typo * apply feedback * add: a warning message when using xformers in a PT 2.0 env. (#3365) * add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> * StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322) * StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [docs] Adapt a model (#3326) * first draft * apply feedback * conv_in.weight thrown away * [docs] Load safetensors (#3333) * safetensors * apply feedback * apply feedback * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [Docs] Fix stable_diffusion.mdx typo (#3398) Fix typo in last code block. Correct "prommpts" to "prompt" * Support ControlNet v1.1 shuffle properly (#3340) * add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]> * [Tests] better determinism (#3374) * enable deterministic pytorch and cuda operations. * disable manual seeding. * make style && make quality for unet_2d tests. * enable determinism for the unet2dconditional model. * add CUBLAS_WORKSPACE_CONFIG for better reproducibility. * relax tolerance (very weird issue, though). * revert to torch manual_seed() where needed. * relax more tolerance. * better placement of the cuda variable and relax more tolerance. * enable determinism for 3d condition model. * relax tolerance. * add: determinism to alt_diffusion. * relax tolerance for alt diffusion. * dance diffusion. * dance diffusion is flaky. * test_dict_tuple_outputs_equivalent edit. * fix two more tests. * fix more ddim tests. * fix: argument. * change to diff in place of difference. * fix: test_save_load call. * test_save_load_float16 call. * fix: expected_max_diff * fix: paint by example. * relax tolerance. * add determinism to 1d unet model. * torch 2.0 regressions seem to be brutal * determinism to vae. * add reason to skipping. * up tolerance. * determinism to vq. * determinism to cuda. * determinism to the generic test pipeline file. * refactor general pipelines testing a bit. * determinism to alt diffusion i2i * up tolerance for alt diff i2i and audio diff * up tolerance. * determinism to audioldm * increase tolerance for audioldm lms. * increase tolerance for paint by paint. * increase tolerance for repaint. * determinism to cycle diffusion and sd 1. * relax tol for cycle diffusion 🚲 * relax tol for sd 1.0 * relax tol for controlnet. * determinism to img var. * relax tol for img variation. * tolerance to i2i sd * make style * determinism to inpaint. * relax tolerance for inpaiting. * determinism for inpainting legacy * relax tolerance. * determinism to instruct pix2pix * determinism to model editing. * model editing tolerance. * panorama determinism * determinism to pix2pix zero. * determinism to sag. * sd 2. determinism * sd. tolerance * disallow tf32 matmul. * relax tolerance is all you need. * make style and determinism to sd 2 depth * relax tolerance for depth. * tolerance to diffedit. * tolerance to sd 2 inpaint. * up tolerance. * determinism in upscaling. * tolerance in upscaler. * more tolerance relaxation. * determinism to v pred. * up tol for v_pred * unclip determinism * determinism to unclip img2img * determinism to text to video. * determinism to last set of tests * up tol. * vq cumsum doesn't have a deterministic kernel * relax tol * relax tol * [docs] Add transformers to install (#3388) add transformers to install * [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add omegaconf for tests (#3400) Add omegaconfg * Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353) * Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]> * Fix docker file (#3402) * up * up * fix: deepseepd_plugin retrieval from accelerate state (#3410) * [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399) * Add `sigmoid` beta scheduler to `DDPMScheduler` docstring * Add `sigmoid` beta scheduler to `RePaintScheduler` docstring --------- Co-authored-by: Patrick von Platen <[email protected]> * Don't install accelerate and transformers from source (#3415) * Don't install transformers and accelerate from source (#3414) * Improve fast tests (#3416) Update pr_tests.yml * attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten * [Docs] update the PT 2.0 optimization doc with latest findings (#3370) * add: benchmarking stats for A100 and V100. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address patrick's comments. * add: rtx 4090 stats * ⚔ benchmark reports done * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * 3313 pr link. * add: plots. Co-authored-by: Pedro <[email protected]> * fix formattimg * update number percent. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Fix style rendering (#3433) * Fix style rendering. * Fix typo * unCLIP scheduler do not use note (#3417) * Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen <[email protected]> * …
* Fix a bug of pano when not doing CFG (#3030) * Fix a bug of pano when not doing CFG * enhance code quality * apply formatting. --------- Co-authored-by: Sayak Paul <[email protected]> * Text2video zero refinements (#3070) * fix progress bar issue in pipeline_text_to_video_zero.py. Copy scheduler after first backward * fix tensor loading in test_text_to_video_zero.py * make style && make quality * Release: v0.15.0 * [Tests] Speed up panorama tests (#3067) * fix: norm group test for UNet3D. * chore: speed up the panorama tests (fast). * set default value of _test_inference_batch_single_identical. * fix: batch_sizes default value. * [Post release] v0.16.0dev (#3072) * Adds profiling flags, computes train metrics average. (#3053) * WIP controlnet training - bugfix --streaming - bugfix running report_to!='wandb' - adds memory profile before validation * Adds final logging statement. * Sets train epochs to 11. Looking at a longer ~16ep run, we see only good validation images after ~11ep: https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8 * Removes --logging_dir (it's not used). * Adds --profile flags. * Updates --output_dir=runs/fill-circle-{timestamp}. * Compute mean of `train_metrics`. Previously `train_metrics[-1]` was logged, resulting in very bumpy train metrics. * Improves logging a bit. - adds l2_grads gradient norm logging - adds steps_per_sec - sets walltime as x coordinate of train/step - logs controlnet_params config * Adds --ccache (doesn't really help though). * minor fix in controlnet flax example (#2986) * fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide * Bugfix --profile_steps * Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`. * Logs fractional epoch. * Adds relative `walltime` metric. * Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`. * Applied `black`. * Streamlines commands in README a bit. * Removes `--ccache`. This makes only a very small difference (~1 min) with this model size, so removing the option introduced in cdb3cc. * Re-ran `black`. * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Converts spaces to tab. * Removes repeated args. * Skips first step (compilation) in profiling * Updates README with profiling instructions. * Unifies tabs/spaces in README. * Re-ran style & quality. --------- Co-authored-by: Sayak Paul <[email protected]> * [Pipelines] Make sure that None functions are correctly not saved (#3080) * doc string example remove from_pt (#3083) * [Tests] parallelize (#3078) * [Tests] parallelize * finish folder structuring * Parallelize tests more * Correct saving of pipelines * make sure logging level is correct * try again * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Pedro Cuenca <[email protected]> * Throw deprecation warning for return_cached_folder (#3092) Throw deprecation warning * Allow SD attend and excite pipeline to work with any size output images (#2835) Allow stable diffusion attend and excite pipeline to work with any size output image. Re: #2476, #2603 * [docs] Update community pipeline docs (#2989) * update community pipeline docs * fix formatting * explain sharing workflows * Add to support Guess Mode for StableDiffusionControlnetPipleline (#2998) * add guess mode (WIP) * fix uncond/cond order * support guidance_scale=1.0 and batch != 1 * remove magic coeff * add docstring * add intergration test * add document to controlnet.mdx * made the comments a bit more explanatory * fix table * fix default value for attend-and-excite (#3099) * fix default * remvoe one line as requested by gc team (#3077) remvoe one line * ddpm custom timesteps (#3007) add custom timesteps test add custom timesteps descending order check docs timesteps -> custom_timesteps can only pass one of num_inference_steps and timesteps * Fix breaking change in `pipeline_stable_diffusion_controlnet.py` (#3118) fix breaking change * Add global pooling to controlnet (#3121) * [Bug fix] Fix img2img processor with safety checker (#3127) Fix img2img processor with safety checker * [Bug fix] Make sure correct timesteps are chosen for img2img (#3128) Make sure correct timesteps are chosen for img2img * Improve deprecation warnings (#3131) * Fix config deprecation (#3129) * Better deprecation message * Better deprecation message * Better doc string * Fixes * fix more * fix more * Improve __getattr__ * correct more * fix more * fix * Improve more * more improvements * fix more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * make style * Fix all rest & add tests & remove old deprecation fns --------- Co-authored-by: Pedro Cuenca <[email protected]> * feat: verfication of multi-gpu support for select examples. (#3126) * feat: verfication of multi-gpu support for select examples. * add: multi-gpu training sections to the relvant doc pages. * speed up attend-and-excite fast tests (#3079) * Optimize log_validation in train_controlnet_flax (#3110) extract pipeline from log_validation * make style * Correct textual inversion readme (#3145) * Update README.md * Apply suggestions from code review * Add unet act fn to other model components (#3136) Adding act fn config to the unet timestep class embedding and conv activation. The custom activation defaults to silu which is the default activation function for both the conv act and the timestep class embeddings so default behavior is not changed. The only unet which use the custom activation is the stable diffusion latent upscaler https://huggingface.co/stabilityai/sd-x2-latent-upscaler/blob/main/unet/config.json (I ran a script against the hub to confirm). The latent upscaler does not use the conv activation nor the timestep class embeddings so we don't change its behavior. * class labels timestep embeddings projection dtype cast (#3137) This mimics the dtype cast for the standard time embeddings * [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705) * [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model * Address review comment from PR * PyLint formatting * Some more pylint fixes, unrelated to our change * Another pylint fix * Styling fix * add from_ckpt method as Mixin (#2318) * add mixin class for pipeline from original sd ckpt * Improve * make style * merge main into * Improve more * fix more * up * Apply suggestions from code review * finish docs * rename * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add TensorRT SD/txt2img Community Pipeline to diffusers along with TensorRT utils (#2974) * Add SD/txt2img Community Pipeline to diffusers along with TensorRT utils Signed-off-by: Asfiya Baig <[email protected]> * update installation command Signed-off-by: Asfiya Baig <[email protected]> * update tensorrt installation Signed-off-by: Asfiya Baig <[email protected]> * changes 1. Update setting of cache directory 2. Address comments: merge utils and pipeline code. 3. Address comments: Add section in README Signed-off-by: Asfiya Baig <[email protected]> * apply make style Signed-off-by: Asfiya Baig <[email protected]> --------- Signed-off-by: Asfiya Baig <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * Correct `Transformer2DModel.forward` docstring (#3074) ⚙️chore(transformer_2d) update function signature for encoder_hidden_states * Update pipeline_stable_diffusion_inpaint_legacy.py (#2903) * Update pipeline_stable_diffusion_inpaint_legacy.py * fix preprocessing of Pil images with adequate batch size * revert map * add tests * reformat * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * next try to fix the style * wth is this * Update testing_utils.py * Update testing_utils.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py * Update test_stable_diffusion_inpaint_legacy.py --------- Co-authored-by: Patrick von Platen <[email protected]> * Modified altdiffusion pipline to support altdiffusion-m18 (#2993) * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 * Modified altdiffusion pipline to support altdiffusion-m18 --------- Co-authored-by: root <[email protected]> * controlnet training resize inputs to multiple of 8 (#3135) controlnet training center crop input images to multiple of 8 The pipeline code resizes inputs to multiples of 8. Not doing this resizing in the training script is causing the encoded image to have different height/width dimensions than the encoded conditioning image (which uses a separate encoder that's part of the controlnet model). We resize and center crop the inputs to make sure they're the same size (as well as all other images in the batch). We also check that the initial resolution is a multiple of 8. * adding custom diffusion training to diffusers examples (#3031) * diffusers==0.14.0 update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion * custom diffusion * custom diffusion * custom diffusion * custom diffusion * apply formatting and get rid of bare except. * refactor readme and other minor changes. * misc refactor. * fix: repo_id issue and loaders logging bug. * fix: save_model_card. * fix: save_model_card. * fix: save_model_card. * add: doc entry. * refactor doc,. * custom diffusion * custom diffusion * custom diffusion * apply style. * remove tralining whitespace. * fix: toctree entry. * remove unnecessary print. * custom diffusion * custom diffusion * custom diffusion test * custom diffusion xformer update * custom diffusion xformer update * custom diffusion xformer update --------- Co-authored-by: Nupur Kumari <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Nupur Kumari <[email protected]> * make style * Update custom_diffusion.mdx (#3165) Add missing newlines for rendering the links correctly * Added distillation for quantization example on textual inversion. (#2760) * Added distillation for quantization example on textual inversion. Signed-off-by: Ye, Xinyu <[email protected]> * refined readme and code style. Signed-off-by: Ye, Xinyu <[email protected]> * Update text2images.py * refined code of model load and added compatibility check. Signed-off-by: Ye, Xinyu <[email protected]> * fixed code style. Signed-off-by: Ye, Xinyu <[email protected]> * fix C403 [*] Unnecessary `list` comprehension (rewrite as a `set` comprehension) Signed-off-by: Ye, Xinyu <[email protected]> --------- Signed-off-by: Ye, Xinyu <[email protected]> * Update Noise Autocorrelation Loss Function for Pix2PixZero Pipeline (#2942) * Update Pix2PixZero Auto-correlation Loss * Add fast inversion tests * Clarify purpose and mark as deprecated Fix inversion prompt broadcasting * Register modules set to `None` in config for `test_save_load_optional_components` * Update new tests to coordinate with #2953 * [DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130) * add: LoRA text encoder support for DreamBooth example. * fix initialization. * fix: modification call. * add: entry in the readme. * use dog dataset from hub. * fix: params to clip. * add entry to the LoRA doc. * add: tests for lora. * remove unnecessary list comprehension./ * Update Habana Gaudi documentation (#3169) * Update Habana Gaudi doc * Fix tables * Add model offload to x4 upscaler (#3187) * Add model offload to x4 upscaler * fix * [docs] Deterministic algorithms (#3172) deterministic algos * Update custom_diffusion.mdx to credit the author (#3163) * Update custom_diffusion.mdx * fix: unnecessary list comprehension. * Fix TensorRT community pipeline device set function (#3157) pass silence_dtype_warnings as kwarg Signed-off-by: Asfiya Baig <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * make `from_flax` work for controlnet (#3161) fix from_flax Co-authored-by: Patrick von Platen <[email protected]> * [docs] Clarify training args (#3146) * clarify training arg * apply feedback * Multi Vector Textual Inversion (#3144) * Multi Vector * Improve * fix multi token * improve test * make style * Update examples/test_examples.py * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * update * Finish * Apply suggestions from code review --------- Co-authored-by: Suraj Patil <[email protected]> * Add `Karras sigmas` to HeunDiscreteScheduler (#3160) * Add karras pattern to discrete heun scheduler * Add integration test * Fix failing CI on pytorch test on M1 (mps) --------- Co-authored-by: Patrick von Platen <[email protected]> * [AudioLDM] Fix dtype of returned waveform (#3189) * Fix bug in train_dreambooth_lora (#3183) * Update train_dreambooth_lora.py fix bug * Update train_dreambooth_lora.py * [Community Pipelines] Update lpw_stable_diffusion pipeline (#3197) * Update lpw_stable_diffusion.py * fix cpu offload * Make sure VAE attention works with Torch 2_0 (#3200) * Make sure attention works with Torch 2_0 * make style * Fix more * Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201) Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)" This reverts commit 9965cb50eac12e397473f01535aab43aae76b4ab. * [Bug fix] Fix batch size attention head size mismatch (#3214) * fix mixed precision training on train_dreambooth_inpaint_lora (#3138) cast to weight dtype * adding enable_vae_tiling and disable_vae_tiling functions (#3225) adding enable_vae_tiling and disable_val_tiling functions * Add ControlNet v1.1 docs (#3226) Add v1.1 docs * Fix issue in maybe_convert_prompt (#3188) When the token used for textual inversion does not have any special symbols (e.g. it is not surrounded by <>), the tokenizer does not properly split the replacement tokens. Adding a space for the padding tokens fixes this. * Sync cache version check from transformers (#3179) sync cache version check from transformers * Fix docs text inversion (#3166) * Fix docs text inversion * Apply suggestions from code review * add model (#3230) * add * clean * up * clean up more * fix more tests * Improve docs further * improve * more fixes docs * Improve docs more * Update src/diffusers/models/unet_2d_condition.py * fix * up * update doc links * make fix-copies * add safety checker and watermarker to stage 3 doc page code snippets * speed optimizations docs * memory optimization docs * make style * add watermarking snippets to doc string examples * make style * use pt_to_pil helper functions in doc strings * skip mps tests * Improve safety * make style * new logic * fix * fix bad onnx design * make new stable diffusion upscale pipeline model arguments optional * define has_nsfw_concept when non-pil output type * lowercase linked to notebook name --------- Co-authored-by: William Berman <[email protected]> * Allow return pt x4 (#3236) * Add all files * update * Allow fp16 attn for x4 upscaler (#3239) * Add all files * update * Make sure vae is memory efficient for PT 1 * make style * fix fast test (#3241) * Adds a document on token merging (#3208) * add document on token merging. * fix headline. * fix: headline. * add some samples for comparison. * [AudioLDM] Update docs to use updated ckpt (#3240) * [AudioLDM] Update docs to use updated ckpt * make style * Release: v0.16.0 * Post release for 0.16.0 (#3244) * Post release * fix more * [docs] only mention one stage (#3246) * [docs] only mention one stage * add blurb on auto accepting --------- Co-authored-by: William Berman <[email protected]> * Write model card in controlnet training script (#3229) Write model card in controlnet training script. * [2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020) * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * Review comments * [Review comment]: Add is_torchsde_available() * [Review comment]: Test and docs * [Review comment] * [Review comment] * [Review comment] * [Review comment] * [Review comment] --------- Co-authored-by: njindal <[email protected]> * [Stochastic Sampler][Slow Test]: Cuda test fixes (#3257) [Slow Test]: Cuda test fixes Co-authored-by: njindal <[email protected]> * Remove required from tracker_project_name (#3260) Remove required from tracker_project_name. As observed by https://github.com/off99555 in https://github.com/huggingface/diffusers/issues/2695#issuecomment-1470755050, it already has a default value. * adding required parameters while calling the get_up_block and get_down_block (#3210) * removed unnecessary parameters from get_up_block and get_down_block functions * adding resnet_skip_time_act, resnet_out_scale_factor and cross_attention_norm to get_up_block and get_down_block functions --------- Co-authored-by: Sayak Paul <[email protected]> * [docs] Update interface in repaint.mdx (#3119) Update repaint.mdx accomodate to #1701 * Update IF name to XL (#3262) Co-authored-by: multimodalart <[email protected]> * fix typo in score sde pipeline (#3132) * Fix typo in textual inversion JAX training script (#3123) The pipeline is built as `pipe` but then used as `pipeline`. * AudioDiffusionPipeline - fix encode method after config changes (#3114) * config fixes * deprecate get_input_dims * Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline"" (#3265) Revert "Revert "[Community Pipelines] Update lpw_stable_diffusion pipeline" (#3201)" This reverts commit 91a2a80eb2f98a9f64b9e287715add244dc6f2f3. * Fix community pipelines (#3266) * update notebook (#3259) Co-authored-by: yiyixuxu <[email protected]> * [docs] add notes for stateful model changes (#3252) * [docs] add notes for stateful model changes * Update docs/source/en/optimization/fp16.mdx Co-authored-by: Pedro Cuenca <[email protected]> * link to accelerate docs for discarding hooks --------- Co-authored-by: Pedro Cuenca <[email protected]> * [LoRA] quality of life improvements in the loading semantics and docs (#3180) * 👽 qol improvements for LoRA. * better function name? * fix: LoRA weight loading with the new format. * address Patrick's comments. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * change wording around encouraging the use of load_lora_weights(). * fix: function name. --------- Co-authored-by: Patrick von Platen <[email protected]> * [Community Pipelines] EDICT pipeline implementation (#3153) * EDICT pipeline initial commit - Starting point taking from https://github.com/Joqsan/edict-diffusion * refactor __init__() method * minor refactoring * refactor scheduler code - remove scheduler and move its methods to the EDICTPipeline class * make CFG optional - refactor encode_prompt(). - include optional generator for sampling with vae. - minor variable renaming * add EDICT pipeline description to README.md * replace preprocess() with VaeImageProcessor * run make style and make quality commands --------- Co-authored-by: Patrick von Platen <[email protected]> * [Docs]zh translated docs update (#3245) * zh translated docs update * update _toctree * Update logging.mdx (#2863) Fix typos * Add multiple conditions to StableDiffusionControlNetInpaintPipeline (#3125) * try multi controlnet inpaint * multi controlnet inpaint * multi controlnet inpaint * Let's make sure that dreambooth always uploads to the Hub (#3272) * Update Dreambooth README * Adapt all docs as well * automatically write model card * fix * make style * Diffedit Zero-Shot Inpainting Pipeline (#2837) * Update Pix2PixZero Auto-correlation Loss * Add Stable Diffusion DiffEdit pipeline * Add draft documentation and import code * Bugfixes and refactoring * Add option to not decode latents in the inversion process * Harmonize preprocessing * Revert "Update Pix2PixZero Auto-correlation Loss" This reverts commit b218062fed08d6cc164206d6cb852b2b7b00847a. * Update annotations * rename `compute_mask` to `generate_mask` * Update documentation * Update docs * Update Docs * Fix copy * Change shape of output latents to batch first * Update docs * Add first draft for tests * Bugfix and update tests * Add `cross_attention_kwargs` support for all pipeline methods * Fix Copies * Add support for PIL image latents Add support for mask broadcasting Update docs and tests Align `mask` argument to `mask_image` Remove height and width arguments * Enable MPS Tests * Move example docstrings * Fix test * Fix test * fix pipeline inheritance * Harmonize `prepare_image_latents` with StableDiffusionPix2PixZeroPipeline * Register modules set to `None` in config for `test_save_load_optional_components` * Move fixed logic to specific test class * Clean changes to other pipelines * Update new tests to coordinate with #2953 * Update slow tests for better results * Safety to avoid potential problems with torch.inference_mode * Add reference in SD Pipeline Overview * Fix tests again * Enforce determinism in noise for generate_mask * Fix copies * Widen test tolerance for fp16 based on `test_stable_diffusion_upscale_pipeline_fp16` * Add LoraLoaderMixin and update `prepare_image_latents` * clean up repeat and reg * bugfix * Remove invalid args from docs Suppress spurious warning by repeating image before latent to mask gen * add constant learning rate with custom rule (#3133) * add constant lr with rules * add constant with rules in TYPE_TO_SCHEDULER_FUNCTION * add constant lr rate with rule * hotfix code quality * fix doc style * change name constant_with_rules to piecewise constant * Allow disabling torch 2_0 attention (#3273) * Allow disabling torch 2_0 attention * make style * Update src/diffusers/models/attention.py * [doc] add link to training script (#3271) add link to training script Co-authored-by: yiyixuxu <[email protected]> * temp disable spectogram diffusion tests (#3278) The note-seq package throws an error on import because the default installed version of Ipython is not compatible with python 3.8 which we run in the CI. https://github.com/huggingface/diffusers/actions/runs/4830121056/jobs/8605954838#step:7:9 * Changed sample[0] to images[0] (#3304) A pipeline object stores the results in `images` not in `sample`. Current code blocks don't work. * Typo in tutorial (#3295) * Torch compile graph fix (#3286) * fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test * Postprocessing refactor img2img (#3268) * refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * [Torch 2.0 compile] Fix more torch compile breaks (#3313) * Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]> * fix: scale_lr and sync example readme and docs. (#3299) * fix: scale_lr and sync example readme and docs. * fix doc link. * Update stable_diffusion.mdx (#3310) fixed import statement * Fix missing variable assign in DeepFloyd-IF-II (#3315) Fix missing variable assign lol * Correct doc build for patch releases (#3316) Update build_documentation.yml * Add Stable Diffusion RePaint to community pipelines (#3320) * Add Stable Diffsuion RePaint to community pipelines - Adds Stable Diffsuion RePaint to community pipelines - Add Readme enty for pipeline * Fix: Remove wrong import - Remove wrong import - Minor change in comments * Fix: Code formatting of stable_diffusion_repaint * Fix: ruff errors in stable_diffusion_repaint * Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314) * fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]> * [docs] Improve LoRA docs (#3311) * update docs * add to toctree * apply feedback * Added input pretubation (#3292) * Added input pretubation * Fixed spelling * Update write_own_pipeline.mdx (#3323) * update controlling generation doc with latest goodies. (#3321) * [Quality] Make style (#3341) * Fix config dpm (#3343) * Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344) * add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo * Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. * Add UniDiffuser classes to __init__ files, modify transformer block to support pre- and post-LN, add fast default tests, fix some bugs. * Update fast tests to use test checkpoints stored on the hub and to better match the reference UniDiffuser implementation. * Fix code with make style. * Revert "Fix code style with make style." This reverts commit 10a174a12c82e6abd3d5a57665719a03dbb85ca7. * Add self.image_encoder, self.text_decoder to list of models to offload to CPU in the enable_sequential_cpu_offload(...)/enable_model_cpu_offload(...) methods to make test_cpu_offload_forward_pass pass. * Fix code quality with make style. * Support using a data type embedding for UniDiffuser-v1. * Add fast test for checking UniDiffuser-v1 sampling. * Make changes so that the repository consistency tests pass. * Add UniDiffuser dummy objects via make fix-copies. * Fix bugs and make improvements to the UniDiffuser pipeline: - Improve batch size inference and fix bugs when num_images_per_prompt or num_prompts_per_image > 1 - Add tests for num_images_per_prompt, num_prompts_per_image > 1 - Improve check_inputs, especially regarding checking supplied latents - Add reset_mode method so that mode inference can be re-enabled after mode is set manually - Fix some warnings related to accessing class members directly instead of through their config - Small amount of refactoring in pipeline_unidiffuser.py * Fix code style with make style. * Add/edit docstrings for added classes and public pipeline methods. Also do some light refactoring. * Add documentation for UniDiffuser and fix some typos/formatting in docstrings. * Fix code with make style. * Refactor and improve the UniDiffuser convert_from_ckpt.py script. * Move the UniDiffusers convert_from_ckpy.py script to diffusers/scripts/convert_unidiffuser_to_diffusers.py * Fix code quality via make style. * Improve UniDiffuser slow tests. * make style * Fix some typos in the UniDiffuser docs. * Remove outdated logic based on transformers version in UniDiffuser pipeline __init__.py * Remove dependency on einops by refactoring einops operations to pure torch operations. * make style * Add slow test on full checkpoint for joint mode and correct expected image slices/text prefixes. * make style * Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager. * Revert "Fix mixed precision issue by wrapping the offending code with the torch.autocast context manager." This reverts commit 1a58958ab4f024dbc4c90a6404c2e66210db6d00. * Add fast test for CUDA/fp16 model behavior (currently failing). * Fix the mixed precision issue and add additional tests of the pipeline cuda/fp16 functionality. * make style * Use a CLIPVisionModelWithProjection instead of CLIPVisionModel for image_encoder to better match the original UniDiffuser implementation. * Make style and remove some testing code. * Fix shape errors for the 'joint' and 'img2text' modes. * Fix tests and remove some testing code. * Add option to use fixed latents for UniDiffuserPipelineSlowTests and fix issue in modeling_text_decoder.py. * Improve UniDiffuser docs, particularly the usage examples, and improve slow tests with new expected outputs. * make style * Fix examples to load model in float16. * In image-to-text mode, sample from the autoencoder moment distribution instead of always getting its mode. * make style * When encoding the image using the VAE, scale the image latents by the VAE's scaling factor. * make style * Clean up code and make slow tests pass. * make fix-copies * [docs] Fix docstring (#3334) fix docstring Co-authored-by: Patrick von Platen <[email protected]> * if dreambooth lora (#3360) * update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings * Postprocessing refactor all others (#3337) * add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [docs] Improve safetensors docstring (#3368) * clarify safetensor docstring * fix typo * apply feedback * add: a warning message when using xformers in a PT 2.0 env. (#3365) * add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> * StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322) * StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [docs] Adapt a model (#3326) * first draft * apply feedback * conv_in.weight thrown away * [docs] Load safetensors (#3333) * safetensors * apply feedback * apply feedback * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [Docs] Fix stable_diffusion.mdx typo (#3398) Fix typo in last code block. Correct "prommpts" to "prompt" * Support ControlNet v1.1 shuffle properly (#3340) * add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]> * [Tests] better determinism (#3374) * enable deterministic pytorch and cuda operations. * disable manual seeding. * make style && make quality for unet_2d tests. * enable determinism for the unet2dconditional model. * add CUBLAS_WORKSPACE_CONFIG for better reproducibility. * relax tolerance (very weird issue, though). * revert to torch manual_seed() where needed. * relax more tolerance. * better placement of the cuda variable and relax more tolerance. * enable determinism for 3d condition model. * relax tolerance. * add: determinism to alt_diffusion. * relax tolerance for alt diffusion. * dance diffusion. * dance diffusion is flaky. * test_dict_tuple_outputs_equivalent edit. * fix two more tests. * fix more ddim tests. * fix: argument. * change to diff in place of difference. * fix: test_save_load call. * test_save_load_float16 call. * fix: expected_max_diff * fix: paint by example. * relax tolerance. * add determinism to 1d unet model. * torch 2.0 regressions seem to be brutal * determinism to vae. * add reason to skipping. * up tolerance. * determinism to vq. * determinism to cuda. * determinism to the generic test pipeline file. * refactor general pipelines testing a bit. * determinism to alt diffusion i2i * up tolerance for alt diff i2i and audio diff * up tolerance. * determinism to audioldm * increase tolerance for audioldm lms. * increase tolerance for paint by paint. * increase tolerance for repaint. * determinism to cycle diffusion and sd 1. * relax tol for cycle diffusion 🚲 * relax tol for sd 1.0 * relax tol for controlnet. * determinism to img var. * relax tol for img variation. * tolerance to i2i sd * make style * determinism to inpaint. * relax tolerance for inpaiting. * determinism for inpainting legacy * relax tolerance. * determinism to instruct pix2pix * determinism to model editing. * model editing tolerance. * panorama determinism * determinism to pix2pix zero. * determinism to sag. * sd 2. determinism * sd. tolerance * disallow tf32 matmul. * relax tolerance is all you need. * make style and determinism to sd 2 depth * relax tolerance for depth. * tolerance to diffedit. * tolerance to sd 2 inpaint. * up tolerance. * determinism in upscaling. * tolerance in upscaler. * more tolerance relaxation. * determinism to v pred. * up tol for v_pred * unclip determinism * determinism to unclip img2img * determinism to text to video. * determinism to last set of tests * up tol. * vq cumsum doesn't have a deterministic kernel * relax tol * relax tol * [docs] Add transformers to install (#3388) add transformers to install * [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add omegaconf for tests (#3400) Add omegaconfg * Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353) * Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]> * Fix docker file (#3402) * up * up * fix: deepseepd_plugin retrieval from accelerate state (#3410) * [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399) * Add `sigmoid` beta scheduler to `DDPMScheduler` docstring * Add `sigmoid` beta scheduler to `RePaintScheduler` docstring --------- Co-authored-by: Patrick von Platen <[email protected]> * Don't install accelerate and transformers from source (#3415) * Don't install transformers and accelerate from source (#3414) * Improve fast tests (#3416) Update pr_tests.yml * attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten * [Docs] update the PT 2.0 optimization doc with latest findings (#3370) * add: benchmarking stats for A100 and V100. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address patrick's comments. * add: rtx 4090 stats * ⚔ benchmark reports done * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * 3313 pr link. * add: plots. Co-authored-by: Pedro <[email protected]> * fix formattimg * update number percent. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Fix style rendering (#3433) * Fix style rendering. * Fix typo * unCLIP scheduler do not use note (#3417) * Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen <[email protected]> * fix warning message pipeline loading (#3446) * add stable diffusion tensorrt img2img pipeline (#3419) * add stable diffusion tensorrt img2img pipeline Signed-off-by: Asfiya Baig <[email protected]> * update docstrings Signed-off-by: Asfiya Baig <[email protected]> --------- Signed-off-by: Asfiya Baig <[email protected]> * Refactor controlnet and add img2img and inpaint (#3386) * refactor controlnet and add img2img and inpaint * First draft to get pipelines to work * make style * Fix more * Fix more * More tests * Fix more * Make inpainting work * make style and more tests * Apply suggestions from code review * up * make style * Fix imports * Fix more * Fix more * Improve examples * add test * Make sure import is correctly deprecated * Make sure everything works in compile mode * make sure authorship is correctly attributed * [Scheduler] DPM-Solver (++) Inverse Scheduler (#3335) * Add DPM-Solver Multistep Inverse Scheduler * Add draft tests for DiffEdit * Add inverse sde-dpmsolver steps to tune image diversity from inverted latents * Fix tests --------- Co-authored-by: Patrick von Platen <[email protected]> * [Docs] Fix incomplete docstring for resnet.py (#3438) Fix incomplete docstrings for resnet.py * fix tiled vae blend extent range (#3384) fix tiled vae bleand extent range * Small update to "Next steps" section (#3443) Small update to "Next steps" section: - PyTorch 2 is recommended. - Updated improvement figures. * Allow arbitrary aspect ratio in IFSuperResolutionPipeline (#3298) * Update pipeline_if_superresolution.py Allow arbitrary aspect ratio in IFSuperResolutionPipeline by using the input image shape * IFSuperResolutionPipeline: allow the user to override the height and width through the arguments * update IFSuperResolutionPipeline width/height doc string to match StableDiffusionInpaintPipeline conventions --------- Co-authored-by: Patrick von Platen <[email protected]> * Adding 'strength' parameter to StableDiffusionInpaintingPipeline (#3424) * Added explanation of 'strength' parameter * Added get_timesteps function which relies on new strength parameter * Added `strength` parameter which defaults to 1. * Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. * Added strength to check_inputs, throws error if out of range * Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. * WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values * Created a is_strength_max to initialise from pure random noise * Updated unit tests w.r.t new strength parameter + fixed new strength unit test * renamed parameter to avoid confusion with variable of same name * Updated regression values for new strength test - now passes * removed 'copied from' comment as this method is now different and divergent from the cpy * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen <[email protected]> * Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false * Ensure backwards compatibility for prepare_latents * Fixed copy check typo * Fixes w.r.t backward compibility changes * make style * keep function argument ordering same for backwards compatibility in callees with copied from statements * make fix-copies --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: William Berman <[email protected]> * [WIP] Bugfix - Pipeline.from_pretrained is broken when the pipeline is partially downloaded (#3448) Added bugfix using f strings. * Fix gradient checkpointing bugs in freezing part of models (requires_grad=False) (#3404) * gradient checkpointing bug fix * bug fix; changes for reviews * reformat * reformat --------- Co-authored-by: Patrick von Platen <[email protected]> * Make dreambooth lora more robust to orig unet (#3462) * Make dreambooth lora more robust to orig unet * up * Reduce peak VRAM by releasing large attention tensors (as soon as they're unnecessary) (#3463) Release large tensors in attention (as soon as they're no longer required). Reduces peak VRAM by nearly 2 GB for 1024x1024 (even after slicing), and the savings scale up with image size. * Add min snr to text2img lora training script (#3459) add min snr to text2img lora training script * Add inpaint lora scale support (#3460) * add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy <[email protected]> * [From ckpt] Fix from_ckpt (#3466) * Correct from_ckpt * make style * Update full dreambooth script to work with IF (#3425) * Add IF dreambooth docs (#3470) * parameterize pass single args through tuple (#3477) * attend and excite tests disable determinism on the class level (#3478) * dreambooth docs torch.compile note (#3471) * dreambooth docs torch.compile note * Update examples/dreambooth/README.md Co-authored-by: Sayak Paul <[email protected]> * Update examples/dreambooth/README.md Co-authored-by: Pedro Cuenca <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * add: if entry in the dreambooth training docs. (#3472) * [docs] Textual inversion inference (#3473) * add textual inversion inference to docs * add to toctree --------- Co-authored-by: Sayak Paul <[email protected]> * [docs] Distributed inference (#3376) * distributed inference * move to inference section * apply feedback * update with split_between_processes * apply feedback * [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479) explicit view kernel size as number elements in flattened indices * mps & onnx tests rework (#3449) * Remove ONNX tests from PR. They are already a part of push_tests.yml. * Remove mps tests from PRs. They are already performed on push. * Fix workflow name for fast push tests. * Extract mps tests to a workflow. For better control/filtering. * Remove --extra-index-url from mps tests * Increase tolerance of mps test This test passes in my Mac (Ventura 13.3) but fails in the CI hardware (Ventura 13.2). I ran the local tests following the same steps that exist in the CI workflow. * Temporarily run mps tests on pr So we can test. * Revert "Temporarily run mps tests on pr" Tests passed, go back to running on push. * [Attention processor] Better warning message when shifting to `AttnProcessor2_0` (#3457) * add: debugging to enabling memory efficient processing * add: better warning message. * [Docs] add note on local directory path. (#3397) add note on local directory path. Co-authored-by: Patrick von Platen <[email protected]> * Refactor full determinism (#3485) * up * fix more * Apply suggestions from code review * fix more * fix more * Check it * Remove 16:8 * fix more * fix more * fix more * up * up * Test only stable diffusion * Test only two files * up * Try out spinning up processes that can be killed * up * Apply suggestions from code review * up * up * Fix DPM single (#3413) * Fix DPM single * add test * fix one more bug * Apply suggestions from code review Co-authored-by: StAlKeR7779 <[email protected]> --------- Co-authored-by: StAlKeR7779 <[email protected]> * Add `use_Karras_sigmas` to DPMSolverSinglestepScheduler (#3476) * add use_karras_sigmas * add karras test * add doc * Adds local_files_only bool to prevent forced online connection (#3486) * make style * [Docs] Korean translation (optimization, training) (#3488) * feat) optimization kr translation * fix) typo, italic setting * feat) dreambooth, text2image kr * feat) lora kr * fix) LoRA * fix) fp16 fix * fix) doc-builder style * fix) fp16 일부 단어 수정 * fix) fp16 style fix * fix) opt, training docs update * feat) toctree update * feat) toctree update --------- Co-authored-by: Chanran Kim <[email protected]> * DataLoader respecting EXIF data in Training Images (#3465) * DataLoader will now bake in any transforms or image manipulations contained in the EXIF Images may have rotations stored in EXIF. Training using such images will cause those transforms to be ignored while training and thus produce unexpected results * Fixed the Dataloading EXIF issue in main DreamBooth training as well * Run make style (black & isort) * make style * feat: allow disk offload for diffuser models (#3285) * allow disk offload for diffuser models * sort import * add max_memory argument * Changed sample[0] to images[0] (#3304) A pipeline object stores the results in `images` not in `sample`. Current code blocks don't work. * Typo in tutorial (#3295) * Torch compile graph fix (#3286) * fix more * Fix more * fix more * Apply suggestions from code review * fix * make style * make fix-copies * fix * make sure torch compile * Clean * fix test * Postprocessing refactor img2img (#3268) * refactor img2img VaeImageProcessor.postprocess * remove copy from for init, run_safety_checker, decode_latents Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Sayak Paul <[email protected]> * [Torch 2.0 compile] Fix more torch compile breaks (#3313) * Fix more torch compile breaks * add tests * Fix all * fix controlnet * fix more * Add Horace He as co-author. > > Co-authored-by: Horace He <[email protected]> * Add Horace He as co-author. Co-authored-by: Horace He <[email protected]> --------- Co-authored-by: Horace He <[email protected]> * fix: scale_lr and sync example readme and docs. (#3299) * fix: scale_lr and sync example readme and docs. * fix doc link. * Update stable_diffusion.mdx (#3310) fixed import statement * Fix missing variable assign in DeepFloyd-IF-II (#3315) Fix missing variable assign lol * Correct doc build for patch releases (#3316) Update build_documentation.yml * Add Stable Diffusion RePaint to community pipelines (#3320) * Add Stable Diffsuion RePaint to community pipelines - Adds Stable Diffsuion RePaint to community pipelines - Add Readme enty for pipeline * Fix: Remove wrong import - Remove wrong import - Minor change in comments * Fix: Code formatting of stable_diffusion_repaint * Fix: ruff errors in stable_diffusion_repaint * Fix multistep dpmsolver for cosine schedule (suitable for deepfloyd-if) (#3314) * fix multistep dpmsolver for cosine schedule (deepfloy-if) * fix a typo * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/schedulers/scheduling_dpmsolver_multistep.py Co-authored-by: Patrick von Platen <[email protected]> * update all dpmsolver (singlestep, multistep, dpm, dpm++) for cosine noise schedule * add test, fix style --------- Co-authored-by: Patrick von Platen <[email protected]> * [docs] Improve LoRA docs (#3311) * update docs * add to toctree * apply feedback * Added input pretubation (#3292) * Added input pretubation * Fixed spelling * Update write_own_pipeline.mdx (#3323) * update controlling generation doc with latest goodies. (#3321) * [Quality] Make style (#3341) * Fix config dpm (#3343) * Add the SDE variant of DPM-Solver and DPM-Solver++ (#3344) * add SDE variant of DPM-Solver and DPM-Solver++ * add test * fix typo * fix typo * Add upsample_size to AttnUpBlock2D, AttnDownBlock2D (#3275) The argument `upsample_size` needs to be added to these modules to allow compatibility with other blocks that require this argument. * Rename --only_save_embeds to --save_as_full_pipeline (#3206) * Set --only_save_embeds to False by default Due to how the option is named, it makes more sense to behave like this. * Refactor only_save_embeds to save_as_full_pipeline * [AudioLDM] Generalise conversion script (#3328) Co-authored-by: Patrick von Platen <[email protected]> * Fix TypeError when using prompt_embeds and negative_prompt (#2982) * test: Added test case * fix: fixed type checking issue on _encode_prompt * fix: fixed copies consistency * fix: one copy was not sufficient * Fix pipeline class on README (#3345) Update README.md * Inpainting: typo in docs (#3331) Typo in docs Co-authored-by: Patrick von Platen <[email protected]> * Add `use_Karras_sigmas` to LMSDiscreteScheduler (#3351) * add karras sigma to lms discrete scheduler * add test for lms_scheduler karras * reformat test lms * Batched load of textual inversions (#3277) * Batched load of textual inversions - Only call resize_token_embeddings once per batch as it is the most expensive operation - Allow pretrained_model_name_or_path and token to be an optional list - Remove Dict from type annotation pretrained_model_name_or_path as it was not supported in this function - Add comment that single files (e.g. .pt/.safetensors) are supported - Add comment for token parameter - Convert token override log message from warning to info * Update src/diffusers/loaders.py Check for duplicate tokens Co-authored-by: Patrick von Platen <[email protected]> * Update condition for None tokens --------- Co-authored-by: Patrick von Platen <[email protected]> * make fix-copies * [docs] Fix docstring (#3334) fix docstring Co-authored-by: Patrick von Platen <[email protected]> * if dreambooth lora (#3360) * update IF stage I pipelines add fixed variance schedulers and lora loading * added kv lora attn processor * allow loading into alternative lora attn processor * make vae optional * throw away predicted variance * allow loading into added kv lora layer * allow load T5 * allow pre compute text embeddings * set new variance type in schedulers * fix copies * refactor all prompt embedding code class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable * fix for when variance type is not defined on scheduler * do not pre compute validation prompt if not present * add example test for if lora dreambooth * add check for train text encoder and pre compute text embeddings * Postprocessing refactor all others (#3337) * add text2img * fix-copies * add * add all other pipelines * add * add * add * add * add * make style * style + fix copies --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> * [docs] Improve safetensors docstring (#3368) * clarify safetensor docstring * fix typo * apply feedback * add: a warning message when using xformers in a PT 2.0 env. (#3365) * add: a warning message when using xformers in a PT 2.0 env. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]> * StableDiffusionInpaintingPipeline - resize image w.r.t height and width (#3322) * StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [docs] Adapt a model (#3326) * first draft * apply feedback * conv_in.weight thrown away * [docs] Load safetensors (#3333) * safetensors * apply feedback * apply feedback * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <[email protected]> * make style * [Docs] Fix stable_diffusion.mdx typo (#3398) Fix typo in last code block. Correct "prommpts" to "prompt" * Support ControlNet v1.1 shuffle properly (#3340) * add inferring_controlnet_cond_batch * Revert "add inferring_controlnet_cond_batch" This reverts commit abe8d6311d4b7f5b9409ca709c7fabf80d06c1a9. * set guess_mode to True whenever global_pool_conditions is True Co-authored-by: Patrick von Platen <[email protected]> * nit * add integration test --------- Co-authored-by: Patrick von Platen <[email protected]> * [Tests] better determinism (#3374) * enable deterministic pytorch and cuda operations. * disable manual seeding. * make style && make quality for unet_2d tests. * enable determinism for the unet2dconditional model. * add CUBLAS_WORKSPACE_CONFIG for better reproducibility. * relax tolerance (very weird issue, though). * revert to torch manual_seed() where needed. * relax more tolerance. * better placement of the cuda variable and relax more tolerance. * enable determinism for 3d condition model. * relax tolerance. * add: determinism to alt_diffusion. * relax tolerance for alt diffusion. * dance diffusion. * dance diffusion is flaky. * test_dict_tuple_outputs_equivalent edit. * fix two more tests. * fix more ddim tests. * fix: argument. * change to diff in place of difference. * fix: test_save_load call. * test_save_load_float16 call. * fix: expected_max_diff * fix: paint by example. * relax tolerance. * add determinism to 1d unet model. * torch 2.0 regressions seem to be brutal * determinism to vae. * add reason to skipping. * up tolerance. * determinism to vq. * determinism to cuda. * determinism to the generic test pipeline file. * refactor general pipelines testing a bit. * determinism to alt diffusion i2i * up tolerance for alt diff i2i and audio diff * up tolerance. * determinism to audioldm * increase tolerance for audioldm lms. * increase tolerance for paint by paint. * increase tolerance for repaint. * determinism to cycle diffusion and sd 1. * relax tol for cycle diffusion 🚲 * relax tol for sd 1.0 * relax tol for controlnet. * determinism to img var. * relax tol for img variation. * tolerance to i2i sd * make style * determinism to inpaint. * relax tolerance for inpaiting. * determinism for inpainting legacy * relax tolerance. * determinism to instruct pix2pix * determinism to model editing. * model editing tolerance. * panorama determinism * determinism to pix2pix zero. * determinism to sag. * sd 2. determinism * sd. tolerance * disallow tf32 matmul. * relax tolerance is all you need. * make style and determinism to sd 2 depth * relax tolerance for depth. * tolerance to diffedit. * tolerance to sd 2 inpaint. * up tolerance. * determinism in upscaling. * tolerance in upscaler. * more tolerance relaxation. * determinism to v pred. * up tol for v_pred * unclip determinism * determinism to unclip img2img * determinism to text to video. * determinism to last set of tests * up tol. * vq cumsum doesn't have a deterministic kernel * relax tol * relax tol * [docs] Add transformers to install (#3388) add transformers to install * [deepspeed] partial ZeRO-3 support (#3076) * [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]> * Add omegaconf for tests (#3400) Add omegaconfg * Fix various bugs with LoRA Dreambooth and Dreambooth script (#3353) * Improve checkpointing lora * fix more * Improve doc string * Update src/diffusers/loaders.py * make stytle * Apply suggestions from code review * Update src/diffusers/loaders.py * Apply suggestions from code review * Apply suggestions from code review * better * Fix all * Fix multi-GPU dreambooth * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Fix all * make style * make style --------- Co-authored-by: Pedro Cuenca <[email protected]> * Fix docker file (#3402) * up * up * fix: deepseepd_plugin retrieval from accelerate state (#3410) * [Docs] Add `sigmoid` beta_scheduler to docstrings of relevant Schedulers (#3399) * Add `sigmoid` beta scheduler to `DDPMScheduler` docstring * Add `sigmoid` beta scheduler to `RePaintScheduler` docstring --------- Co-authored-by: Patrick von Platen <[email protected]> * Don't install accelerate and transformers from source (#3415) * Don't install transformers and accelerate from source (#3414) * Improve fast tests (#3416) Update pr_tests.yml * attention refactor: the trilogy (#3387) * Replace `AttentionBlock` with `Attention` * use _from_deprecated_attn_block check re: @patrickvonplaten * [Docs] update the PT 2.0 optimization doc with latest findings (#3370) * add: benchmarking stats for A100 and V100. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * address patrick's comments. * add: rtx 4090 stats * ⚔ benchmark reports done * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * 3313 pr link. * add: plots. Co-authored-by: Pedro <[email protected]> * fix formattimg * update number percent. --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Fix style rendering (#3433) * Fix style rendering. * Fix typo * unCLIP scheduler do not use note (#3417) * Replace deprecated command with environment file (#3409) Co-authored-by: Patrick von Platen <[email protected]> * …
This PR implements a pipeline for the UniDiffuser model as discussed in #2857.
Model/Pipeline Description
The UniDiffuser model (paper, code) is a multi-modal model which extends the DDPM model to model all distributions relevant to a set of multi-modal data. From the paper abstract:
In this PR, we implement a image-text UniDiffuser model as described in the paper:
Usage Examples
TODO
Discussion
CC
@patrickvonplaten
@nemonameless
@baofff (author on original paper, author of original code)