17
17
from omegaconf import OmegaConf
18
18
from PIL import Image , ImageOps
19
19
from torch import nn
20
- from pytorch_lightning import seed_everything
20
+ from pytorch_lightning import seed_everything , logging
21
21
22
22
from ldm .util import instantiate_from_config
23
23
from ldm .models .diffusion .ddim import DDIMSampler
35
35
from ldm.generate import Generate
36
36
37
37
# Create an object with default values
38
- gr = Generate()
38
+ gr = Generate('stable-diffusion-1.4' )
39
39
40
40
# do the slow model initialization
41
41
gr.load_model()
79
79
80
80
The full list of arguments to Generate() are:
81
81
gr = Generate(
82
+ # these values are set once and shouldn't be changed
83
+ conf = path to configuration file ('configs/models.yaml')
84
+ model = symbolic name of the model in the configuration file
85
+ full_precision = False
86
+
87
+ # this value is sticky and maintained between generation calls
88
+ sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
89
+
90
+ # these are deprecated - use conf and model instead
82
91
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
83
- config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
84
- iterations = <integer> // how many times to run the sampling (1)
85
- steps = <integer> // 50
86
- seed = <integer> // current system time
87
- sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
88
- grid = <boolean> // false
89
- width = <integer> // image width, multiple of 64 (512)
90
- height = <integer> // image height, multiple of 64 (512)
91
- cfg_scale = <float> // condition-free guidance scale (7.5)
92
+ config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
92
93
)
93
94
94
95
"""
@@ -101,66 +102,62 @@ class Generate:
101
102
102
103
def __init__ (
103
104
self ,
104
- iterations = 1 ,
105
- steps = 50 ,
106
- cfg_scale = 7.5 ,
107
- weights = 'models/ldm/stable-diffusion-v1/model.ckpt' ,
108
- config = 'configs/stable-diffusion/v1-inference.yaml' ,
109
- grid = False ,
110
- width = 512 ,
111
- height = 512 ,
105
+ model = 'stable-diffusion-1.4' ,
106
+ conf = 'configs/models.yaml' ,
107
+ embedding_path = None ,
112
108
sampler_name = 'k_lms' ,
113
109
ddim_eta = 0.0 , # deterministic
114
110
full_precision = False ,
115
- strength = 0.75 , # default in scripts/img2img.py
116
- seamless = False ,
117
- embedding_path = None ,
118
- device_type = 'cuda' ,
119
- ignore_ctrl_c = False ,
111
+ # these are deprecated; if present they override values in the conf file
112
+ weights = None ,
113
+ config = None ,
120
114
):
121
- self .iterations = iterations
122
- self .width = width
123
- self .height = height
124
- self .steps = steps
125
- self .cfg_scale = cfg_scale
126
- self .weights = weights
127
- self .config = config
128
- self .sampler_name = sampler_name
129
- self .grid = grid
130
- self .ddim_eta = ddim_eta
131
- self .full_precision = True if choose_torch_device () == 'mps' else full_precision
132
- self .strength = strength
133
- self .seamless = seamless
134
- self .embedding_path = embedding_path
135
- self .device_type = device_type
136
- self .ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
137
- self .model = None # empty for now
138
- self .sampler = None
139
- self .device = None
140
- self .generators = {}
141
- self .base_generator = None
142
- self .seed = None
143
-
144
- if device_type == 'cuda' and not torch .cuda .is_available ():
145
- device_type = choose_torch_device ()
146
- print (">> cuda not available, using device" , device_type )
115
+ models = OmegaConf .load (conf )
116
+ mconfig = models [model ]
117
+ self .weights = mconfig .weights if weights is None else weights
118
+ self .config = mconfig .config if config is None else config
119
+ self .height = mconfig .height
120
+ self .width = mconfig .width
121
+ self .iterations = 1
122
+ self .steps = 50
123
+ self .cfg_scale = 7.5
124
+ self .sampler_name = sampler_name
125
+ self .ddim_eta = 0.0 # same seed always produces same image
126
+ self .full_precision = True if choose_torch_device () == 'mps' else full_precision
127
+ self .strength = 0.75
128
+ self .seamless = False
129
+ self .embedding_path = embedding_path
130
+ self .model = None # empty for now
131
+ self .sampler = None
132
+ self .device = None
133
+ self .session_peakmem = None
134
+ self .generators = {}
135
+ self .base_generator = None
136
+ self .seed = None
137
+
138
+ # Note that in previous versions, there was an option to pass the
139
+ # device to Generate(). However the device was then ignored, so
140
+ # it wasn't actually doing anything. This logic could be reinstated.
141
+ device_type = choose_torch_device ()
147
142
self .device = torch .device (device_type )
148
143
149
144
# for VRAM usage statistics
150
- device_type = choose_torch_device ()
151
- self .session_peakmem = torch .cuda .max_memory_allocated () if device_type == 'cuda' else None
145
+ self .session_peakmem = torch .cuda .max_memory_allocated () if self ._has_cuda else None
152
146
transformers .logging .set_verbosity_error ()
153
147
148
+ # gets rid of annoying messages about random seed
149
+ logging .getLogger ('pytorch_lightning' ).setLevel (logging .ERROR )
150
+
154
151
def prompt2png (self , prompt , outdir , ** kwargs ):
155
152
"""
156
153
Takes a prompt and an output directory, writes out the requested number
157
154
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
158
155
Optional named arguments are the same as those passed to Generate and prompt2image()
159
156
"""
160
- results = self .prompt2image (prompt , ** kwargs )
157
+ results = self .prompt2image (prompt , ** kwargs )
161
158
pngwriter = PngWriter (outdir )
162
- prefix = pngwriter .unique_prefix ()
163
- outputs = []
159
+ prefix = pngwriter .unique_prefix ()
160
+ outputs = []
164
161
for image , seed in results :
165
162
name = f'{ prefix } .{ seed } .png'
166
163
path = pngwriter .save_image_and_prompt_to_png (
@@ -183,21 +180,21 @@ def prompt2image(
183
180
self ,
184
181
# these are common
185
182
prompt ,
186
- iterations = None ,
187
- steps = None ,
188
- seed = None ,
189
- cfg_scale = None ,
190
- ddim_eta = None ,
191
- skip_normalize = False ,
192
- image_callback = None ,
193
- step_callback = None ,
194
- width = None ,
195
- height = None ,
196
- sampler_name = None ,
197
- seamless = False ,
198
- log_tokenization = False ,
199
- with_variations = None ,
200
- variation_amount = 0.0 ,
183
+ iterations = None ,
184
+ steps = None ,
185
+ seed = None ,
186
+ cfg_scale = None ,
187
+ ddim_eta = None ,
188
+ skip_normalize = False ,
189
+ image_callback = None ,
190
+ step_callback = None ,
191
+ width = None ,
192
+ height = None ,
193
+ sampler_name = None ,
194
+ seamless = False ,
195
+ log_tokenization = False ,
196
+ with_variations = None ,
197
+ variation_amount = 0.0 ,
201
198
# these are specific to img2img and inpaint
202
199
threshold = 0.0 ,
203
200
perlin = 0.0 ,
@@ -209,9 +206,11 @@ def prompt2image(
209
206
embiggen = None ,
210
207
embiggen_tiles = None ,
211
208
# these are specific to GFPGAN/ESRGAN
212
- gfpgan_strength = 0 ,
213
- save_original = False ,
214
- upscale = None ,
209
+ gfpgan_strength = 0 ,
210
+ save_original = False ,
211
+ upscale = None ,
212
+ # Set this True to handle KeyboardInterrupt internally
213
+ catch_interrupts = False ,
215
214
** args ,
216
215
): # eat up additional cruft
217
216
"""
@@ -266,10 +265,9 @@ def process_image(image,seed):
266
265
self .log_tokenization = log_tokenization
267
266
with_variations = [] if with_variations is None else with_variations
268
267
269
- model = (
270
- self .load_model ()
271
- ) # will instantiate the model or return it from cache
272
-
268
+ # will instantiate the model or return it from cache
269
+ model = self .load_model ()
270
+
273
271
for m in model .modules ():
274
272
if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
275
273
m .padding_mode = 'circular' if seamless else m ._orig_padding_mode
@@ -289,7 +287,6 @@ def process_image(image,seed):
289
287
(embiggen == None and embiggen_tiles == None ) or ((embiggen != None or embiggen_tiles != None ) and init_img != None )
290
288
), 'Embiggen requires an init/input image to be specified'
291
289
292
- # check this logic - doesn't look right
293
290
if len (with_variations ) > 0 or variation_amount > 1.0 :
294
291
assert seed is not None ,\
295
292
'seed must be specified when using with_variations'
@@ -306,7 +303,7 @@ def process_image(image,seed):
306
303
self ._set_sampler ()
307
304
308
305
tic = time .time ()
309
- if torch . cuda . is_available ():
306
+ if self . _has_cuda ():
310
307
torch .cuda .reset_peak_memory_stats ()
311
308
312
309
results = list ()
@@ -315,9 +312,9 @@ def process_image(image,seed):
315
312
316
313
try :
317
314
uc , c = get_uc_and_c (
318
- prompt , model = self .model ,
315
+ prompt , model = self .model ,
319
316
skip_normalize = skip_normalize ,
320
- log_tokens = self .log_tokenization
317
+ log_tokens = self .log_tokenization
321
318
)
322
319
323
320
(init_image ,mask_image ) = self ._make_images (init_img ,init_mask , width , height , fit )
@@ -362,27 +359,25 @@ def process_image(image,seed):
362
359
save_original = save_original ,
363
360
image_callback = image_callback )
364
361
365
- except KeyboardInterrupt :
366
- print ('*interrupted*' )
367
- if not self .ignore_ctrl_c :
368
- raise KeyboardInterrupt
369
- print (
370
- '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
371
- )
372
362
except RuntimeError as e :
373
363
print (traceback .format_exc (), file = sys .stderr )
374
364
print ('>> Could not generate image.' )
365
+ except KeyboardInterrupt :
366
+ if catch_interrupts :
367
+ print ('**Interrupted** Partial results will be returned.' )
368
+ else :
369
+ raise KeyboardInterrupt
375
370
376
371
toc = time .time ()
377
372
print ('>> Usage stats:' )
378
373
print (
379
374
f'>> { len (results )} image(s) generated in' , '%4.2fs' % (toc - tic )
380
375
)
381
- if torch . cuda . is_available () and self . device . type == 'cuda' :
376
+ if self . _has_cuda () :
382
377
print (
383
378
f'>> Max VRAM used for this generation:' ,
384
379
'%4.2fG.' % (torch .cuda .max_memory_allocated () / 1e9 ),
385
- 'Current VRAM utilization:'
380
+ 'Current VRAM utilization:' ,
386
381
'%4.2fG' % (torch .cuda .memory_allocated () / 1e9 ),
387
382
)
388
383
@@ -449,8 +444,7 @@ def load_model(self):
449
444
if self .model is None :
450
445
seed_everything (random .randrange (0 , np .iinfo (np .uint32 ).max ))
451
446
try :
452
- config = OmegaConf .load (self .config )
453
- model = self ._load_model_from_config (config , self .weights )
447
+ model = self ._load_model_from_config (self .config , self .weights )
454
448
if self .embedding_path is not None :
455
449
model .embedding_manager .load (
456
450
self .embedding_path , self .full_precision
@@ -551,8 +545,11 @@ def _set_sampler(self):
551
545
552
546
print (msg )
553
547
554
- def _load_model_from_config (self , config , ckpt ):
555
- print (f'>> Loading model from { ckpt } ' )
548
+ # Be warned: config is the path to the model config file, not the dream conf file!
549
+ # Also note that we can get config and weights from self, so why do we need to
550
+ # pass them as args?
551
+ def _load_model_from_config (self , config , weights ):
552
+ print (f'>> Loading model from { weights } ' )
556
553
557
554
# for usage statistics
558
555
device_type = choose_torch_device ()
@@ -561,10 +558,11 @@ def _load_model_from_config(self, config, ckpt):
561
558
tic = time .time ()
562
559
563
560
# this does the work
564
- pl_sd = torch .load (ckpt , map_location = 'cpu' )
565
- sd = pl_sd ['state_dict' ]
566
- model = instantiate_from_config (config .model )
567
- m , u = model .load_state_dict (sd , strict = False )
561
+ c = OmegaConf .load (config )
562
+ pl_sd = torch .load (weights , map_location = 'cpu' )
563
+ sd = pl_sd ['state_dict' ]
564
+ model = instantiate_from_config (c .model )
565
+ m , u = model .load_state_dict (sd , strict = False )
568
566
569
567
if self .full_precision :
570
568
print (
@@ -583,7 +581,7 @@ def _load_model_from_config(self, config, ckpt):
583
581
print (
584
582
f'>> Model loaded in' , '%4.2fs' % (toc - tic )
585
583
)
586
- if device_type == 'cuda' :
584
+ if self . _has_cuda () :
587
585
print (
588
586
'>> Max VRAM used to load the model:' ,
589
587
'%4.2fG' % (torch .cuda .max_memory_allocated () / 1e9 ),
@@ -720,3 +718,5 @@ def _resolution_check(self, width, height, log=False):
720
718
return width , height , resize_needed
721
719
722
720
721
+ def _has_cuda (self ):
722
+ return self .device .type == 'cuda'
0 commit comments