Skip to content

Commit bd80c2c

Browse files
lsteinafiaka87
authored andcommitted
Refactor generate.py and dream.py (invoke-ai#534)
* revert inadvertent change of conda env name (invoke-ai#528) * Refactor generate.py and dream.py * config file path (models.yaml) is parsed inside Generate() to simplify API * Better handling of keyboard interrupts in file loading mode vs interactive * Removed oodles of unused variables. * move nonfunctional inpainting out of the scripts directory * fix ugly ddim tqdm formatting
1 parent a45ffbc commit bd80c2c

File tree

4 files changed

+138
-159
lines changed

4 files changed

+138
-159
lines changed

ldm/generate.py

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from omegaconf import OmegaConf
1818
from PIL import Image, ImageOps
1919
from torch import nn
20-
from pytorch_lightning import seed_everything
20+
from pytorch_lightning import seed_everything, logging
2121

2222
from ldm.util import instantiate_from_config
2323
from ldm.models.diffusion.ddim import DDIMSampler
@@ -35,7 +35,7 @@
3535
from ldm.generate import Generate
3636
3737
# Create an object with default values
38-
gr = Generate()
38+
gr = Generate('stable-diffusion-1.4')
3939
4040
# do the slow model initialization
4141
gr.load_model()
@@ -79,16 +79,17 @@
7979
8080
The full list of arguments to Generate() are:
8181
gr = Generate(
82+
# these values are set once and shouldn't be changed
83+
conf = path to configuration file ('configs/models.yaml')
84+
model = symbolic name of the model in the configuration file
85+
full_precision = False
86+
87+
# this value is sticky and maintained between generation calls
88+
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
89+
90+
# these are deprecated - use conf and model instead
8291
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
83-
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
84-
iterations = <integer> // how many times to run the sampling (1)
85-
steps = <integer> // 50
86-
seed = <integer> // current system time
87-
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
88-
grid = <boolean> // false
89-
width = <integer> // image width, multiple of 64 (512)
90-
height = <integer> // image height, multiple of 64 (512)
91-
cfg_scale = <float> // condition-free guidance scale (7.5)
92+
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
9293
)
9394
9495
"""
@@ -101,66 +102,62 @@ class Generate:
101102

102103
def __init__(
103104
self,
104-
iterations = 1,
105-
steps = 50,
106-
cfg_scale = 7.5,
107-
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
108-
config = 'configs/stable-diffusion/v1-inference.yaml',
109-
grid = False,
110-
width = 512,
111-
height = 512,
105+
model = 'stable-diffusion-1.4',
106+
conf = 'configs/models.yaml',
107+
embedding_path = None,
112108
sampler_name = 'k_lms',
113109
ddim_eta = 0.0, # deterministic
114110
full_precision = False,
115-
strength = 0.75, # default in scripts/img2img.py
116-
seamless = False,
117-
embedding_path = None,
118-
device_type = 'cuda',
119-
ignore_ctrl_c = False,
111+
# these are deprecated; if present they override values in the conf file
112+
weights = None,
113+
config = None,
120114
):
121-
self.iterations = iterations
122-
self.width = width
123-
self.height = height
124-
self.steps = steps
125-
self.cfg_scale = cfg_scale
126-
self.weights = weights
127-
self.config = config
128-
self.sampler_name = sampler_name
129-
self.grid = grid
130-
self.ddim_eta = ddim_eta
131-
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
132-
self.strength = strength
133-
self.seamless = seamless
134-
self.embedding_path = embedding_path
135-
self.device_type = device_type
136-
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
137-
self.model = None # empty for now
138-
self.sampler = None
139-
self.device = None
140-
self.generators = {}
141-
self.base_generator = None
142-
self.seed = None
143-
144-
if device_type == 'cuda' and not torch.cuda.is_available():
145-
device_type = choose_torch_device()
146-
print(">> cuda not available, using device", device_type)
115+
models = OmegaConf.load(conf)
116+
mconfig = models[model]
117+
self.weights = mconfig.weights if weights is None else weights
118+
self.config = mconfig.config if config is None else config
119+
self.height = mconfig.height
120+
self.width = mconfig.width
121+
self.iterations = 1
122+
self.steps = 50
123+
self.cfg_scale = 7.5
124+
self.sampler_name = sampler_name
125+
self.ddim_eta = 0.0 # same seed always produces same image
126+
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
127+
self.strength = 0.75
128+
self.seamless = False
129+
self.embedding_path = embedding_path
130+
self.model = None # empty for now
131+
self.sampler = None
132+
self.device = None
133+
self.session_peakmem = None
134+
self.generators = {}
135+
self.base_generator = None
136+
self.seed = None
137+
138+
# Note that in previous versions, there was an option to pass the
139+
# device to Generate(). However the device was then ignored, so
140+
# it wasn't actually doing anything. This logic could be reinstated.
141+
device_type = choose_torch_device()
147142
self.device = torch.device(device_type)
148143

149144
# for VRAM usage statistics
150-
device_type = choose_torch_device()
151-
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
145+
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
152146
transformers.logging.set_verbosity_error()
153147

148+
# gets rid of annoying messages about random seed
149+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
150+
154151
def prompt2png(self, prompt, outdir, **kwargs):
155152
"""
156153
Takes a prompt and an output directory, writes out the requested number
157154
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
158155
Optional named arguments are the same as those passed to Generate and prompt2image()
159156
"""
160-
results = self.prompt2image(prompt, **kwargs)
157+
results = self.prompt2image(prompt, **kwargs)
161158
pngwriter = PngWriter(outdir)
162-
prefix = pngwriter.unique_prefix()
163-
outputs = []
159+
prefix = pngwriter.unique_prefix()
160+
outputs = []
164161
for image, seed in results:
165162
name = f'{prefix}.{seed}.png'
166163
path = pngwriter.save_image_and_prompt_to_png(
@@ -183,21 +180,21 @@ def prompt2image(
183180
self,
184181
# these are common
185182
prompt,
186-
iterations = None,
187-
steps = None,
188-
seed = None,
189-
cfg_scale = None,
190-
ddim_eta = None,
191-
skip_normalize = False,
192-
image_callback = None,
193-
step_callback = None,
194-
width = None,
195-
height = None,
196-
sampler_name = None,
197-
seamless = False,
198-
log_tokenization= False,
199-
with_variations = None,
200-
variation_amount = 0.0,
183+
iterations = None,
184+
steps = None,
185+
seed = None,
186+
cfg_scale = None,
187+
ddim_eta = None,
188+
skip_normalize = False,
189+
image_callback = None,
190+
step_callback = None,
191+
width = None,
192+
height = None,
193+
sampler_name = None,
194+
seamless = False,
195+
log_tokenization = False,
196+
with_variations = None,
197+
variation_amount = 0.0,
201198
# these are specific to img2img and inpaint
202199
threshold = 0.0,
203200
perlin = 0.0,
@@ -209,9 +206,11 @@ def prompt2image(
209206
embiggen = None,
210207
embiggen_tiles = None,
211208
# these are specific to GFPGAN/ESRGAN
212-
gfpgan_strength= 0,
213-
save_original = False,
214-
upscale = None,
209+
gfpgan_strength = 0,
210+
save_original = False,
211+
upscale = None,
212+
# Set this True to handle KeyboardInterrupt internally
213+
catch_interrupts = False,
215214
**args,
216215
): # eat up additional cruft
217216
"""
@@ -266,10 +265,9 @@ def process_image(image,seed):
266265
self.log_tokenization = log_tokenization
267266
with_variations = [] if with_variations is None else with_variations
268267

269-
model = (
270-
self.load_model()
271-
) # will instantiate the model or return it from cache
272-
268+
# will instantiate the model or return it from cache
269+
model = self.load_model()
270+
273271
for m in model.modules():
274272
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
275273
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
@@ -289,7 +287,6 @@ def process_image(image,seed):
289287
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
290288
), 'Embiggen requires an init/input image to be specified'
291289

292-
# check this logic - doesn't look right
293290
if len(with_variations) > 0 or variation_amount > 1.0:
294291
assert seed is not None,\
295292
'seed must be specified when using with_variations'
@@ -306,7 +303,7 @@ def process_image(image,seed):
306303
self._set_sampler()
307304

308305
tic = time.time()
309-
if torch.cuda.is_available():
306+
if self._has_cuda():
310307
torch.cuda.reset_peak_memory_stats()
311308

312309
results = list()
@@ -315,9 +312,9 @@ def process_image(image,seed):
315312

316313
try:
317314
uc, c = get_uc_and_c(
318-
prompt, model=self.model,
315+
prompt, model =self.model,
319316
skip_normalize=skip_normalize,
320-
log_tokens=self.log_tokenization
317+
log_tokens =self.log_tokenization
321318
)
322319

323320
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
@@ -362,27 +359,25 @@ def process_image(image,seed):
362359
save_original = save_original,
363360
image_callback = image_callback)
364361

