Skip to content

Commit 71574e9

Browse files
committed
make fix-copies
1 parent 580a6d5 commit 71574e9

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sana.py

+19
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,19 @@
2323

2424
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2525
from ...image_processor import PixArtImageProcessor
26+
from ...loaders import SanaLoraLoaderMixin
2627
from ...models import AutoencoderDC, SanaTransformer2DModel
2728
from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
2829
from ...schedulers import FlowMatchEulerDiscreteScheduler
2930
from ...utils import (
3031
BACKENDS_MAPPING,
32+
USE_PEFT_BACKEND,
3133
is_bs4_available,
3234
is_ftfy_available,
3335
logging,
3436
replace_example_docstring,
37+
scale_lora_layers,
38+
unscale_lora_layers,
3539
)
3640
from ...utils.torch_utils import randn_tensor
3741
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -185,6 +189,7 @@ def encode_prompt(
185189
clean_caption: bool = False,
186190
max_sequence_length: int = 300,
187191
complex_human_instruction: Optional[List[str]] = None,
192+
lora_scale: Optional[float] = None,
188193
):
189194
r"""
190195
Encodes the prompt into text encoder hidden states.
@@ -218,6 +223,15 @@ def encode_prompt(
218223
if device is None:
219224
device = self._execution_device
220225

226+
# set lora scale so that monkey patched LoRA
227+
# function of text encoder can correctly access it
228+
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
229+
self._lora_scale = lora_scale
230+
231+
# dynamically adjust the LoRA scale
232+
if self.text_encoder is not None and USE_PEFT_BACKEND:
233+
scale_lora_layers(self.text_encoder, lora_scale)
234+
221235
if prompt is not None and isinstance(prompt, str):
222236
batch_size = 1
223237
elif prompt is not None and isinstance(prompt, list):
@@ -313,6 +327,11 @@ def encode_prompt(
313327
negative_prompt_embeds = None
314328
negative_prompt_attention_mask = None
315329

330+
if self.text_encoder is not None:
331+
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
332+
# Retrieve the original scale by scaling back the LoRA layers
333+
unscale_lora_layers(self.text_encoder, lora_scale)
334+
316335
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
317336

318337
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs

0 commit comments

Comments
 (0)