Skip to content

Commit 27ee939

Browse files
committed
with diffusers cac, always run the original prompt on the first step
1 parent 5e7ed96 commit 27ee939

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

ldm/models/diffusion/cross_attention_control.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
108108
return self.tokens_cross_attention_action == Context.Action.APPLY
109109
return False
110110

111-
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
111+
def get_active_cross_attention_control_types_for_step(self, percent_through:Optional[float]=None, step_size:Optional[float]=None)\
112112
-> list[CrossAttentionType]:
113113
"""
114114
Should cross-attention control be applied on the given step?
@@ -117,6 +117,11 @@ def get_active_cross_attention_control_types_for_step(self, percent_through:floa
117117
"""
118118
if percent_through is None:
119119
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
120+
if step_size is not None:
121+
# adjust percent_through to ignore the first step
122+
percent_through = (percent_through - step_size) / (1.0 - step_size)
123+
if percent_through < 0:
124+
return []
120125

121126
opts = self.arguments.edit_options
122127
to_control = []

ldm/models/diffusion/shared_invokeai_diffusion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,16 @@ def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
141141
if step_index is not None and total_step_count is not None:
142142
# 🧨diffusers codepath
143143
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
144+
step_size_percent = 1 / total_step_count
144145
else:
145146
# legacy compvis codepath
146147
# TODO remove when compvis codepath support is dropped
147148
if step_index is None and sigma is None:
148149
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
149150
percent_through = self.estimate_percent_through(step_index, sigma)
150-
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
151+
# legacy code path supports s_* so we don't need step_size_percent
152+
step_size_percent = None
153+
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through, step_size=step_size_percent)
151154

152155
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
153156
wants_hybrid_conditioning = isinstance(conditioning, dict)

0 commit comments

Comments
 (0)