Skip to content

Commit 7f6a36c

Browse files
authored
Merge pull request #2 from huggingface/add-glide
+ cosine schedule and unet config
2 parents 2db090d + 747f42d commit 7f6a36c

File tree

3 files changed

+78
-13
lines changed

3 files changed

+78
-13
lines changed

models/vision/glide/run_glide.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
from .modeling_glide import GLIDE
3+
from diffusers import UNetGLIDEModel, GaussianDDPMScheduler
4+
5+
generator = torch.Generator()
6+
generator = generator.manual_seed(0)
7+
8+
# 1. Load models
9+
scheduler = GaussianDDPMScheduler.from_config("fusing/glide-base")
10+
model = UNetGLIDEModel.from_pretrained("fusing/glide-base")
11+
12+
pipeline = GLIDE(model, scheduler)
13+
14+
img = pipeline(generator)
15+
16+
print(img)

src/diffusers/models/unet_glide.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import math
22
from abc import abstractmethod
33

4-
import torch as th
4+
import torch
55
import torch.nn as nn
66
import torch.nn.functional as F
77

8+
from ..configuration_utils import Config
9+
from ..modeling_utils import PreTrainedModel
10+
811

912
def convert_module_to_f16(l):
1013
"""
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
9497
:return: an [N x dim] Tensor of positional embeddings.
9598
"""
9699
half = dim // 2
97-
freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to(
100+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
98101
device=timesteps.device
99102
)
100103
args = timesteps[:, None].float() * freqs[None]
101-
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
104+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
102105
if dim % 2:
103-
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
106+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
104107
return embedding
105108

106109

@@ -298,7 +301,7 @@ def forward(self, x, emb):
298301
emb_out = emb_out[..., None]
299302
if self.use_scale_shift_norm:
300303
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
301-
scale, shift = th.chunk(emb_out, 2, dim=1)
304+
scale, shift = torch.chunk(emb_out, 2, dim=1)
302305
h = out_norm(h) * (1 + scale) + shift
303306
h = out_rest(h)
304307
else:
@@ -376,16 +379,16 @@ def forward(self, qkv, encoder_kv=None):
376379
if encoder_kv is not None:
377380
assert encoder_kv.shape[1] == self.n_heads * ch * 2
378381
ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
379-
k = th.cat([ek, k], dim=-1)
380-
v = th.cat([ev, v], dim=-1)
382+
k = torch.cat([ek, k], dim=-1)
383+
v = torch.cat([ev, v], dim=-1)
381384
scale = 1 / math.sqrt(math.sqrt(ch))
382-
weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
383-
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
384-
a = th.einsum("bts,bcs->bct", weight, v)
385+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
386+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
387+
a = torch.einsum("bts,bcs->bct", weight, v)
385388
return a.reshape(bs, -1, length)
386389

387390

388-
class UNetGLIDEModel(nn.Module):
391+
class UNetGLIDEModel(PreTrainedModel, Config):
389392
"""
390393
The full UNet model with attention and timestep embedding.
391394
@@ -435,6 +438,25 @@ def __init__(
435438
encoder_channels=None,
436439
):
437440
super().__init__()
441+
self.register(
442+
in_channels=in_channels,
443+
model_channels=model_channels,
444+
out_channels=out_channels,
445+
num_res_blocks=num_res_blocks,
446+
attention_resolutions=attention_resolutions,
447+
dropout=dropout,
448+
channel_mult=channel_mult,
449+
conv_resample=conv_resample,
450+
dims=dims,
451+
use_checkpoint=use_checkpoint,
452+
use_fp16=use_fp16,
453+
num_heads=num_heads,
454+
num_head_channels=num_head_channels,
455+
num_heads_upsample=num_heads_upsample,
456+
use_scale_shift_norm=use_scale_shift_norm,
457+
resblock_updown=resblock_updown,
458+
encoder_channels=encoder_channels,
459+
)
438460

439461
if num_heads_upsample == -1:
440462
num_heads_upsample = num_heads
@@ -448,7 +470,7 @@ def __init__(
448470
self.channel_mult = channel_mult
449471
self.conv_resample = conv_resample
450472
self.use_checkpoint = use_checkpoint
451-
self.dtype = th.float16 if use_fp16 else th.float32
473+
self.dtype = torch.float16 if use_fp16 else torch.float32
452474
self.num_heads = num_heads
453475
self.num_head_channels = num_head_channels
454476
self.num_heads_upsample = num_heads_upsample
@@ -637,7 +659,7 @@ def forward(self, x, timesteps, transformer_out):
637659
hs.append(h)
638660
h = self.middle_block(h, emb)
639661
for module in self.output_blocks:
640-
h = th.cat([h, hs.pop()], dim=1)
662+
h = torch.cat([h, hs.pop()], dim=1)
641663
h = module(h, emb)
642664
h = h.type(x.dtype)
643665
return self.out(h)

src/diffusers/schedulers/gaussian_ddpm.py

+27
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import torch
15+
import math
1516
from torch import nn
1617

1718
from ..configuration_utils import ConfigMixin
@@ -24,6 +25,26 @@ def linear_beta_schedule(timesteps, beta_start, beta_end):
2425
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
2526

2627

28+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
29+
"""
30+
Create a beta schedule that discretizes the given alpha_t_bar function,
31+
which defines the cumulative product of (1-beta) over time from t = [0,1].
32+
33+
:param num_diffusion_timesteps: the number of betas to produce.
34+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
35+
produces the cumulative product of (1-beta) up to that
36+
part of the diffusion process.
37+
:param max_beta: the maximum beta to use; use values lower than 1 to
38+
prevent singularities.
39+
"""
40+
betas = []
41+
for i in range(num_diffusion_timesteps):
42+
t1 = i / num_diffusion_timesteps
43+
t2 = (i + 1) / num_diffusion_timesteps
44+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
45+
return torch.tensor(betas, dtype=torch.float64)
46+
47+
2748
class GaussianDDPMScheduler(nn.Module, ConfigMixin):
2849

2950
config_name = SAMPLING_CONFIG_NAME
@@ -48,6 +69,12 @@ def __init__(
4869

4970
if beta_schedule == "linear":
5071
betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
72+
elif beta_schedule == "squaredcos_cap_v2":
73+
# GLIDE cosine schedule
74+
betas = betas_for_alpha_bar(
75+
timesteps,
76+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
77+
)
5178
else:
5279
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
5380

0 commit comments

Comments
 (0)