22
22
def load_into (f , model , prefix , device , dtype = None ):
23
23
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
24
24
for key in f .keys ():
25
- if key .startswith (prefix ):
25
+ if key .startswith (prefix ) and not key . startswith ( "loss." ) :
26
26
path = key [len (prefix ):].split ("." )
27
27
obj = model
28
28
for p in path :
@@ -133,6 +133,10 @@ def __init__(self, model):
133
133
MODEL = "models/sd3_beta.safetensors"
134
134
# VAE model file path, or set "None" to use the same model file
135
135
VAEFile = "models/sd3_vae.safetensors"
136
+ # Optional init image file path
137
+ INIT_IMAGE = None
138
+ # If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
139
+ DENOISE = 0.6
136
140
# Output file path
137
141
OUTPUT = "output.png"
138
142
@@ -194,12 +198,13 @@ def fix_cond(self, cond):
194
198
cond , pooled = (cond [0 ].half ().cuda (), cond [1 ].half ().cuda ())
195
199
return { "c_crossattn" : cond , "y" : pooled }
196
200
197
- def do_sampling (self , latent , seed , conditioning , neg_cond , steps , cfg_scale ) -> torch .Tensor :
201
+ def do_sampling (self , latent , seed , conditioning , neg_cond , steps , cfg_scale , denoise = 1.0 ) -> torch .Tensor :
198
202
print ("Sampling..." )
199
203
latent = latent .half ().cuda ()
200
204
self .sd3 .model = self .sd3 .model .cuda ()
201
205
noise = self .get_noise (seed , latent ).cuda ()
202
206
sigmas = self .get_sigmas (self .sd3 .model .model_sampling , steps ).cuda ()
207
+ sigmas = sigmas [int (steps * (1 - denoise )):]
203
208
conditioning = self .fix_cond (conditioning )
204
209
neg_cond = self .fix_cond (neg_cond )
205
210
extra_args = { "cond" : conditioning , "uncond" : neg_cond , "cond_scale" : cfg_scale }
@@ -210,6 +215,21 @@ def do_sampling(self, latent, seed, conditioning, neg_cond, steps, cfg_scale) ->
210
215
print ("Sampling done" )
211
216
return latent
212
217
218
+ def vae_encode (self , image ) -> torch .Tensor :
219
+ print ("Encoding image to latent..." )
220
+ image = image .convert ("RGB" )
221
+ image_np = np .array (image ).astype (np .float32 ) / 255.0
222
+ image_np = np .moveaxis (image_np , 2 , 0 )
223
+ batch_images = np .expand_dims (image_np , axis = 0 ).repeat (1 , axis = 0 )
224
+ image_torch = torch .from_numpy (batch_images )
225
+ image_torch = 2.0 * image_torch - 1.0
226
+ image_torch = image_torch .cuda ()
227
+ self .vae .model = self .vae .model .cuda ()
228
+ latent = self .vae .model .encode (image_torch ).cpu ()
229
+ self .vae .model = self .vae .model .cpu ()
230
+ print ("Encoded" )
231
+ return latent
232
+
213
233
def vae_decode (self , latent ) -> Image .Image :
214
234
print ("Decoding latent to image..." )
215
235
latent = latent .cuda ()
@@ -224,20 +244,25 @@ def vae_decode(self, latent) -> Image.Image:
224
244
print ("Decoded" )
225
245
return out_image
226
246
227
- def gen_image (self , prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , seed = SEED , output = OUTPUT ):
247
+ def gen_image (self , prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , seed = SEED , output = OUTPUT , init_image = INIT_IMAGE , denoise = DENOISE ):
228
248
latent = self .get_empty_latent (width , height )
249
+ if init_image :
250
+ image_data = Image .open (init_image )
251
+ image_data = image_data .resize ((width , height ), Image .LANCZOS )
252
+ latent = self .vae_encode (image_data )
253
+ latent = SD3LatentFormat ().process_in (latent )
229
254
conditioning = self .get_cond (prompt )
230
255
neg_cond = self .get_cond ("" )
231
- sampled_latent = self .do_sampling (latent , seed , conditioning , neg_cond , steps , cfg_scale )
256
+ sampled_latent = self .do_sampling (latent , seed , conditioning , neg_cond , steps , cfg_scale , denoise if init_image else 1.0 )
232
257
image = self .vae_decode (sampled_latent )
233
258
print (f"Will save to { output } " )
234
259
image .save (output )
235
260
print ("Done" )
236
261
237
262
@torch .no_grad ()
238
- def main (prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , shift = SHIFT , model = MODEL , vae = VAEFile , seed = SEED , output = OUTPUT ):
263
+ def main (prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , shift = SHIFT , model = MODEL , vae = VAEFile , seed = SEED , output = OUTPUT , init_image = INIT_IMAGE , denoise = DENOISE ):
239
264
inferencer = SD3Inferencer ()
240
265
inferencer .load (model , vae , shift )
241
- inferencer .gen_image (prompt , width , height , steps , cfg_scale , seed , output )
266
+ inferencer .gen_image (prompt , width , height , steps , cfg_scale , seed , output , init_image , denoise )
242
267
243
268
fire .Fire (main )
0 commit comments