1
1
import math
2
2
from abc import abstractmethod
3
3
4
- import torch as th
4
+ import torch
5
5
import torch .nn as nn
6
6
import torch .nn .functional as F
7
7
8
+ from ..configuration_utils import Config
9
+ from ..modeling_utils import PreTrainedModel
10
+
8
11
9
12
def convert_module_to_f16 (l ):
10
13
"""
@@ -94,13 +97,13 @@ def timestep_embedding(timesteps, dim, max_period=10000):
94
97
:return: an [N x dim] Tensor of positional embeddings.
95
98
"""
96
99
half = dim // 2
97
- freqs = th .exp (- math .log (max_period ) * th .arange (start = 0 , end = half , dtype = th .float32 ) / half ).to (
100
+ freqs = torch .exp (- math .log (max_period ) * torch .arange (start = 0 , end = half , dtype = torch .float32 ) / half ).to (
98
101
device = timesteps .device
99
102
)
100
103
args = timesteps [:, None ].float () * freqs [None ]
101
- embedding = th .cat ([th .cos (args ), th .sin (args )], dim = - 1 )
104
+ embedding = torch .cat ([torch .cos (args ), torch .sin (args )], dim = - 1 )
102
105
if dim % 2 :
103
- embedding = th .cat ([embedding , th .zeros_like (embedding [:, :1 ])], dim = - 1 )
106
+ embedding = torch .cat ([embedding , torch .zeros_like (embedding [:, :1 ])], dim = - 1 )
104
107
return embedding
105
108
106
109
@@ -298,7 +301,7 @@ def forward(self, x, emb):
298
301
emb_out = emb_out [..., None ]
299
302
if self .use_scale_shift_norm :
300
303
out_norm , out_rest = self .out_layers [0 ], self .out_layers [1 :]
301
- scale , shift = th .chunk (emb_out , 2 , dim = 1 )
304
+ scale , shift = torch .chunk (emb_out , 2 , dim = 1 )
302
305
h = out_norm (h ) * (1 + scale ) + shift
303
306
h = out_rest (h )
304
307
else :
@@ -376,16 +379,16 @@ def forward(self, qkv, encoder_kv=None):
376
379
if encoder_kv is not None :
377
380
assert encoder_kv .shape [1 ] == self .n_heads * ch * 2
378
381
ek , ev = encoder_kv .reshape (bs * self .n_heads , ch * 2 , - 1 ).split (ch , dim = 1 )
379
- k = th .cat ([ek , k ], dim = - 1 )
380
- v = th .cat ([ev , v ], dim = - 1 )
382
+ k = torch .cat ([ek , k ], dim = - 1 )
383
+ v = torch .cat ([ev , v ], dim = - 1 )
381
384
scale = 1 / math .sqrt (math .sqrt (ch ))
382
- weight = th .einsum ("bct,bcs->bts" , q * scale , k * scale ) # More stable with f16 than dividing afterwards
383
- weight = th .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
384
- a = th .einsum ("bts,bcs->bct" , weight , v )
385
+ weight = torch .einsum ("bct,bcs->bts" , q * scale , k * scale ) # More stable with f16 than dividing afterwards
386
+ weight = torch .softmax (weight .float (), dim = - 1 ).type (weight .dtype )
387
+ a = torch .einsum ("bts,bcs->bct" , weight , v )
385
388
return a .reshape (bs , - 1 , length )
386
389
387
390
388
- class UNetGLIDEModel (nn . Module ):
391
+ class UNetGLIDEModel (PreTrainedModel , Config ):
389
392
"""
390
393
The full UNet model with attention and timestep embedding.
391
394
@@ -435,6 +438,25 @@ def __init__(
435
438
encoder_channels = None ,
436
439
):
437
440
super ().__init__ ()
441
+ self .register (
442
+ in_channels = in_channels ,
443
+ model_channels = model_channels ,
444
+ out_channels = out_channels ,
445
+ num_res_blocks = num_res_blocks ,
446
+ attention_resolutions = attention_resolutions ,
447
+ dropout = dropout ,
448
+ channel_mult = channel_mult ,
449
+ conv_resample = conv_resample ,
450
+ dims = dims ,
451
+ use_checkpoint = use_checkpoint ,
452
+ use_fp16 = use_fp16 ,
453
+ num_heads = num_heads ,
454
+ num_head_channels = num_head_channels ,
455
+ num_heads_upsample = num_heads_upsample ,
456
+ use_scale_shift_norm = use_scale_shift_norm ,
457
+ resblock_updown = resblock_updown ,
458
+ encoder_channels = encoder_channels ,
459
+ )
438
460
439
461
if num_heads_upsample == - 1 :
440
462
num_heads_upsample = num_heads
@@ -448,7 +470,7 @@ def __init__(
448
470
self .channel_mult = channel_mult
449
471
self .conv_resample = conv_resample
450
472
self .use_checkpoint = use_checkpoint
451
- self .dtype = th .float16 if use_fp16 else th .float32
473
+ self .dtype = torch .float16 if use_fp16 else torch .float32
452
474
self .num_heads = num_heads
453
475
self .num_head_channels = num_head_channels
454
476
self .num_heads_upsample = num_heads_upsample
@@ -637,7 +659,7 @@ def forward(self, x, timesteps, transformer_out):
637
659
hs .append (h )
638
660
h = self .middle_block (h , emb )
639
661
for module in self .output_blocks :
640
- h = th .cat ([h , hs .pop ()], dim = 1 )
662
+ h = torch .cat ([h , hs .pop ()], dim = 1 )
641
663
h = module (h , emb )
642
664
h = h .type (x .dtype )
643
665
return self .out (h )
0 commit comments