Skip to content

Commit b47b8a0

Browse files
author
Dazhi Zhong
committed
Merge branch 'master' of https://github.com/crowsonkb/v-diffusion-pytorch into addcolabs
2 parents 95adf47 + bbc5275 commit b47b8a0

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

cfg_sample.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def run(x, steps):
258258
return sampling.sample(cfg_model_fn, x, steps, args.eta, {}, callback=callback_fn)
259259

260260
def run_all(n, batch_size):
261-
x = torch.randn([args.n, 3, side_y, side_x], device=device)
261+
x = torch.randn([n, 3, side_y, side_x], device=device)
262262
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
263263
steps = utils.get_spliced_ddpm_cosine_schedule(t)
264264
if args.init:

clip_sample.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def run(x, steps, clip_embed):
324324
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn_, callback=callback_fn)
325325

326326
def run_all(n, batch_size):
327-
x = torch.randn([args.n, 3, side_y, side_x], device=device)
327+
x = torch.randn([n, 3, side_y, side_x], device=device)
328328
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
329329
steps = utils.get_spliced_ddpm_cosine_schedule(t)
330330
if args.init:

diffusion/sampling.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ def sample(model, x, steps, eta, extra_args, callback=None):
2323
pred = x * alphas[i] - v * sigmas[i]
2424
eps = x * sigmas[i] + v * alphas[i]
2525

26+
# Call the callback
27+
if callback is not None:
28+
callback({'x': x, 'i': i, 't': steps[i], 'v': v, 'pred': pred})
29+
2630
# If we are not on the last timestep, compute the noisy image for the
2731
# next timestep.
2832
if i < len(steps) - 1:
@@ -64,8 +68,13 @@ def cond_sample(model, x, steps, eta, extra_args, cond_fn, callback=None):
6468
with torch.cuda.amp.autocast():
6569
v = model(x, ts * steps[i], **extra_args)
6670

71+
pred = x * alphas[i] - v * sigmas[i]
72+
73+
# Call the callback
74+
if callback is not None:
75+
callback({'x': x, 'i': i, 't': steps[i], 'v': v.detach(), 'pred': pred.detach()})
76+
6777
if steps[i] < 1:
68-
pred = x * alphas[i] - v * sigmas[i]
6978
cond_grad = cond_fn(x, ts * steps[i], pred, **extra_args).detach()
7079
v = v.detach() - cond_grad * (sigmas[i] / alphas[i])
7180
else:

0 commit comments

Comments
 (0)