Skip to content

Commit 701484b

Browse files
authored
apply repeat_interleave fix for mps to stable diffusion image2image pipeline (huggingface#1135)
copy from other pipeline
1 parent 9543cda commit 701484b

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,10 @@ def __call__(
337337
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
338338
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
339339

340-
# duplicate text embeddings for each generation per prompt
341-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
340+
# duplicate text embeddings for each generation per prompt, using mps friendly method
341+
bs_embed, seq_len, _ = text_embeddings.shape
342+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
343+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
342344

343345
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
344346
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

0 commit comments

Comments
 (0)