Skip to content

add colab callback function for display, auto download model checkpoints, and colab notebook #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 121 additions & 4 deletions cfg_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,30 @@
import argparse
from functools import partial
from pathlib import Path

import os
from types import SimpleNamespace
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from torchvision.utils import save_image
from tqdm.auto import trange

from CLIP import clip
from diffusion import get_model, get_models, sampling, utils
from diffusion import get_model, get_models, sampling, utils, download_model

def isnotebook():
try:
shell = get_ipython().__class__.__name__
return shell=='ZMQInteractiveShell' or shell=='Shell'
except NameError:
return False
IS_NOTEBOOK = isnotebook()
if IS_NOTEBOOK:
from IPython import display


MODULE_DIR = Path(__file__).resolve().parent

Expand All @@ -35,6 +48,12 @@ def resize_and_center_crop(image, size):
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])

def callback_fn(info):
if info['i'] % 50==0:
out = info['pred'].add(1).div(2)
save_image(out, f"interm_output_{info['i']:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))

def main():
p = argparse.ArgumentParser(description=__doc__,
Expand Down Expand Up @@ -80,6 +99,8 @@ def main():
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
if not os.path.isfile(checkpoint):
download_model(args.model, checkpoint)
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
if device.type == 'cuda':
model = model.half()
Expand Down Expand Up @@ -128,7 +149,103 @@ def cfg_model_fn(x, t):
return v

def run(x, steps):
return sampling.sample(cfg_model_fn, x, steps, args.eta, {})
return sampling.sample(cfg_model_fn, x, steps, args.eta, {}, callback=callback_fn)

def run_all(n, batch_size):
x = torch.randn([args.n, 3, side_y, side_x], device=device)
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
steps = utils.get_spliced_ddpm_cosine_schedule(t)
if args.init:
steps = steps[steps < args.starting_timestep]
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
x = init * alpha + x * sigma
for i in trange(0, n, batch_size):
cur_batch_size = min(n - i, batch_size)
outs = run(x[i:i+cur_batch_size], steps)
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')

try:
run_all(args.n, args.batch_size)
except KeyboardInterrupt:
pass


def run_diffusion_cfg(prompts,images=None,steps=1000,init=None,model="cc12m_1_cfg",size=[512,512], checkpoint=None, device=None, eta=1.0, n=1, seed=42,starting_timestep=0.9, batch_size=1,display_freq=50):

args = SimpleNamespace(prompts=prompts,images=images,steps=steps,init=init,model=model,size=size, checkpoint=checkpoint, device=device, eta=eta, n=n, seed=seed,starting_timestep=starting_timestep, batch_size=batch_size)
print(args)

if args.device:
device = torch.device(args.device)
else:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = get_model(args.model)()
_, side_y, side_x = model.shape
if args.size:
side_x, side_y = args.size
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
if not os.path.isfile(checkpoint):
download_model(args.model, checkpoint)
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
if device.type == 'cuda':
model = model.half()
model = model.to(device).eval().requires_grad_(False)
clip_model_name = model.clip_model if hasattr(model, 'clip_model') else 'ViT-B/16'
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])

if args.init:
init = Image.open(utils.fetch(args.init)).convert('RGB')
init = resize_and_center_crop(init, (side_x, side_y))
init = utils.from_pil_image(init).cuda()[None].repeat([args.n, 1, 1, 1])

zero_embed = torch.zeros([1, clip_model.visual.output_dim], device=device)
target_embeds, weights = [zero_embed], []

if args.prompts:
if isinstance(args.prompts, str):
args.prompts = [args.prompts,]
for prompt in args.prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
weights.append(weight)

if args.images:
if isinstance(args.images, str):
args.images = [args.images,]
for prompt in args.images:
path, weight = parse_prompt(prompt)
img = Image.open(utils.fetch(path)).convert('RGB')
clip_size = clip_model.visual.input_resolution
img = resize_and_center_crop(img, (clip_size, clip_size))
batch = TF.to_tensor(img)[None].to(device)
embed = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1)
target_embeds.append(embed)
weights.append(weight)

weights = torch.tensor([1 - sum(weights), *weights], device=device)