365-
except KeyboardInterrupt:
366-
print('*interrupted*')
367-
if not self.ignore_ctrl_c:
368-
raise KeyboardInterrupt
369-
print(
370-
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
371-
)
372362
except RuntimeError as e:
373363
print(traceback.format_exc(), file=sys.stderr)
374364
print('>> Could not generate image.')
365+
except KeyboardInterrupt:
366+
if catch_interrupts:
367+
print('**Interrupted** Partial results will be returned.')
368+
else:
369+
raise KeyboardInterrupt
375370

376371
toc = time.time()
377372
print('>> Usage stats:')
378373
print(
379374
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
380375
)
381-
if torch.cuda.is_available() and self.device.type == 'cuda':
376+
if self._has_cuda():
382377
print(
383378
f'>> Max VRAM used for this generation:',
384379
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
385-
'Current VRAM utilization:'
380+
'Current VRAM utilization:',
386381
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
387382
)
388383

@@ -449,8 +444,7 @@ def load_model(self):
449444
if self.model is None:
450445
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
451446
try:
452-
config = OmegaConf.load(self.config)
453-
model = self._load_model_from_config(config, self.weights)
447+
model = self._load_model_from_config(self.config, self.weights)
454448
if self.embedding_path is not None:
455449
model.embedding_manager.load(
456450
self.embedding_path, self.full_precision
@@ -551,8 +545,11 @@ def _set_sampler(self):
551545

552546
print(msg)
553547

554-
def _load_model_from_config(self, config, ckpt):
555-
print(f'>> Loading model from {ckpt}')
548+
# Be warned: config is the path to the model config file, not the dream conf file!
549+
# Also note that we can get config and weights from self, so why do we need to
550+
# pass them as args?
551+
def _load_model_from_config(self, config, weights):
552+
print(f'>> Loading model from {weights}')
556553

557554
# for usage statistics
558555
device_type = choose_torch_device()
@@ -561,10 +558,11 @@ def _load_model_from_config(self, config, ckpt):
561558
tic = time.time()
562559

563560
# this does the work
564-
pl_sd = torch.load(ckpt, map_location='cpu')
565-
sd = pl_sd['state_dict']
566-
model = instantiate_from_config(config.model)
567-
m, u = model.load_state_dict(sd, strict=False)
561+
c = OmegaConf.load(config)
562+
pl_sd = torch.load(weights, map_location='cpu')
563+
sd = pl_sd['state_dict']
564+
model = instantiate_from_config(c.model)
565+
m, u = model.load_state_dict(sd, strict=False)
568566

569567
if self.full_precision:
570568
print(
@@ -583,7 +581,7 @@ def _load_model_from_config(self, config, ckpt):
583581
print(
584582
f'>> Model loaded in', '%4.2fs' % (toc - tic)
585583
)
586-
if device_type == 'cuda':
584+
if self._has_cuda():
587585
print(
588586
'>> Max VRAM used to load the model:',
589587
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
@@ -720,3 +718,5 @@ def _resolution_check(self, width, height, log=False):
720718
return width, height, resize_needed
721719

722720

721+
def _has_cuda(self):
722+
return self.device.type == 'cuda'

ldm/models/diffusion/ddim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def ddim_sampling(
225225
total_steps = (
226226
timesteps if ddim_use_original_steps else timesteps.shape[0]
227227
)
228-
print(f'Running DDIM Sampling with {total_steps} timesteps')
228+
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
229229

230230
iterator = tqdm(
231231
time_range,

0 commit comments

Comments
 (0)