@@ -97,10 +97,26 @@ def get_schedule(prompt):
97
97
98
98
99
99
ScheduledPromptConditioning = namedtuple ("ScheduledPromptConditioning" , ["end_at_step" , "cond" ])
100
- ScheduledPromptBatch = namedtuple ("ScheduledPromptBatch" , ["shape" , "schedules" ])
101
100
102
101
103
102
def get_learned_conditioning (model , prompts , steps ):
103
+ """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
104
+ and the sampling step at which this condition is to be replaced by the next one.
105
+
106
+ Input:
107
+ (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)
108
+
109
+ Output:
110
+ [
111
+ [
112
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0523, ..., -0.4901, -0.3066, 0.0674], ..., [ 0.3317, -0.5102, -0.4066, ..., 0.4119, -0.7647, -1.0160]], device='cuda:0'))
113
+ ],
114
+ [
115
+ ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.0192, 0.3867, -0.4644, ..., 0.1135, -0.3696, -0.4625]], device='cuda:0')),
116
+ ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886, 0.0229, -0.0522, ..., -0.4901, -0.3067, 0.0673], ..., [-0.7352, -0.4356, -0.7888, ..., 0.6994, -0.4312, -1.2593]], device='cuda:0'))
117
+ ]
118
+ ]
119
+ """
104
120
res = []
105
121
106
122
prompt_schedules = get_learned_conditioning_prompt_schedules (prompts , steps )
@@ -123,13 +139,75 @@ def get_learned_conditioning(model, prompts, steps):
123
139
cache [prompt ] = cond_schedule
124
140
res .append (cond_schedule )
125
141
126
- return ScheduledPromptBatch ((len (prompts ),) + res [0 ][0 ].cond .shape , res )
142
+ return res
143
+
144
+
145
+ re_AND = re .compile (r"\bAND\b" )
146
+ re_weight = re .compile (r"^(.*?)(?:\s*:\s*([-+]?\s*(?:\d+|\d*\.\d+)?))?\s*$" )
147
+
148
+
149
+ def get_multicond_prompt_list (prompts ):
150
+ res_indexes = []
151
+
152
+ prompt_flat_list = []
153
+ prompt_indexes = {}
154
+
155
+ for prompt in prompts :
156
+ subprompts = re_AND .split (prompt )
157
+
158
+ indexes = []
159
+ for subprompt in subprompts :
160
+ text , weight = re_weight .search (subprompt ).groups ()
161
+
162
+ weight = float (weight ) if weight is not None else 1.0
163
+
164
+ index = prompt_indexes .get (text , None )
165
+ if index is None :
166
+ index = len (prompt_flat_list )
167
+ prompt_flat_list .append (text )
168
+ prompt_indexes [text ] = index
169
+
170
+ indexes .append ((index , weight ))
171
+
172
+ res_indexes .append (indexes )
173
+
174
+ return res_indexes , prompt_flat_list , prompt_indexes
175
+
176
+
177
+ class ComposableScheduledPromptConditioning :
178
+ def __init__ (self , schedules , weight = 1.0 ):
179
+ self .schedules : list [ScheduledPromptConditioning ] = schedules
180
+ self .weight : float = weight
181
+
182
+
183
+ class MulticondLearnedConditioning :
184
+ def __init__ (self , shape , batch ):
185
+ self .shape : tuple = shape # the shape field is needed to send this object to DDIM/PLMS
186
+ self .batch : list [list [ComposableScheduledPromptConditioning ]] = batch
127
187
128
188
129
- def reconstruct_cond_batch (c : ScheduledPromptBatch , current_step ):
130
- param = c .schedules [0 ][0 ].cond
131
- res = torch .zeros (c .shape , device = param .device , dtype = param .dtype )
132
- for i , cond_schedule in enumerate (c .schedules ):
189
+ def get_multicond_learned_conditioning (model , prompts , steps ) -> MulticondLearnedConditioning :
190
+ """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
191
+ For each prompt, the list is obtained by splitting the prompt using the AND separator.
192
+
193
+ https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/
194
+ """
195
+
196
+ res_indexes , prompt_flat_list , prompt_indexes = get_multicond_prompt_list (prompts )
197
+
198
+ learned_conditioning = get_learned_conditioning (model , prompt_flat_list , steps )
199
+
200
+ res = []
201
+ for indexes in res_indexes :
202
+ res .append ([ComposableScheduledPromptConditioning (learned_conditioning [i ], weight ) for i , weight in indexes ])
203
+
204
+ return MulticondLearnedConditioning (shape = (len (prompts ),), batch = res )
205
+
206
+
207
+ def reconstruct_cond_batch (c : list [list [ScheduledPromptConditioning ]], current_step ):
208
+ param = c [0 ][0 ].cond
209
+ res = torch .zeros ((len (c ),) + param .shape , device = param .device , dtype = param .dtype )
210
+ for i , cond_schedule in enumerate (c ):
133
211
target_index = 0
134
212
for current , (end_at , cond ) in enumerate (cond_schedule ):
135
213
if current_step <= end_at :
@@ -140,6 +218,30 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
140
218
return res
141
219
142
220
221
+ def reconstruct_multicond_batch (c : MulticondLearnedConditioning , current_step ):
222
+ param = c .batch [0 ][0 ].schedules [0 ].cond
223
+
224
+ tensors = []
225
+ conds_list = []
226
+
227
+ for batch_no , composable_prompts in enumerate (c .batch ):
228
+ conds_for_batch = []
229
+
230
+ for cond_index , composable_prompt in enumerate (composable_prompts ):
231
+ target_index = 0
232
+ for current , (end_at , cond ) in enumerate (composable_prompt .schedules ):
233
+ if current_step <= end_at :
234
+ target_index = current
235
+ break
236
+
237
+ conds_for_batch .append ((len (tensors ), composable_prompt .weight ))
238
+ tensors .append (composable_prompt .schedules [target_index ].cond )
239
+
240
+ conds_list .append (conds_for_batch )
241
+
242
+ return conds_list , torch .stack (tensors ).to (device = param .device , dtype = param .dtype )
243
+
244
+
143
245
re_attention = re .compile (r"""
144
246
\\\(|
145
247
\\\)|
0 commit comments