torch.manual_seed(args.seed)

def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
return v

def run(x, steps):
return sampling.sample(cfg_model_fn, x, steps, args.eta, {}, callback=callback_fn)

def run_all(n, batch_size):
x = torch.randn([n, 3, side_y, side_x], device=device)
Expand Down
146 changes: 140 additions & 6 deletions clip_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,34 @@
import argparse
from functools import partial
from pathlib import Path

import os
from types import SimpleNamespace
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm import trange
from torchvision.utils import save_image
from tqdm.auto import trange

from CLIP import clip
from diffusion import get_model, get_models, sampling, utils
from diffusion import get_model, get_models, sampling, utils, download_model


def isnotebook():
try:
shell = get_ipython().__class__.__name__
return shell=='ZMQInteractiveShell' or shell=='Shell'
except NameError:
return False
IS_NOTEBOOK = isnotebook()
if IS_NOTEBOOK:
from IPython import display


MODULE_DIR = Path(__file__).resolve().parent

MODULE_DIR = Path(__file__).resolve().parent

class MakeCutouts(nn.Module):
def __init__(self, cut_size, cutn, cut_pow=1.):
Expand Down Expand Up @@ -63,6 +77,12 @@ def resize_and_center_crop(image, size):
image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
return TF.center_crop(image, size[::-1])

def callback_fn(info):
if info['i'] % 50==0:
out = info['pred'].add(1).div(2)
save_image(out, f"interm_output_{info['i']:05d}.png")
if IS_NOTEBOOK:
display.display(display.Image(f"interm_output_{info['i']:05d}.png",height=300))

def main():
p = argparse.ArgumentParser(description=__doc__,
Expand Down Expand Up @@ -114,6 +134,8 @@ def main():
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
if not os.path.isfile(checkpoint):
download_model(args.model, checkpoint)
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
if device.type == 'cuda':
model = model.half()
Expand Down Expand Up @@ -176,8 +198,120 @@ def run(x, steps, clip_embed):
extra_args = {}
cond_fn_ = partial(cond_fn, clip_embed=clip_embed)
if not args.clip_guidance_scale:
return sampling.sample(model, x, steps, args.eta, extra_args)
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn_)
return sampling.sample(model, x, steps, args.eta, extra_args, callback=callback_fn)
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn_, callback=callback_fn)

def run_all(n, batch_size):
x = torch.randn([args.n, 3, side_y, side_x], device=device)
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
steps = utils.get_spliced_ddpm_cosine_schedule(t)
if args.init:
steps = steps[steps < args.starting_timestep]
alpha, sigma = utils.t_to_alpha_sigma(steps[0])
x = init * alpha + x * sigma
for i in trange(0, n, batch_size):
cur_batch_size = min(n - i, batch_size)
outs = run(x[i:i+cur_batch_size], steps, clip_embed[i:i+cur_batch_size])
for j, out in enumerate(outs):
utils.to_pil_image(out).save(f'out_{i + j:05}.png')

try:
run_all(args.n, args.batch_size)
except KeyboardInterrupt:
pass



def run_diffusion(prompts,images=None,steps=1000,init=None,model="yfcc_2",size=[512,512], checkpoint=None, clip_guidance_scale=500, cutn=16, cut_pow=1, device=None, eta=1.0, n=1, seed=42,starting_timestep=0.9, batch_size=1,display_freq=50):

args = SimpleNamespace(prompts=prompts,images=images,steps=steps,init=init,model=model,size=size, checkpoint=checkpoint, clip_guidance_scale=clip_guidance_scale, cutn=cutn, cut_pow=cut_pow, device=device, eta=eta, n=n, seed=seed,starting_timestep=starting_timestep, batch_size=batch_size)
print(args)
if args.device:
device = torch.device(args.device)
else:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model = get_model(args.model)()
_, side_y, side_x = model.shape
if args.size:
side_x, side_y = args.size
checkpoint = args.checkpoint
if not checkpoint:
checkpoint = MODULE_DIR / f'checkpoints/{args.model}.pth'
if not os.path.isfile(checkpoint):
download_model(args.model, checkpoint)
model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
if device.type == 'cuda':
model = model.half()
model = model.to(device).eval().requires_grad_(False)
clip_model_name = model.clip_model if hasattr(model, 'clip_model') else 'ViT-B/16'
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
make_cutouts = MakeCutouts(clip_model.visual.input_resolution, args.cutn, args.cut_pow)

