17
17
from omegaconf import OmegaConf
18
18
from PIL import Image , ImageOps
19
19
from torch import nn
20
- from pytorch_lightning import seed_everything
20
+ from pytorch_lightning import seed_everything , logging
21
21
22
22
from ldm .util import instantiate_from_config
23
23
from ldm .models .diffusion .ddim import DDIMSampler
35
35
from ldm.generate import Generate
36
36
37
37
# Create an object with default values
38
- gr = Generate()
38
+ gr = Generate('stable-diffusion-1.4' )
39
39
40
40
# do the slow model initialization
41
41
gr.load_model()
79
79
80
80
The full list of arguments to Generate() are:
81
81
gr = Generate(
82
+ # these values are set once and shouldn't be changed
83
+ conf = path to configuration file ('configs/models.yaml')
84
+ model = symbolic name of the model in the configuration file
85
+ full_precision = False
86
+
87
+ # this value is sticky and maintained between generation calls
88
+ sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
89
+
90
+ # these are deprecated - use conf and model instead
82
91
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
83
- config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
84
- iterations = <integer> // how many times to run the sampling (1)
85
- steps = <integer> // 50
86
- seed = <integer> // current system time
87
- sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
88
- grid = <boolean> // false
89
- width = <integer> // image width, multiple of 64 (512)
90
- height = <integer> // image height, multiple of 64 (512)
91
- cfg_scale = <float> // condition-free guidance scale (7.5)
92
+ config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
92
93
)
93
94
94
95
"""
@@ -101,66 +102,62 @@ class Generate:
101
102
102
103
def __init__ (
103
104
self ,
104
- iterations = 1 ,
105
- steps = 50 ,
106
- cfg_scale = 7.5 ,
107
- weights = 'models/ldm/stable-diffusion-v1/model.ckpt' ,
108
- config = 'configs/stable-diffusion/v1-inference.yaml' ,
109
- grid = False ,
110
- width = 512 ,
111
- height = 512 ,
105
+ model = 'stable-diffusion-1.4' ,
106
+ conf = 'configs/models.yaml' ,
107
+ embedding_path = None ,
112
108
sampler_name = 'k_lms' ,
113
109
ddim_eta = 0.0 , # deterministic
114
110
full_precision = False ,
115
- strength = 0.75 , # default in scripts/img2img.py
116
- seamless = False ,
117
- embedding_path = None ,
118
- device_type = 'cuda' ,
119
- ignore_ctrl_c = False ,
111
+ # these are deprecated; if present they override values in the conf file
112
+ weights = None ,
113
+ config = None ,
120
114
):
121
- self .iterations = iterations
122
- self .width = width
123
- self .height = height
124
- self .steps = steps
125
- self .cfg_scale = cfg_scale
126
- self .weights = weights
127
- self .config = config
128
- self .sampler_name = sampler_name
129
- self .grid = grid
130
- self .ddim_eta = ddim_eta
131
- self .full_precision = True if choose_torch_device () == 'mps' else full_precision
132
- self .strength = strength
133
- self .seamless = seamless
134
- self .embedding_path = embedding_path
135
- self .device_type = device_type
136
- self .ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
137
- self .model = None # empty for now
138
- self .sampler = None
139
- self .device = None
140
- self .generators = {}
141
- self .base_generator = None
142
- self .seed = None
143
-
144
- if device_type == 'cuda' and not torch .cuda .is_available ():
145
- device_type = choose_torch_device ()
146
- print (">> cuda not available, using device" , device_type )
115
+ models = OmegaConf .load (conf )
116
+ mconfig = models [model ]
117
+ self .weights = mconfig .weights if weights is None else weights
118
+ self .config = mconfig .config if config is None else config
119
+ self .height = mconfig .height
120
+ self .width = mconfig .width
121
+ self .iterations = 1
122
+ self .steps = 50
123
+ self .cfg_scale = 7.5
124
+ self .sampler_name = sampler_name
125
+ self .ddim_eta = 0.0 # same seed always produces same image
126
+ self .full_precision = True if choose_torch_device () == 'mps' else full_precision
127
+ self .strength = 0.75
128
+ self .seamless = False
129
+ self .embedding_path = embedding_path
130
+ self .model = None # empty for now
131
+ self .sampler = None
132
+ self .device = None
133
+ self .session_peakmem = None
134
+ self .generators = {}
135
+ self .base_generator = None
136
+ self .seed = None
137
+
138
+ # Note that in previous versions, there was an option to pass the
139
+ # device to Generate(). However the device was then ignored, so
140
+ # it wasn't actually doing anything. This logic could be reinstated.
141
+ device_type = choose_torch_device ()
147
142
self .device = torch .device (device_type )
148
143
149
144
# for VRAM usage statistics
150
- device_type = choose_torch_device ()
151
- self .session_peakmem = torch .cuda .max_memory_allocated () if device_type == 'cuda' else None
145
+ self .session_peakmem = torch .cuda .max_memory_allocated () if self ._has_cuda else None
152
146
transformers .logging .set_verbosity_error ()
153
147
148
+ # gets rid of annoying messages about random seed
149
+ logging .getLogger ('pytorch_lightning' ).setLevel (logging .ERROR )
150
+
154
151
def prompt2png (self , prompt , outdir , ** kwargs ):
155
152
"""
156
153
Takes a prompt and an output directory, writes out the requested number
157
154
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
158
155
Optional named arguments are the same as those passed to Generate and prompt2image()
159
156
"""
160
- results = self .prompt2image (prompt , ** kwargs )
157
+ results = self .prompt2image (prompt , ** kwargs )
161
158
pngwriter = PngWriter (outdir )
162
- prefix = pngwriter .unique_prefix ()
163
- outputs = []
159
+ prefix = pngwriter .unique_prefix ()
160
+ outputs = []
164
161
for image , seed in results :
165
162
name = f'{ prefix } .{ seed } .png'
166
163
path = pngwriter .save_image_and_prompt_to_png (
@@ -183,33 +180,35 @@ def prompt2image(
183
180
self ,
184
181
# these are common
185
182
prompt ,
186
- iterations = None ,
187
- steps = None ,
188
- seed = None ,
189
- cfg_scale = None ,
190
- ddim_eta = None ,
191
- skip_normalize = False ,
192
- image_callback = None ,
193
- step_callback = None ,
194
- width = None ,
195
- height = None ,
196
- sampler_name = None ,
197
- seamless = False ,
198
- log_tokenization = False ,
199
- with_variations = None ,
200
- variation_amount = 0.0 ,
183
+ iterations = None ,
184
+ steps = None ,
185
+ seed = None ,
186
+ cfg_scale = None ,
187
+ ddim_eta = None ,
188
+ skip_normalize = False ,
189
+ image_callback = None ,
190
+ step_callback = None ,
191
+ width = None ,
192
+ height = None ,
193
+ sampler_name = None ,
194
+ seamless = False ,
195
+ log_tokenization = False ,
196
+ with_variations = None ,
197
+ variation_amount = 0.0 ,
201
198
# these are specific to img2img and inpaint
202
- init_img = None ,
203
- init_mask = None ,
204
- fit = False ,
205
- strength = None ,
199
+ init_img = None ,
200
+ init_mask = None ,
201
+ fit = False ,
202
+ strength = None ,
206
203
# these are specific to embiggen (which also relies on img2img args)
207
204
embiggen = None ,
208
205
embiggen_tiles = None ,
209
206
# these are specific to GFPGAN/ESRGAN
210
- gfpgan_strength = 0 ,
211
- save_original = False ,
212
- upscale = None ,
207
+ gfpgan_strength = 0 ,
208
+ save_original = False ,
209
+ upscale = None ,
210
+ # Set this True to handle KeyboardInterrupt internally
211
+ catch_interrupts = False ,
213
212
** args ,
214
213
): # eat up additional cruft
215
214
"""
@@ -262,10 +261,9 @@ def process_image(image,seed):
262
261
self .log_tokenization = log_tokenization
263
262
with_variations = [] if with_variations is None else with_variations
264
263
265
- model = (
266
- self .load_model ()
267
- ) # will instantiate the model or return it from cache
268
-
264
+ # will instantiate the model or return it from cache
265
+ model = self .load_model ()
266
+
269
267
for m in model .modules ():
270
268
if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
271
269
m .padding_mode = 'circular' if seamless else m ._orig_padding_mode
@@ -281,7 +279,6 @@ def process_image(image,seed):
281
279
(embiggen == None and embiggen_tiles == None ) or ((embiggen != None or embiggen_tiles != None ) and init_img != None )
282
280
), 'Embiggen requires an init/input image to be specified'
283
281
284
- # check this logic - doesn't look right
285
282
if len (with_variations ) > 0 or variation_amount > 1.0 :
286
283
assert seed is not None ,\
287
284
'seed must be specified when using with_variations'
@@ -298,7 +295,7 @@ def process_image(image,seed):
298
295
self ._set_sampler ()
299
296
300
297
tic = time .time ()
301
- if torch . cuda . is_available ():
298
+ if self . _has_cuda ():
302
299
torch .cuda .reset_peak_memory_stats ()
303
300
304
301
results = list ()
@@ -307,9 +304,9 @@ def process_image(image,seed):
307
304
308
305
try :
309
306
uc , c = get_uc_and_c (
310
- prompt , model = self .model ,
307
+ prompt , model = self .model ,
311
308
skip_normalize = skip_normalize ,
312
- log_tokens = self .log_tokenization
309
+ log_tokens = self .log_tokenization
313
310
)
314
311
315
312
(init_image ,mask_image ) = self ._make_images (init_img ,init_mask , width , height , fit )
@@ -352,27 +349,25 @@ def process_image(image,seed):
352
349
save_original = save_original ,
353
350
image_callback = image_callback )
354
351
355
- except KeyboardInterrupt :
356
- print ('*interrupted*' )
357
- if not self .ignore_ctrl_c :
358
- raise KeyboardInterrupt
359
- print (
360
- '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
361
- )
362
352
except RuntimeError as e :
363
353
print (traceback .format_exc (), file = sys .stderr )
364
354
print ('>> Could not generate image.' )
355
+ except KeyboardInterrupt :
356
+ if catch_interrupts :
357
+ print ('**Interrupted** Partial results will be returned.' )
358
+ else :
359
+ raise KeyboardInterrupt
365
360
366
361
toc = time .time ()
367
362
print ('>> Usage stats:' )
368
363
print (
369
364
f'>> { len (results )} image(s) generated in' , '%4.2fs' % (toc - tic )
370
365
)
371
- if torch . cuda . is_available () and self . device . type == 'cuda' :
366
+ if self . _has_cuda () :
372
367
print (
373
368
f'>> Max VRAM used for this generation:' ,
374
369
'%4.2fG.' % (torch .cuda .max_memory_allocated () / 1e9 ),
375
- 'Current VRAM utilization:'
370
+ 'Current VRAM utilization:' ,
376
371
'%4.2fG' % (torch .cuda .memory_allocated () / 1e9 ),
377
372
)
378
373
@@ -439,8 +434,7 @@ def load_model(self):
439
434
if self .model is None :
440
435
seed_everything (random .randrange (0 , np .iinfo (np .uint32 ).max ))
441
436
try :
442
- config = OmegaConf .load (self .config )
443
- model = self ._load_model_from_config (config , self .weights )
437
+ model = self ._load_model_from_config (self .config , self .weights )
444
438
if self .embedding_path is not None :
445
439
model .embedding_manager .load (
446
440
self .embedding_path , self .full_precision
@@ -541,8 +535,11 @@ def _set_sampler(self):
541
535
542
536
print (msg )
543
537
544
- def _load_model_from_config (self , config , ckpt ):
545
- print (f'>> Loading model from { ckpt } ' )
538
+ # Be warned: config is the path to the model config file, not the dream conf file!
539
+ # Also note that we can get config and weights from self, so why do we need to
540
+ # pass them as args?
541
+ def _load_model_from_config (self , config , weights ):
542
+ print (f'>> Loading model from { weights } ' )
546
543
547
544
# for usage statistics
548
545
device_type = choose_torch_device ()
@@ -551,10 +548,11 @@ def _load_model_from_config(self, config, ckpt):
551
548
tic = time .time ()
552
549
553
550
# this does the work
554
- pl_sd = torch .load (ckpt , map_location = 'cpu' )
555
- sd = pl_sd ['state_dict' ]
556
- model = instantiate_from_config (config .model )
557
- m , u = model .load_state_dict (sd , strict = False )
551
+ c = OmegaConf .load (config )
552
+ pl_sd = torch .load (weights , map_location = 'cpu' )
553
+ sd = pl_sd ['state_dict' ]
554
+ model = instantiate_from_config (c .model )
555
+ m , u = model .load_state_dict (sd , strict = False )
558
556
559
557
if self .full_precision :
560
558
print (
@@ -573,7 +571,7 @@ def _load_model_from_config(self, config, ckpt):
573
571
print (
574
572
f'>> Model loaded in' , '%4.2fs' % (toc - tic )
575
573
)
576
- if device_type == 'cuda' :
574
+ if self . _has_cuda () :
577
575
print (
578
576
'>> Max VRAM used to load the model:' ,
579
577
'%4.2fG' % (torch .cuda .max_memory_allocated () / 1e9 ),
@@ -710,3 +708,5 @@ def _resolution_check(self, width, height, log=False):
710
708
return width , height , resize_needed
711
709
712
710
711
+ def _has_cuda (self ):
712
+ return self .device .type == 'cuda'
0 commit comments