@@ -70,13 +70,8 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
70
70
71
71
# Have to unwrap the inpainting conditioning here to perform pre-processing
72
72
image_conditioning = None
73
- uc_image_conditioning = None
74
73
if isinstance (cond , dict ):
75
- if self .conditioning_key == "crossattn-adm" :
76
- image_conditioning = cond ["c_adm" ]
77
- uc_image_conditioning = unconditional_conditioning ["c_adm" ]
78
- else :
79
- image_conditioning = cond ["c_concat" ][0 ]
74
+ image_conditioning = cond ["c_concat" ][0 ]
80
75
cond = cond ["c_crossattn" ][0 ]
81
76
unconditional_conditioning = unconditional_conditioning ["c_crossattn" ][0 ]
82
77
@@ -103,12 +98,8 @@ def before_sample(self, x, ts, cond, unconditional_conditioning):
103
98
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
104
99
# Note that they need to be lists because it just concatenates them later.
105
100
if image_conditioning is not None :
106
- if self .conditioning_key == "crossattn-adm" :
107
- cond = {"c_adm" : image_conditioning , "c_crossattn" : [cond ]}
108
- unconditional_conditioning = {"c_adm" : uc_image_conditioning , "c_crossattn" : [unconditional_conditioning ]}
109
- else :
110
- cond = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond ]}
111
- unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
101
+ cond = {"c_concat" : [image_conditioning ], "c_crossattn" : [cond ]}
102
+ unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
112
103
113
104
return x , ts , cond , unconditional_conditioning
114
105
@@ -185,12 +176,8 @@ def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning,
185
176
186
177
# Wrap the conditioning models with additional image conditioning for inpainting model
187
178
if image_conditioning is not None :
188
- if self .conditioning_key == "crossattn-adm" :
189
- conditioning = {"c_adm" : image_conditioning , "c_crossattn" : [conditioning ]}
190
- unconditional_conditioning = {"c_adm" : torch .zeros_like (image_conditioning ), "c_crossattn" : [unconditional_conditioning ]}
191
- else :
192
- conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [conditioning ]}
193
- unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
179
+ conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [conditioning ]}
180
+ unconditional_conditioning = {"c_concat" : [image_conditioning ], "c_crossattn" : [unconditional_conditioning ]}
194
181
195
182
samples = self .launch_sampling (t_enc + 1 , lambda : self .sampler .decode (x1 , conditioning , t_enc , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning ))
196
183
@@ -208,12 +195,8 @@ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, ima
208
195
# Wrap the conditioning models with additional image conditioning for inpainting model
209
196
# dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
210
197
if image_conditioning is not None :
211
- if self .conditioning_key == "crossattn-adm" :
212
- conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_adm" : image_conditioning }
213
- unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_adm" : torch .zeros_like (image_conditioning )}
214
- else :
215
- conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_concat" : [image_conditioning ]}
216
- unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_concat" : [image_conditioning ]}
198
+ conditioning = {"dummy_for_plms" : np .zeros ((conditioning .shape [0 ],)), "c_crossattn" : [conditioning ], "c_concat" : [image_conditioning ]}
199
+ unconditional_conditioning = {"c_crossattn" : [unconditional_conditioning ], "c_concat" : [image_conditioning ]}
217
200
218
201
samples_ddim = self .launch_sampling (steps , lambda : self .sampler .sample (S = steps , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = p .cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x , eta = self .eta )[0 ])
219
202
0 commit comments