if args.init:
print(args.init)
init = Image.open(utils.fetch(args.init)).convert('RGB')
init = resize_and_center_crop(init, (side_x, side_y))
init = utils.from_pil_image(init).cuda()[None].repeat([args.n, 1, 1, 1])

target_embeds, weights = [], []

if args.prompts:
if isinstance(args.prompts, str):
args.prompts = [args.prompts,]
for prompt in args.prompts:
txt, weight = parse_prompt(prompt)
target_embeds.append(clip_model.encode_text(clip.tokenize(txt).to(device)).float())
weights.append(weight)

if args.images:
if isinstance(args.images, str):
args.images = [args.images,]
for prompt in args.images:
path, weight = parse_prompt(prompt)
img = Image.open(utils.fetch(path)).convert('RGB')
img = TF.resize(img, min(side_x, side_y, *img.size),
transforms.InterpolationMode.LANCZOS)
batch = make_cutouts(TF.to_tensor(img)[None].to(device))
embeds = F.normalize(clip_model.encode_image(normalize(batch)).float(), dim=-1)
target_embeds.append(embeds)
weights.extend([weight / args.cutn] * args.cutn)

if not target_embeds:
raise RuntimeError('At least one text or image prompt must be specified.')
target_embeds = torch.cat(target_embeds)
weights = torch.tensor(weights, device=device)
if weights.sum().abs() < 1e-3:
raise RuntimeError('The weights must not sum to 0.')
weights /= weights.sum().abs()

clip_embed = F.normalize(target_embeds.mul(weights[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = clip_embed.repeat([args.n, 1])

torch.manual_seed(args.seed)

def cond_fn(x, t, pred, clip_embed):
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
loss = losses.mean(0).sum() * args.clip_guidance_scale
grad = -torch.autograd.grad(loss, x)[0]
return grad

def run(x, steps, clip_embed):
if hasattr(model, 'clip_model'):
extra_args = {'clip_embed': clip_embed}
cond_fn_ = cond_fn
else:
extra_args = {}
cond_fn_ = partial(cond_fn, clip_embed=clip_embed)
if not args.clip_guidance_scale:
return sampling.sample(model, x, steps, args.eta, extra_args, callback=callback_fn)
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn_, callback=callback_fn)

def run_all(n, batch_size):
x = torch.randn([n, 3, side_y, side_x], device=device)
Expand Down
2 changes: 1 addition & 1 deletion diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import sampling, utils
from .models import get_model, get_models
from .models import get_model, get_models, download_model
2 changes: 1 addition & 1 deletion diffusion/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .models import get_model, get_models
from .models import get_model, get_models, download_model
24 changes: 23 additions & 1 deletion diffusion/models/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from . import cc12m_1, yfcc_1, yfcc_2

import requests
import shutil
import os

models = {
'cc12m_1': cc12m_1.CC12M1Model,
Expand All @@ -8,10 +10,30 @@
'yfcc_2': yfcc_2.YFCC2Model,
}

model_to_url = {
"yfcc_2":"https://v-diffusion.s3.us-west-2.amazonaws.com/yfcc_2.pth",
"yfcc_1":"https://v-diffusion.s3.us-west-2.amazonaws.com/yfcc_1.pth",
"cc12m_1":"https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1.pth",
"cc12m_1_cfg":"https://v-diffusion.s3.us-west-2.amazonaws.com/cc12m_1_cfg.pth",
}


def get_model(model):
return models[model]


def get_models():
return list(models.keys())



def download_model(model_name, file_path=None):
model_url = model_to_url[model_name]
if file_path is None:
file_path = f"checkpoints/{model_name}.pth"
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with requests.get(model_url, stream=True) as r:
with open(file_path, 'wb') as f:
shutil.copyfileobj(r.raw, f)

Loading