Skip to content

Commit 01a865a

Browse files
committed
Initial commit
0 parents  commit 01a865a

13 files changed

+627
-0
lines changed

.gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
venv*
2+
__pycache__
3+
.ipynb_checkpoints
4+
*.pth
5+
out*
6+
*.egg-info
7+
*.ini

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "CLIP"]
2+
path = CLIP
3+
url = https://github.com/openai/CLIP

CLIP

Submodule CLIP added at 573315e

LICENSE

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Copyright (c) 2021 Katherine Crowson and John David Pressman
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy
4+
of this software and associated documentation files (the "Software"), to deal
5+
in the Software without restriction, including without limitation the rights
6+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
copies of the Software, and to permit persons to whom the Software is
8+
furnished to do so, subject to the following conditions:
9+
10+
The above copyright notice and this permission notice shall be included in
11+
all copies or substantial portions of the Software.
12+
13+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
THE SOFTWARE.

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# v-diffusion-pytorch

clip_sample.py

+130
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#!/usr/bin/env python3
2+
3+
"""CLIP guided sampling from a diffusion model."""
4+
5+
import argparse
6+
from pathlib import Path
7+
8+
import torch
9+
from torch import nn
10+
from torch.nn import functional as F
11+
from torchvision import transforms
12+
from tqdm import trange
13+
14+
from CLIP import clip
15+
from diffusion import get_model, get_models, sampling, utils
16+
17+
MODULE_DIR = Path(__file__).resolve().parent
18+
19+
20+
class MakeCutouts(nn.Module):
21+
def __init__(self, cut_size, cutn, cut_pow=1.):
22+
super().__init__()
23+
self.cut_size = cut_size
24+
self.cutn = cutn
25+
self.cut_pow = cut_pow
26+
27+
def forward(self, input):
28+
sideY, sideX = input.shape[2:4]
29+
max_size = min(sideX, sideY)
30+
min_size = min(sideX, sideY, self.cut_size)
31+
cutouts = []
32+
for _ in range(self.cutn):
33+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
34+
offsetx = torch.randint(0, sideX - size + 1, ())
35+
offsety = torch.randint(0, sideY - size + 1, ())
36+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
37+
cutout = F.adaptive_avg_pool2d(cutout, self.cut_size)
38+
cutouts.append(cutout)
39+
return torch.cat(cutouts)
40+
41+
42+
def spherical_dist_loss(x, y):
43+
x = F.normalize(x, dim=-1)
44+
y = F.normalize(y, dim=-1)
45+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
46+
47+
48+
def main():
49+
p = argparse.ArgumentParser(description=__doc__,
50+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
51+
p.add_argument('prompt', type=str,
52+
help='the text prompt')
53+
p.add_argument('--batch-size', '-bs', type=int, default=1,
54+
help='the number of images per batch')
55+
p.add_argument('--checkpoint', type=str,
56+
help='the checkpoint to use')
57+
p.add_argument('--clip-guidance-scale', '-cs', type=float, default=500.,
58+
help='the CLIP guidance scale')
59+
p.add_argument('--device', type=str,
60+
help='the device to use')
61+
p.add_argument('--eta', type=float, default=1.,
62+
help='the amount of noise to add during sampling (0-1)')
63+
p.add_argument('--model', type=str, default='cc12m_1', choices=get_models(),
64+
help='the model to use')
65+
p.add_argument('-n', type=int, default=1,
66+
help='the number of images to sample')
67+
p.add_argument('--seed', type=int, default=0,
68+
help='the random seed')
69+
p.add_argument('--steps', type=int, default=1000,
70+
help='the number of timesteps')
71+
args = p.parse_args()
72+
73+
if args.device:
74+
device = torch.device(args.device)
75+
else:
76+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
77+
print('Using device:', device)
78+
79+
model = get_model(args.model)()
80+
checkpoint = args.checkpoint
81+
if not checkpoint:
82+
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
83+
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
84+
if device.type == 'cuda':
85+
model = model.half()
86+
model = model.to(device).eval().requires_grad_(False)
87+
clip_model = clip.load(model.clip_model, jit=False, device=device)[0]
88+
clip_model.eval().requires_grad_(False)
89+
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
90+
std=[0.26862954, 0.26130258, 0.27577711])
91+
cutn = 16
92+
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, cutn=cutn, cut_pow=1)
93+
94+
clip_embed = clip_model.encode_text(clip.tokenize(args.prompt).to(device))
95+
clip_embed = clip_embed.repeat([args.n, 1])
96+
97+
torch.manual_seed(args.seed)
98+
99+
def cond_fn(x, t, pred, clip_embed):
100+
clip_in = normalize(make_cutouts((pred + 1) / 2))
101+
image_embeds = clip_model.encode_image(clip_in).view([cutn, x.shape[0], -1])
102+
losses = spherical_dist_loss(image_embeds, clip_embed[None])
103+
loss = losses.mean(0).sum() * args.clip_guidance_scale
104+
grad = -torch.autograd.grad(loss, x)[0]
105+
return grad
106+
107+
def run(x, clip_embed):
108+
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
109+
steps = utils.get_spliced_ddpm_cosine_schedule(t)
110+
extra_args = {'clip_embed': clip_embed}
111+
if not args.clip_guidance_scale:
112+
return sampling.sample(model, x, steps, args.eta, extra_args)
113+
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn)
114+
115+
def run_all(n, batch_size):
116+
x = torch.randn([args.n, *model.shape], device=device)
117+
for i in trange(0, n, batch_size):
118+
cur_batch_size = min(n - i, batch_size)
119+
outs = run(x[i:i+cur_batch_size], clip_embed[i:i+cur_batch_size])
120+
for j, out in enumerate(outs):
121+
utils.to_pil_image(out).save(f'out_{i + j:05}.png')
122+
123+
try:
124+
run_all(args.n, args.batch_size)
125+
except KeyboardInterrupt:
126+
pass
127+
128+
129+
if __name__ == '__main__':
130+
main()

diffusion/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import sampling, utils
2+
from .models import get_model, get_models

diffusion/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .models import get_model, get_models

0 commit comments

Comments
 (0)