Skip to content

Commit d079464

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 120a9f9 + 2858d7e commit d079464

File tree

5 files changed

+110
-16
lines changed

5 files changed

+110
-16
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ def parse_args():
239239
parser.add_argument(
240240
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
241241
)
242+
parser.add_argument(
243+
"--snr_gamma",
244+
type=float,
245+
default=None,
246+
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
247+
"More details here: https://arxiv.org/abs/2303.09556.",
248+
)
242249
parser.add_argument(
243250
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
244251
)
@@ -472,6 +479,30 @@ def main():
472479
else:
473480
raise ValueError("xformers is not available. Make sure it is installed correctly")
474481

482+
def compute_snr(timesteps):
483+
"""
484+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
485+
"""
486+
alphas_cumprod = noise_scheduler.alphas_cumprod
487+
sqrt_alphas_cumprod = alphas_cumprod**0.5
488+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
489+
490+
# Expand the tensors.
491+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
492+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
493+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
494+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
495+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
496+
497+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
498+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
499+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
500+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
501+
502+
# Compute SNR.
503+
snr = (alpha / sigma) ** 2
504+
return snr
505+
475506
lora_layers = AttnProcsLayers(unet.attn_processors)
476507

477508
# Enable TF32 for faster training on Ampere GPUs,
@@ -727,7 +758,23 @@ def collate_fn(examples):
727758

728759
# Predict the noise residual and compute loss
729760
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
730-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
761+
762+
if args.snr_gamma is None:
763+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
764+
else:
765+
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
766+
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
767+
# This is discussed in Section 4.2 of the same paper.
768+
snr = compute_snr(timesteps)
769+
mse_loss_weights = (
770+
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
771+
)
772+
# We first calculate the original loss. Then we mean over the non-batch dimensions and
773+
# rebalance the sample-wise losses with their respective loss weights.
774+
# Finally, we take the mean of the rebalanced loss.
775+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
776+
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
777+
loss = loss.mean()
731778

732779
# Gather the losses across all processes for logging (if we use distributed training).
733780
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()

src/diffusers/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1326,7 +1326,7 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs):
13261326
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
13271327
from_safetensors = file_extension == "safetensors"
13281328

1329-
if from_safetensors and use_safetensors is True:
1329+
if from_safetensors and use_safetensors is False:
13301330
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
13311331

13321332
# TODO: For now we only support stable diffusion

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
140140
new_item = new_item.replace("norm.weight", "group_norm.weight")
141141
new_item = new_item.replace("norm.bias", "group_norm.bias")
142142

143-
new_item = new_item.replace("q.weight", "query.weight")
144-
new_item = new_item.replace("q.bias", "query.bias")
143+
new_item = new_item.replace("q.weight", "to_q.weight")
144+
new_item = new_item.replace("q.bias", "to_q.bias")
145145

146-
new_item = new_item.replace("k.weight", "key.weight")
147-
new_item = new_item.replace("k.bias", "key.bias")
146+
new_item = new_item.replace("k.weight", "to_k.weight")
147+
new_item = new_item.replace("k.bias", "to_k.bias")
148148

149-
new_item = new_item.replace("v.weight", "value.weight")
150-
new_item = new_item.replace("v.bias", "value.bias")
149+
new_item = new_item.replace("v.weight", "to_v.weight")
150+
new_item = new_item.replace("v.bias", "to_v.bias")
151151

152-
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
153-
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
152+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
153+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
154154

155155
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
156156

@@ -204,8 +204,12 @@ def assign_to_checkpoint(
204204
new_path = new_path.replace(replacement["old"], replacement["new"])
205205

206206
# proj_attn.weight has to be converted from conv 1D to linear
207-
if "proj_attn.weight" in new_path:
207+
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
208+
shape = old_checkpoint[path["old"]].shape
209+
if is_attn_weight and len(shape) == 3:
208210
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
211+
elif is_attn_weight and len(shape) == 4:
212+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
209213
else:
210214
checkpoint[new_path] = old_checkpoint[path["old"]]
211215

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
import warnings
17-
from typing import Callable, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
import PIL
@@ -744,6 +744,7 @@ def __call__(
744744
return_dict: bool = True,
745745
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
746746
callback_steps: int = 1,
747+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
747748
):
748749
r"""
749750
Function invoked when calling the pipeline for generation.
@@ -815,7 +816,10 @@ def __call__(
815816
callback_steps (`int`, *optional*, defaults to 1):
816817
The frequency at which the `callback` function will be called. If not specified, the callback will be
817818
called at every step.
818-
819+
cross_attention_kwargs (`dict`, *optional*):
820+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
821+
`self.processor` in
822+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
819823
Examples:
820824
821825
```py
@@ -966,9 +970,13 @@ def __call__(
966970
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
967971

968972
# predict the noise residual
969-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
970-
0
971-
]
973+
noise_pred = self.unet(
974+
latent_model_input,
975+
t,
976+
encoder_hidden_states=prompt_embeds,
977+
cross_attention_kwargs=cross_attention_kwargs,
978+
return_dict=False,
979+
)[0]
972980

973981
# perform guidance
974982
if do_classifier_free_guidance:

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3636
from diffusers.utils.testing_utils import require_torch_gpu
3737

38+
from ...models.test_models_unet_2d_condition import create_lora_layers
3839
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
3940
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
4041

@@ -155,6 +156,40 @@ def test_stable_diffusion_inpaint_image_tensor(self):
155156
assert out_pil.shape == (1, 64, 64, 3)
156157
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
157158

159+
def test_stable_diffusion_inpaint_lora(self):
160+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
161+
162+
components = self.get_dummy_components()
163+
sd_pipe = StableDiffusionInpaintPipeline(**components)
164+
sd_pipe = sd_pipe.to(torch_device)
165+
sd_pipe.set_progress_bar_config(disable=None)
166+
167+
# forward 1
168+
inputs = self.get_dummy_inputs(device)
169+
output = sd_pipe(**inputs)
170+
image = output.images
171+
image_slice = image[0, -3:, -3:, -1]
172+
173+
# set lora layers
174+
lora_attn_procs = create_lora_layers(sd_pipe.unet)
175+
sd_pipe.unet.set_attn_processor(lora_attn_procs)
176+
sd_pipe = sd_pipe.to(torch_device)
177+
178+
# forward 2
179+
inputs = self.get_dummy_inputs(device)
180+
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0})
181+
image = output.images
182+
image_slice_1 = image[0, -3:, -3:, -1]
183+
184+
# forward 3
185+
inputs = self.get_dummy_inputs(device)
186+
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5})
187+
image = output.images
188+
image_slice_2 = image[0, -3:, -3:, -1]
189+
190+
assert np.abs(image_slice - image_slice_1).max() < 1e-2
191+
assert np.abs(image_slice - image_slice_2).max() > 1e-2
192+
158193
def test_inference_batch_single_identical(self):
159194
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
160195

0 commit comments

Comments
 (0)