Skip to content

Commit 6f5b822

Browse files
Factor out encode text with Copied from (huggingface#1224)
* up * more fixes * fix * finalize * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * upload models * up
1 parent 701484b commit 6f5b822

8 files changed

+710
-480
lines changed

pipelines/stable_diffusion/pipeline_cycle_diffusion.py

+109-77
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,110 @@ def disable_attention_slicing(self):
205205
# set slice_size = `None` to disable `set_attention_slice`
206206
self.enable_attention_slicing(None)
207207

208+
@property
209+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
210+
def _execution_device(self):
211+
r"""
212+
Returns the device on which the pipeline's models will be executed. After calling
213+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
214+
hooks.
215+
"""
216+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
217+
return self.device
218+
for module in self.unet.modules():
219+
if (
220+
hasattr(module, "_hf_hook")
221+
and hasattr(module._hf_hook, "execution_device")
222+
and module._hf_hook.execution_device is not None
223+
):
224+
return torch.device(module._hf_hook.execution_device)
225+
return self.device
226+
227+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
228+
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
229+
r"""
230+
Encodes the prompt into text encoder hidden states.
231+
232+
Args:
233+
prompt (`str` or `list(int)`):
234+
prompt to be encoded
235+
device: (`torch.device`):
236+
torch device
237+
num_images_per_prompt (`int`):
238+
number of images that should be generated per prompt
239+
do_classifier_free_guidance (`bool`):
240+
whether to use classifier free guidance or not
241+
negative_prompt (`str` or `List[str]`):
242+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
243+
if `guidance_scale` is less than `1`).
244+
"""
245+
batch_size = len(prompt) if isinstance(prompt, list) else 1
246+
247+
text_inputs = self.tokenizer(
248+
prompt,
249+
padding="max_length",
250+
max_length=self.tokenizer.model_max_length,
251+
return_tensors="pt",
252+
)
253+
text_input_ids = text_inputs.input_ids
254+
255+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
256+
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
257+
logger.warning(
258+
"The following part of your input was truncated because CLIP can only handle sequences up to"
259+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
260+
)
261+
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
262+
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
263+
264+
# duplicate text embeddings for each generation per prompt, using mps friendly method
265+
bs_embed, seq_len, _ = text_embeddings.shape
266+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
267+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
268+
269+
# get unconditional embeddings for classifier free guidance
270+
if do_classifier_free_guidance:
271+
uncond_tokens: List[str]
272+
if negative_prompt is None:
273+
uncond_tokens = [""] * batch_size
274+
elif type(prompt) is not type(negative_prompt):
275+
raise TypeError(
276+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
277+
f" {type(prompt)}."
278+
)
279+
elif isinstance(negative_prompt, str):
280+
uncond_tokens = [negative_prompt]
281+
elif batch_size != len(negative_prompt):
282+
raise ValueError(
283+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
284+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
285+
" the batch size of `prompt`."
286+
)
287+
else:
288+
uncond_tokens = negative_prompt
289+
290+
max_length = text_input_ids.shape[-1]
291+
uncond_input = self.tokenizer(
292+
uncond_tokens,
293+
padding="max_length",
294+
max_length=max_length,
295+
truncation=True,
296+
return_tensors="pt",
297+
)
298+
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
299+
300+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
301+
seq_len = uncond_embeddings.shape[1]
302+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
303+
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
304+
305+
# For classifier free guidance, we need to do two forward passes.
306+
# Here we concatenate the unconditional and text embeddings into a single batch
307+
# to avoid doing two forward passes
308+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
309+
310+
return text_embeddings
311+
208312
@torch.no_grad()
209313
def __call__(
210314
self,
@@ -309,89 +413,17 @@ def __call__(
309413
if isinstance(init_image, PIL.Image.Image):
310414
init_image = preprocess(init_image)
311415

312-
# get prompt text embeddings
313-
text_inputs = self.tokenizer(
314-
prompt,
315-
padding="max_length",
316-
max_length=self.tokenizer.model_max_length,
317-
return_tensors="pt",
318-
)
319-
source_text_inputs = self.tokenizer(
320-
source_prompt,
321-
padding="max_length",
322-
max_length=self.tokenizer.model_max_length,
323-
return_tensors="pt",
324-
)
325-
text_input_ids = text_inputs.input_ids
326-
source_text_input_ids = source_text_inputs.input_ids
327-
328-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
329-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
330-
logger.warning(
331-
"The following part of your input was truncated because CLIP can only handle sequences up to"
332-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
333-
)
334-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
335-
if source_text_input_ids.shape[-1] > self.tokenizer.model_max_length:
336-
removed_text = self.tokenizer.batch_decode(source_text_input_ids[:, self.tokenizer.model_max_length :])
337-
logger.warning(
338-
"The following part of your input was truncated because CLIP can only handle sequences up to"
339-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
340-
)
341-
source_text_input_ids = source_text_input_ids[:, : self.tokenizer.model_max_length]
342-
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
343-
source_text_embeddings = self.text_encoder(source_text_input_ids.to(self.device))[0]
344-
345-
# duplicate text embeddings for each generation per prompt
346-
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
347-
source_text_embeddings = source_text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
416+
device = self._execution_device
348417

349418
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
350419
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
351420
# corresponds to doing no classifier free guidance.
421+
do_classifier_free_guidance = guidance_scale > 1.0
352422

353-
# get unconditional embeddings for classifier free guidance
354-
uncond_tokens = [""]
355-
356-
max_length = text_input_ids.shape[-1]
357-
uncond_input = self.tokenizer(
358-
uncond_tokens,
359-
padding="max_length",
360-
max_length=max_length,
361-
truncation=True,
362-
return_tensors="pt",
423+
text_embeddings = self._encode_prompt(prompt, device, num_images_per_prompt, do_classifier_free_guidance, None)
424+
source_text_embeddings = self._encode_prompt(
425+
source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None
363426
)
364-
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
365-
366-
# duplicate unconditional embeddings for each generation per prompt
367-
uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
368-
369-
# For classifier free guidance, we need to do two forward passes.
370-
# Here we concatenate the unconditional and text embeddings into a single batch
371-
# to avoid doing two forward passes
372-
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
373-
374-
source_uncond_tokens = [""]
375-
376-
max_length = source_text_input_ids.shape[-1]
377-
source_uncond_input = self.tokenizer(
378-
source_uncond_tokens,
379-
padding="max_length",
380-
max_length=max_length,
381-
truncation=True,
382-
return_tensors="pt",
383-
)
384-
source_uncond_embeddings = self.text_encoder(source_uncond_input.input_ids.to(self.device))[0]
385-
386-
# duplicate unconditional embeddings for each generation per prompt
387-
source_uncond_embeddings = source_uncond_embeddings.repeat_interleave(
388-
batch_size * num_images_per_prompt, dim=0
389-
)
390-
391-
# For classifier free guidance, we need to do two forward passes.
392-
# Here we concatenate the unconditional and text embeddings into a single batch
393-
# to avoid doing two forward passes
394-
source_text_embeddings = torch.cat([source_uncond_embeddings, source_text_embeddings])
395427

396428
# encode the init image into latents and scale the latents
397429
latents_dtype = text_embeddings.dtype

pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

+66-42
Original file line numberDiff line numberDiff line change
@@ -92,44 +92,22 @@ def __init__(
9292
feature_extractor=feature_extractor,
9393
)
9494

95-
def __call__(
96-
self,
97-
prompt: Union[str, List[str]],
98-
height: Optional[int] = 512,
99-
width: Optional[int] = 512,
100-
num_inference_steps: Optional[int] = 50,
101-
guidance_scale: Optional[float] = 7.5,
102-
negative_prompt: Optional[Union[str, List[str]]] = None,
103-
num_images_per_prompt: Optional[int] = 1,
104-
eta: Optional[float] = 0.0,
105-
generator: Optional[np.random.RandomState] = None,
106-
latents: Optional[np.ndarray] = None,
107-
output_type: Optional[str] = "pil",
108-
return_dict: bool = True,
109-
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
110-
callback_steps: Optional[int] = 1,
111-
**kwargs,
112-
):
113-
if isinstance(prompt, str):
114-
batch_size = 1
115-
elif isinstance(prompt, list):
116-
batch_size = len(prompt)
117-
else:
118-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
119-
120-
if height % 8 != 0 or width % 8 != 0:
121-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
122-
123-
if (callback_steps is None) or (
124-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
125-
):
126-
raise ValueError(
127-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
128-
f" {type(callback_steps)}."
129-
)
130-
131-
if generator is None:
132-
generator = np.random
95+
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
96+
r"""
97+
Encodes the prompt into text encoder hidden states.
98+
99+
Args:
100+
prompt (`str` or `list(int)`):
101+
prompt to be encoded
102+
num_images_per_prompt (`int`):
103+
number of images that should be generated per prompt
104+
do_classifier_free_guidance (`bool`):
105+
whether to use classifier free guidance or not
106+
negative_prompt (`str` or `List[str]`):
107+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
108+
if `guidance_scale` is less than `1`).
109+
"""
110+
batch_size = len(prompt) if isinstance(prompt, list) else 1
133111

134112
# get prompt text embeddings
135113
text_inputs = self.tokenizer(
@@ -150,10 +128,6 @@ def __call__(
150128
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
151129
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
152130

153-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
154-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
155-
# corresponds to doing no classifier free guidance.
156-
do_classifier_free_guidance = guidance_scale > 1.0
157131
# get unconditional embeddings for classifier free guidance
158132
if do_classifier_free_guidance:
159133
uncond_tokens: List[str]
@@ -191,6 +165,56 @@ def __call__(
191165
# to avoid doing two forward passes
192166
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
193167

168+
return text_embeddings
169+
170+
def __call__(
171+
self,
172+
prompt: Union[str, List[str]],
173+
height: Optional[int] = 512,
174+
width: Optional[int] = 512,
175+
num_inference_steps: Optional[int] = 50,
176+
guidance_scale: Optional[float] = 7.5,
177+
negative_prompt: Optional[Union[str, List[str]]] = None,
178+
num_images_per_prompt: Optional[int] = 1,
179+
eta: Optional[float] = 0.0,
180+
generator: Optional[np.random.RandomState] = None,
181+
latents: Optional[np.ndarray] = None,
182+
output_type: Optional[str] = "pil",
183+
return_dict: bool = True,
184+
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
185+
callback_steps: Optional[int] = 1,
186+
**kwargs,
187+
):
188+
if isinstance(prompt, str):
189+
batch_size = 1
190+
elif isinstance(prompt, list):
191+
batch_size = len(prompt)
192+
else:
193+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
194+
195+
if height % 8 != 0 or width % 8 != 0:
196+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
197+
198+
if (callback_steps is None) or (
199+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
200+
):
201+
raise ValueError(
202+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
203+
f" {type(callback_steps)}."
204+
)
205+
206+
if generator is None:
207+
generator = np.random
208+
209+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
210+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
211+
# corresponds to doing no classifier free guidance.
212+
do_classifier_free_guidance = guidance_scale > 1.0
213+
214+
text_embeddings = self._encode_prompt(
215+
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
216+
)
217+
194218
# get the initial random noise unless the user supplied it
195219
latents_dtype = text_embeddings.dtype
196220
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)

0 commit comments

Comments
 (0)