Skip to content

Commit 4839075

Browse files
author
Dazhi Zhong
committed
change callback_fn
1 parent 40f4bd2 commit 4839075

File tree

3 files changed

+12
-32
lines changed

3 files changed

+12
-32
lines changed

Diff for: cfg_sample.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def resize_and_center_crop(image, size):
4848
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
4949
return TF.center_crop(image, size[::-1])
5050

51+
def callback_fn(info):
52+
if info['i'] % 50==0 or info['i']==args.steps:
53+
out = info['pred'].add(1).div(2)
54+
save_image(out, f"interm_output_{info['i']:05d}.png")
55+
if IS_NOTEBOOK:
56+
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))
5157

5258
def main():
5359
p = argparse.ArgumentParser(description=__doc__,
@@ -132,13 +138,6 @@ def main():
132138

133139
torch.manual_seed(args.seed)
134140

135-
def callback_fn(pred, i):
136-
if i % 50==0 or i==args.steps:
137-
out = pred.add(1).div(2)
138-
save_image(out, f"interm_output_{i:05d}.png")
139-
if IS_NOTEBOOK:
140-
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))
141-
142141
def cfg_model_fn(x, t):
143142
n = x.shape[0]
144143
n_conds = len(target_embeds)
@@ -235,13 +234,6 @@ def run_diffusion_cfg(prompts,images=None,steps=1000,init=None,model="cc12m_1_cf
235234

236235
torch.manual_seed(args.seed)
237236

238-
def callback_fn(pred, i):
239-
if i % display_freq==0 or i==args.steps:
240-
out = pred.add(1).div(2)
241-
save_image(out, f"interm_output_{i:05d}.png")
242-
if IS_NOTEBOOK:
243-
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))
244-
245237
def cfg_model_fn(x, t):
246238
n = x.shape[0]
247239
n_conds = len(target_embeds)

Diff for: clip_sample.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def resize_and_center_crop(image, size):
7777
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
7878
return TF.center_crop(image, size[::-1])
7979

80+
def callback_fn(info):
81+
if info['i'] % 50==0 or info['i']==args.steps:
82+
out = info['pred'].add(1).div(2)
83+
save_image(out, f"interm_output_{info['i']:05d}.png")
84+
if IS_NOTEBOOK:
85+
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))
8086

8187
def main():
8288
p = argparse.ArgumentParser(description=__doc__,
@@ -176,13 +182,6 @@ def main():
176182

177183
torch.manual_seed(args.seed)
178184

179-
def callback_fn(pred, i):
180-
if i % 50==0 or i==args.steps:
181-
out = pred.add(1).div(2)
182-
save_image(out, f"interm_output_{i:05d}.png")
183-
if IS_NOTEBOOK:
184-
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))
185-
186185
def cond_fn(x, t, pred, clip_embed):
187186
clip_in = normalize(make_cutouts((pred + 1) / 2))
188187
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
@@ -295,13 +294,6 @@ def run_diffusion(prompts,images=None,steps=1000,init=None,model="yfcc_2",size=[
295294

296295
torch.manual_seed(args.seed)
297296

298-
def callback_fn(pred, i):
299-
if i % display_freq==0 or i==args.steps:
300-
out = pred.add(1).div(2)
301-
save_image(out, f"interm_output_{i:05d}.png")
302-
if IS_NOTEBOOK:
303-
display.display(display.Image(f"interm_output_{i:05d}.png",height=300))
304-
305297
def cond_fn(x, t, pred, clip_embed):
306298
clip_in = normalize(make_cutouts((pred + 1) / 2))
307299
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])

Diff for: diffusion/sampling.py

-4
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ def sample(model, x, steps, eta, extra_args, callback=None):
4444
if eta:
4545
x += torch.randn_like(x) * ddim_sigma
4646

47-
if callback_fn:
48-
callback_fn(pred,i)
4947

5048
# If we are on the last timestep, output the denoised image
5149
return pred
@@ -101,8 +99,6 @@ def cond_sample(model, x, steps, eta, extra_args, cond_fn, callback=None):
10199
if eta:
102100
x += torch.randn_like(x) * ddim_sigma
103101

104-
if callback_fn:
105-
callback_fn(pred,i)
106102

107103
# If we are on the last timestep, output the denoised image
108104
return pred

0 commit comments

Comments
 (0)