Skip to content

Commit 23433bf

Browse files
committed
attention_kwargs -> cross_attention_kwargs.
1 parent f219198 commit 23433bf

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

src/diffusers/models/transformers/sana_transformer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -364,22 +364,22 @@ def forward(
364364
timestep: torch.LongTensor,
365365
encoder_attention_mask: Optional[torch.Tensor] = None,
366366
attention_mask: Optional[torch.Tensor] = None,
367-
attention_kwargs: Optional[Dict[str, Any]] = None,
367+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
368368
return_dict: bool = True,
369369
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
370-
if attention_kwargs is not None:
371-
attention_kwargs = attention_kwargs.copy()
372-
lora_scale = attention_kwargs.pop("scale", 1.0)
370+
if cross_attention_kwargs is not None:
371+
cross_attention_kwargs = cross_attention_kwargs.copy()
372+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
373373
else:
374374
lora_scale = 1.0
375375

376376
if USE_PEFT_BACKEND:
377377
# weight the lora layers by setting `lora_scale` for each PEFT layer
378378
scale_lora_layers(self, lora_scale)
379379
else:
380-
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
380+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None:
381381
logger.warning(
382-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
382+
"Passing `scale` via `cross_attention_kwargs` when not using the PEFT backend is ineffective."
383383
)
384384

385385
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.

src/diffusers/pipelines/sana/pipeline_sana.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,8 @@ def guidance_scale(self):
574574
return self._guidance_scale
575575

576576
@property
577-
def attention_kwargs(self):
578-
return self._attention_kwargs
577+
def cross_attention_kwargs(self):
578+
return self._cross_attention_kwargs
579579

580580
@property
581581
def do_classifier_free_guidance(self):
@@ -613,7 +613,7 @@ def __call__(
613613
return_dict: bool = True,
614614
clean_caption: bool = True,
615615
use_resolution_binning: bool = True,
616-
attention_kwargs: Optional[Dict[str, Any]] = None,
616+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
617617
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
618618
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
619619
max_sequence_length: int = 300,
@@ -686,7 +686,7 @@ def __call__(
686686
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
687687
return_dict (`bool`, *optional*, defaults to `True`):
688688
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
689-
attention_kwargs: TODO
689+
cross_attention_kwargs: TODO
690690
clean_caption (`bool`, *optional*, defaults to `True`):
691691
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
692692
be installed. If the dependencies are not installed, the embeddings will be created from the raw
@@ -747,7 +747,7 @@ def __call__(
747747
)
748748

749749
self._guidance_scale = guidance_scale
750-
self._attention_kwargs = attention_kwargs
750+
self._cross_attention_kwargs = cross_attention_kwargs
751751
self._interrupt = False
752752

753753
# 2. Default height and width to transformer
@@ -759,7 +759,9 @@ def __call__(
759759
batch_size = prompt_embeds.shape[0]
760760

761761
device = self._execution_device
762-
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
762+
lora_scale = (
763+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
764+
)
763765

764766
# 3. Encode input prompt
765767
(
@@ -829,7 +831,7 @@ def __call__(
829831
encoder_attention_mask=prompt_attention_mask,
830832
timestep=timestep,
831833
return_dict=False,
832-
attention_kwargs=self.attention_kwargs,
834+
cross_attention_kwargs=self.cross_attention_kwargs,
833835
)[0]
834836
noise_pred = noise_pred.float()
835837

0 commit comments

Comments
 (0)