|
1 |
| - |
2 | 1 | import tensorflow as tf
|
| 2 | +from tensorflow import keras |
3 | 3 | import tensorflow_addons as tfa
|
4 | 4 |
|
5 |
| -from .layers import quick_gelu, get_conv2d, apply_seq, td_dot, gelu, GEGLU |
6 |
| - |
| 5 | +from .layers import quick_gelu, apply_seq, td_dot, gelu, GEGLU, PaddedConv2D |
7 | 6 |
|
8 | 7 |
|
9 |
| -class AttnBlock(tf.keras.layers.Layer): |
10 |
| - def __init__(self , in_channels): |
| 8 | +class AttnBlock(keras.layers.Layer): |
| 9 | + def __init__(self, in_channels): |
11 | 10 | super(AttnBlock, self).__init__()
|
12 |
| - self.norm = tfa.layers.GroupNormalization(epsilon=1e-5 ) |
13 |
| - self.q = get_conv2d(in_channels, in_channels, 1) |
14 |
| - self.k = get_conv2d( in_channels, in_channels, 1) |
15 |
| - self.v = get_conv2d( in_channels, in_channels, 1) |
16 |
| - self.proj_out = get_conv2d(in_channels, in_channels, 1) |
| 11 | + self.norm = tfa.layers.GroupNormalization(epsilon=1e-5) |
| 12 | + self.q = PaddedConv2D(in_channels, in_channels, 1) |
| 13 | + self.k = PaddedConv2D(in_channels, in_channels, 1) |
| 14 | + self.v = PaddedConv2D(in_channels, in_channels, 1) |
| 15 | + self.proj_out = PaddedConv2D(in_channels, in_channels, 1) |
17 | 16 |
|
18 | 17 | # copied from AttnBlock in ldm repo
|
19 |
| - def __call__(self, x): |
| 18 | + def call(self, x): |
20 | 19 | h_ = self.norm(x)
|
21 |
| - q,k,v = self.q(h_), self.k(h_), self.v(h_) |
| 20 | + q, k, v = self.q(h_), self.k(h_), self.v(h_) |
22 | 21 |
|
23 | 22 | # compute attention
|
24 |
| - b, h,w, c = q.shape |
25 |
| - q = tf.reshape(q , (-1 ,h*w , c ))# b,hw,c |
26 |
| - k = tf.keras.layers.Permute((3,1,2))(k) |
27 |
| - k = tf.reshape(k , (-1,c,h*w)) # b,c,hw |
| 23 | + b, h, w, c = q.shape |
| 24 | + q = tf.reshape(q, (-1, h * w, c)) # b,hw,c |
| 25 | + k = keras.layers.Permute((3, 1, 2))(k) |
| 26 | + k = tf.reshape(k, (-1, c, h * w)) # b,c,hw |
28 | 27 | w_ = q @ k
|
29 |
| - w_ = w_ * (c**(-0.5)) |
30 |
| - w_ = tf.keras.activations.softmax (w_) |
31 |
| - |
32 |
| - |
| 28 | + w_ = w_ * (c ** (-0.5)) |
| 29 | + w_ = keras.activations.softmax(w_) |
33 | 30 |
|
34 | 31 | # attend to values
|
35 |
| - v = tf.keras.layers.Permute((3,1,2))(v) |
36 |
| - v = tf.reshape(v , (-1,c,h*w)) |
37 |
| - w_ = tf.keras.layers.Permute((2,1))( w_) |
| 32 | + v = keras.layers.Permute((3, 1, 2))(v) |
| 33 | + v = tf.reshape(v, (-1, c, h * w)) |
| 34 | + w_ = keras.layers.Permute((2, 1))(w_) |
38 | 35 | h_ = v @ w_
|
39 |
| - h_ = tf.keras.layers.Permute((2,1))(h_) |
40 |
| - h_ = tf.reshape(h_, (-1,h,w,c)) |
| 36 | + h_ = keras.layers.Permute((2, 1))(h_) |
| 37 | + h_ = tf.reshape(h_, (-1, h, w, c)) |
41 | 38 | return x + self.proj_out(h_)
|
42 | 39 |
|
43 | 40 |
|
44 |
| - |
45 |
| -class ResnetBlock(tf.keras.layers.Layer): |
46 |
| - def __init__(self , in_channels, out_channels=None): |
| 41 | +class ResnetBlock(keras.layers.Layer): |
| 42 | + def __init__(self, in_channels, out_channels=None): |
47 | 43 | super(ResnetBlock, self).__init__()
|
48 |
| - self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5 ) |
49 |
| - self.conv1 = get_conv2d(in_channels, out_channels, 3, padding=1) |
50 |
| - self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5 ) |
51 |
| - self.conv2 = get_conv2d(out_channels, out_channels, 3, padding=1) |
52 |
| - self.nin_shortcut = get_conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x |
53 |
| - |
54 |
| - def __call__(self, x): |
55 |
| - h = self.conv1(tf.keras.activations.swish(self.norm1(x)) ) |
56 |
| - h = self.conv2(tf.keras.activations.swish(self.norm2(h)) ) |
| 44 | + self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5) |
| 45 | + self.conv1 = PaddedConv2D(in_channels, out_channels, 3, padding=1) |
| 46 | + self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5) |
| 47 | + self.conv2 = PaddedConv2D(out_channels, out_channels, 3, padding=1) |
| 48 | + self.nin_shortcut = ( |
| 49 | + PaddedConv2D(in_channels, out_channels, 1) |
| 50 | + if in_channels != out_channels |
| 51 | + else lambda x: x |
| 52 | + ) |
| 53 | + |
| 54 | + def call(self, x): |
| 55 | + h = self.conv1(keras.activations.swish(self.norm1(x))) |
| 56 | + h = self.conv2(keras.activations.swish(self.norm2(h))) |
57 | 57 | return self.nin_shortcut(x) + h
|
58 | 58 |
|
59 | 59 |
|
60 |
| - |
61 |
| - |
62 |
| -class Mid(tf.keras.layers.Layer): |
63 |
| - def __init__(self , block_in): |
| 60 | +class Mid(keras.layers.Layer): |
| 61 | + def __init__(self, block_in): |
64 | 62 | super(Mid, self).__init__()
|
65 |
| - self.block_1 = ResnetBlock( block_in, block_in) |
| 63 | + self.block_1 = ResnetBlock(block_in, block_in) |
66 | 64 | self.attn_1 = AttnBlock(block_in)
|
67 |
| - self.block_2 = ResnetBlock( block_in, block_in) |
68 |
| - |
69 |
| - def __call__(self, x): |
70 |
| - return apply_seq(x , [self.block_1, self.attn_1, self.block_2]) |
71 |
| - |
| 65 | + self.block_2 = ResnetBlock(block_in, block_in) |
72 | 66 |
|
| 67 | + def call(self, x): |
| 68 | + return apply_seq(x, [self.block_1, self.attn_1, self.block_2]) |
73 | 69 |
|
74 | 70 |
|
75 |
| -class Decoder(tf.keras.models.Model): |
76 |
| - def __init__(self ): |
| 71 | +class Decoder(keras.models.Model): |
| 72 | + def __init__(self): |
77 | 73 | super(Decoder, self).__init__()
|
78 | 74 | sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
79 |
| - self.conv_in = get_conv2d( 4,512,3, padding=1) |
80 |
| - self.mid = Mid( 512) |
81 |
| - self.upp = tf.keras.layers.UpSampling2D(size=(2, 2)) |
82 |
| - |
83 |
| - self.post_quant_conv = get_conv2d( 4, 4, 1) |
| 75 | + |
| 76 | + self.post_quant_conv = PaddedConv2D(4, 4, 1) |
| 77 | + self.conv_in = PaddedConv2D(4, 512, 3, padding=1) |
| 78 | + self.mid = Mid(512) |
| 79 | + self.upp = keras.layers.UpSampling2D(size=(2, 2)) |
84 | 80 |
|
85 | 81 | arr = []
|
86 |
| - for i,s in enumerate(sz): |
87 |
| - arr.append({"block": |
88 |
| - [ResnetBlock( s[1], s[0]), |
89 |
| - ResnetBlock( s[0], s[0]), |
90 |
| - ResnetBlock( s[0], s[0])]}) |
91 |
| - if i != 0: arr[-1]['upsample'] = {"conv": get_conv2d(s[0], s[0], 3, padding=1)} |
| 82 | + for i, s in enumerate(sz): |
| 83 | + arr.append( |
| 84 | + { |
| 85 | + "block": [ |
| 86 | + ResnetBlock(s[1], s[0]), |
| 87 | + ResnetBlock(s[0], s[0]), |
| 88 | + ResnetBlock(s[0], s[0]), |
| 89 | + ] |
| 90 | + } |
| 91 | + ) |
| 92 | + if i != 0: |
| 93 | + arr[-1]["upsample"] = {"conv": PaddedConv2D(s[0], s[0], 3, padding=1)} |
92 | 94 | self.up = arr
|
93 | 95 |
|
94 | 96 | self.norm_out = tfa.layers.GroupNormalization(epsilon=1e-5)
|
95 |
| - self.conv_out = get_conv2d(128, 3, 3, padding=1) |
| 97 | + self.conv_out = PaddedConv2D(128, 3, 3, padding=1) |
96 | 98 |
|
97 |
| - def __call__(self, x, training=False): |
98 |
| - |
99 |
| - x = self.post_quant_conv(1/0.18215 * x) |
| 99 | + def call(self, x, training=False): |
| 100 | + x = self.post_quant_conv(1 / 0.18215 * x) |
100 | 101 |
|
101 | 102 | x = self.conv_in(x)
|
102 | 103 | x = self.mid(x)
|
103 | 104 |
|
104 | 105 | for l in self.up[::-1]:
|
105 |
| - for b in l['block']: x = b(x) |
106 |
| - if 'upsample' in l: |
| 106 | + for b in l["block"]: |
| 107 | + x = b(x) |
| 108 | + if "upsample" in l: |
107 | 109 | # https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
|
108 |
| - bs,c,py,px = x.shape |
| 110 | + bs, c, py, px = x.shape |
109 | 111 | x = self.upp(x)
|
110 |
| - x = l['upsample']['conv'](x) |
111 |
| - |
112 |
| - return self.conv_out(tf.keras.activations.swish (self.norm_out(x)) ) |
| 112 | + x = l["upsample"]["conv"](x) |
113 | 113 |
|
| 114 | + return self.conv_out(keras.activations.swish(self.norm_out(x))) |
0 commit comments