|
23 | 23 |
|
24 | 24 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
25 | 25 | from ...image_processor import PixArtImageProcessor
|
| 26 | +from ...loaders import SanaLoraLoaderMixin |
26 | 27 | from ...models import AutoencoderDC, SanaTransformer2DModel
|
27 | 28 | from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
|
28 | 29 | from ...schedulers import FlowMatchEulerDiscreteScheduler
|
29 | 30 | from ...utils import (
|
30 | 31 | BACKENDS_MAPPING,
|
| 32 | + USE_PEFT_BACKEND, |
31 | 33 | is_bs4_available,
|
32 | 34 | is_ftfy_available,
|
33 | 35 | logging,
|
34 | 36 | replace_example_docstring,
|
| 37 | + scale_lora_layers, |
| 38 | + unscale_lora_layers, |
35 | 39 | )
|
36 | 40 | from ...utils.torch_utils import randn_tensor
|
37 | 41 | from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
@@ -185,6 +189,7 @@ def encode_prompt(
|
185 | 189 | clean_caption: bool = False,
|
186 | 190 | max_sequence_length: int = 300,
|
187 | 191 | complex_human_instruction: Optional[List[str]] = None,
|
| 192 | + lora_scale: Optional[float] = None, |
188 | 193 | ):
|
189 | 194 | r"""
|
190 | 195 | Encodes the prompt into text encoder hidden states.
|
@@ -218,6 +223,15 @@ def encode_prompt(
|
218 | 223 | if device is None:
|
219 | 224 | device = self._execution_device
|
220 | 225 |
|
| 226 | + # set lora scale so that monkey patched LoRA |
| 227 | + # function of text encoder can correctly access it |
| 228 | + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): |
| 229 | + self._lora_scale = lora_scale |
| 230 | + |
| 231 | + # dynamically adjust the LoRA scale |
| 232 | + if self.text_encoder is not None and USE_PEFT_BACKEND: |
| 233 | + scale_lora_layers(self.text_encoder, lora_scale) |
| 234 | + |
221 | 235 | if prompt is not None and isinstance(prompt, str):
|
222 | 236 | batch_size = 1
|
223 | 237 | elif prompt is not None and isinstance(prompt, list):
|
@@ -313,6 +327,11 @@ def encode_prompt(
|
313 | 327 | negative_prompt_embeds = None
|
314 | 328 | negative_prompt_attention_mask = None
|
315 | 329 |
|
| 330 | + if self.text_encoder is not None: |
| 331 | + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: |
| 332 | + # Retrieve the original scale by scaling back the LoRA layers |
| 333 | + unscale_lora_layers(self.text_encoder, lora_scale) |
| 334 | + |
316 | 335 | return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
317 | 336 |
|
318 | 337 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
|
0 commit comments