Skip to content

Commit e6179af

Browse files
authored
Refactor generate.py and dream.py (#534)
* revert inadvertent change of conda env name (#528) * Refactor generate.py and dream.py * config file path (models.yaml) is parsed inside Generate() to simplify API * Better handling of keyboard interrupts in file loading mode vs interactive * Removed oodles of unused variables. * move nonfunctional inpainting out of the scripts directory * fix ugly ddim tqdm formatting
1 parent d15c75e commit e6179af

File tree

4 files changed

+142
-163
lines changed

4 files changed

+142
-163
lines changed

ldm/generate.py

Lines changed: 103 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from omegaconf import OmegaConf
1818
from PIL import Image, ImageOps
1919
from torch import nn
20-
from pytorch_lightning import seed_everything
20+
from pytorch_lightning import seed_everything, logging
2121

2222
from ldm.util import instantiate_from_config
2323
from ldm.models.diffusion.ddim import DDIMSampler
@@ -35,7 +35,7 @@
3535
from ldm.generate import Generate
3636
3737
# Create an object with default values
38-
gr = Generate()
38+
gr = Generate('stable-diffusion-1.4')
3939
4040
# do the slow model initialization
4141
gr.load_model()
@@ -79,16 +79,17 @@
7979
8080
The full list of arguments to Generate() are:
8181
gr = Generate(
82+
# these values are set once and shouldn't be changed
83+
conf = path to configuration file ('configs/models.yaml')
84+
model = symbolic name of the model in the configuration file
85+
full_precision = False
86+
87+
# this value is sticky and maintained between generation calls
88+
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
89+
90+
# these are deprecated - use conf and model instead
8291
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
83-
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
84-
iterations = <integer> // how many times to run the sampling (1)
85-
steps = <integer> // 50
86-
seed = <integer> // current system time
87-
sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
88-
grid = <boolean> // false
89-
width = <integer> // image width, multiple of 64 (512)
90-
height = <integer> // image height, multiple of 64 (512)
91-
cfg_scale = <float> // condition-free guidance scale (7.5)
92+
config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml')
9293
)
9394
9495
"""
@@ -101,66 +102,62 @@ class Generate:
101102

102103
def __init__(
103104
self,
104-
iterations = 1,
105-
steps = 50,
106-
cfg_scale = 7.5,
107-
weights = 'models/ldm/stable-diffusion-v1/model.ckpt',
108-
config = 'configs/stable-diffusion/v1-inference.yaml',
109-
grid = False,
110-
width = 512,
111-
height = 512,
105+
model = 'stable-diffusion-1.4',
106+
conf = 'configs/models.yaml',
107+
embedding_path = None,
112108
sampler_name = 'k_lms',
113109
ddim_eta = 0.0, # deterministic
114110
full_precision = False,
115-
strength = 0.75, # default in scripts/img2img.py
116-
seamless = False,
117-
embedding_path = None,
118-
device_type = 'cuda',
119-
ignore_ctrl_c = False,
111+
# these are deprecated; if present they override values in the conf file
112+
weights = None,
113+
config = None,
120114
):
121-
self.iterations = iterations
122-
self.width = width
123-
self.height = height
124-
self.steps = steps
125-
self.cfg_scale = cfg_scale
126-
self.weights = weights
127-
self.config = config
128-
self.sampler_name = sampler_name
129-
self.grid = grid
130-
self.ddim_eta = ddim_eta
131-
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
132-
self.strength = strength
133-
self.seamless = seamless
134-
self.embedding_path = embedding_path
135-
self.device_type = device_type
136-
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
137-
self.model = None # empty for now
138-
self.sampler = None
139-
self.device = None
140-
self.generators = {}
141-
self.base_generator = None
142-
self.seed = None
143-
144-
if device_type == 'cuda' and not torch.cuda.is_available():
145-
device_type = choose_torch_device()
146-
print(">> cuda not available, using device", device_type)
115+
models = OmegaConf.load(conf)
116+
mconfig = models[model]
117+
self.weights = mconfig.weights if weights is None else weights
118+
self.config = mconfig.config if config is None else config
119+
self.height = mconfig.height
120+
self.width = mconfig.width
121+
self.iterations = 1
122+
self.steps = 50
123+
self.cfg_scale = 7.5
124+
self.sampler_name = sampler_name
125+
self.ddim_eta = 0.0 # same seed always produces same image
126+
self.full_precision = True if choose_torch_device() == 'mps' else full_precision
127+
self.strength = 0.75
128+
self.seamless = False
129+
self.embedding_path = embedding_path
130+
self.model = None # empty for now
131+
self.sampler = None
132+
self.device = None
133+
self.session_peakmem = None
134+
self.generators = {}
135+
self.base_generator = None
136+
self.seed = None
137+
138+
# Note that in previous versions, there was an option to pass the
139+
# device to Generate(). However the device was then ignored, so
140+
# it wasn't actually doing anything. This logic could be reinstated.
141+
device_type = choose_torch_device()
147142
self.device = torch.device(device_type)
148143

149144
# for VRAM usage statistics
150-
device_type = choose_torch_device()
151-
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
145+
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
152146
transformers.logging.set_verbosity_error()
153147

148+
# gets rid of annoying messages about random seed
149+
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
150+
154151
def prompt2png(self, prompt, outdir, **kwargs):
155152
"""
156153
Takes a prompt and an output directory, writes out the requested number
157154
of PNG files, and returns an array of [[filename,seed],[filename,seed]...]
158155
Optional named arguments are the same as those passed to Generate and prompt2image()
159156
"""
160-
results = self.prompt2image(prompt, **kwargs)
157+
results = self.prompt2image(prompt, **kwargs)
161158
pngwriter = PngWriter(outdir)
162-
prefix = pngwriter.unique_prefix()
163-
outputs = []
159+
prefix = pngwriter.unique_prefix()
160+
outputs = []
164161
for image, seed in results:
165162
name = f'{prefix}.{seed}.png'
166163
path = pngwriter.save_image_and_prompt_to_png(
@@ -183,33 +180,35 @@ def prompt2image(
183180
self,
184181
# these are common
185182
prompt,
186-
iterations = None,
187-
steps = None,
188-
seed = None,
189-
cfg_scale = None,
190-
ddim_eta = None,
191-
skip_normalize = False,
192-
image_callback = None,
193-
step_callback = None,
194-
width = None,
195-
height = None,
196-
sampler_name = None,
197-
seamless = False,
198-
log_tokenization= False,
199-
with_variations = None,
200-
variation_amount = 0.0,
183+
iterations = None,
184+
steps = None,
185+
seed = None,
186+
cfg_scale = None,
187+
ddim_eta = None,
188+
skip_normalize = False,
189+
image_callback = None,
190+
step_callback = None,
191+
width = None,
192+
height = None,
193+
sampler_name = None,
194+
seamless = False,
195+
log_tokenization = False,
196+
with_variations = None,
197+
variation_amount = 0.0,
201198
# these are specific to img2img and inpaint
202-
init_img = None,
203-
init_mask = None,
204-
fit = False,
205-
strength = None,
199+
init_img = None,
200+
init_mask = None,
201+
fit = False,
202+
strength = None,
206203
# these are specific to embiggen (which also relies on img2img args)
207204
embiggen = None,
208205
embiggen_tiles = None,
209206
# these are specific to GFPGAN/ESRGAN
210-
gfpgan_strength= 0,
211-
save_original = False,
212-
upscale = None,
207+
gfpgan_strength = 0,
208+
save_original = False,
209+
upscale = None,
210+
# Set this True to handle KeyboardInterrupt internally
211+
catch_interrupts = False,
213212
**args,
214213
): # eat up additional cruft
215214
"""
@@ -262,10 +261,9 @@ def process_image(image,seed):
262261
self.log_tokenization = log_tokenization
263262
with_variations = [] if with_variations is None else with_variations
264263

265-
model = (
266-
self.load_model()
267-
) # will instantiate the model or return it from cache
268-
264+
# will instantiate the model or return it from cache
265+
model = self.load_model()
266+
269267
for m in model.modules():
270268
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
271269
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
@@ -281,7 +279,6 @@ def process_image(image,seed):
281279
(embiggen == None and embiggen_tiles == None) or ((embiggen != None or embiggen_tiles != None) and init_img != None)
282280
), 'Embiggen requires an init/input image to be specified'
283281

284-
# check this logic - doesn't look right
285282
if len(with_variations) > 0 or variation_amount > 1.0:
286283
assert seed is not None,\
287284
'seed must be specified when using with_variations'
@@ -298,7 +295,7 @@ def process_image(image,seed):
298295
self._set_sampler()
299296

300297
tic = time.time()
301-
if torch.cuda.is_available():
298+
if self._has_cuda():
302299
torch.cuda.reset_peak_memory_stats()
303300

304301
results = list()
@@ -307,9 +304,9 @@ def process_image(image,seed):
307304

308305
try:
309306
uc, c = get_uc_and_c(
310-
prompt, model=self.model,
307+
prompt, model =self.model,
311308
skip_normalize=skip_normalize,
312-
log_tokens=self.log_tokenization
309+
log_tokens =self.log_tokenization
313310
)
314311

315312
(init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit)
@@ -352,27 +349,25 @@ def process_image(image,seed):
352349
save_original = save_original,
353350
image_callback = image_callback)
354351

355-
except KeyboardInterrupt:
356-
print('*interrupted*')
357-
if not self.ignore_ctrl_c:
358-
raise KeyboardInterrupt
359-
print(
360-
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
361-
)
362352
except RuntimeError as e:
363353
print(traceback.format_exc(), file=sys.stderr)
364354
print('>> Could not generate image.')
355+
except KeyboardInterrupt:
356+
if catch_interrupts:
357+
print('**Interrupted** Partial results will be returned.')
358+
else:
359+
raise KeyboardInterrupt
365360

366361
toc = time.time()
367362
print('>> Usage stats:')
368363
print(
369364
f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic)
370365
)
371-
if torch.cuda.is_available() and self.device.type == 'cuda':
366+
if self._has_cuda():
372367
print(
373368
f'>> Max VRAM used for this generation:',
374369
'%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9),
375-
'Current VRAM utilization:'
370+
'Current VRAM utilization:',
376371
'%4.2fG' % (torch.cuda.memory_allocated() / 1e9),
377372
)
378373

@@ -439,8 +434,7 @@ def load_model(self):
439434
if self.model is None:
440435
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
441436
try:
442-
config = OmegaConf.load(self.config)
443-
model = self._load_model_from_config(config, self.weights)
437+
model = self._load_model_from_config(self.config, self.weights)
444438
if self.embedding_path is not None:
445439
model.embedding_manager.load(
446440
self.embedding_path, self.full_precision
@@ -541,8 +535,11 @@ def _set_sampler(self):
541535

542536
print(msg)
543537

544-
def _load_model_from_config(self, config, ckpt):
545-
print(f'>> Loading model from {ckpt}')
538+
# Be warned: config is the path to the model config file, not the dream conf file!
539+
# Also note that we can get config and weights from self, so why do we need to
540+
# pass them as args?
541+
def _load_model_from_config(self, config, weights):
542+
print(f'>> Loading model from {weights}')
546543

547544
# for usage statistics
548545
device_type = choose_torch_device()
@@ -551,10 +548,11 @@ def _load_model_from_config(self, config, ckpt):
551548
tic = time.time()
552549

553550
# this does the work
554-
pl_sd = torch.load(ckpt, map_location='cpu')
555-
sd = pl_sd['state_dict']
556-
model = instantiate_from_config(config.model)
557-
m, u = model.load_state_dict(sd, strict=False)
551+
c = OmegaConf.load(config)
552+
pl_sd = torch.load(weights, map_location='cpu')
553+
sd = pl_sd['state_dict']
554+
model = instantiate_from_config(c.model)
555+
m, u = model.load_state_dict(sd, strict=False)
558556

559557
if self.full_precision:
560558
print(
@@ -573,7 +571,7 @@ def _load_model_from_config(self, config, ckpt):
573571
print(
574572
f'>> Model loaded in', '%4.2fs' % (toc - tic)
575573
)
576-
if device_type == 'cuda':
574+
if self._has_cuda():
577575
print(
578576
'>> Max VRAM used to load the model:',
579577
'%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9),
@@ -710,3 +708,5 @@ def _resolution_check(self, width, height, log=False):
710708
return width, height, resize_needed
711709

712710

711+
def _has_cuda(self):
712+
return self.device.type == 'cuda'

ldm/models/diffusion/ddim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def ddim_sampling(
225225
total_steps = (
226226
timesteps if ddim_use_original_steps else timesteps.shape[0]
227227
)
228-
print(f'Running DDIM Sampling with {total_steps} timesteps')
228+
print(f'\nRunning DDIM Sampling with {total_steps} timesteps')
229229

230230
iterator = tqdm(
231231
time_range,

0 commit comments

Comments
 (0)