Skip to content

Commit c6fc48d

Browse files
committed
Premier jet
1 parent 2f2820a commit c6fc48d

File tree

8 files changed

+1534
-511
lines changed

8 files changed

+1534
-511
lines changed

setup.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from setuptools import find_packages, setup
2-
setup(name="stable_diffusion_tf",
3-
version="0.1",
4-
description="Stable Diffusion in Tensorflow / Keras",
5-
author="Divam Gupta",
6-
author_email='[email protected]',
7-
platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx"
8-
url="https://github.com/divamgupta/stable-diffusion-tensorflow",
9-
packages=find_packages(),
10-
)
2+
3+
setup(
4+
name="stable_diffusion_tf",
5+
version="0.1",
6+
description="Stable Diffusion in Tensorflow / Keras",
7+
author="Divam Gupta",
8+
author_email="[email protected]",
9+
platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx"
10+
url="https://github.com/divamgupta/stable-diffusion-tensorflow",
11+
packages=find_packages(),
12+
)

stable_diffusion_tf/autoencoder_kl.py

+71-70
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,114 @@
1-
21
import tensorflow as tf
2+
from tensorflow import keras
33
import tensorflow_addons as tfa
44

5-
from .layers import quick_gelu, get_conv2d, apply_seq, td_dot, gelu, GEGLU
6-
5+
from .layers import quick_gelu, apply_seq, td_dot, gelu, GEGLU, PaddedConv2D
76

87

9-
class AttnBlock(tf.keras.layers.Layer):
10-
def __init__(self , in_channels):
8+
class AttnBlock(keras.layers.Layer):
9+
def __init__(self, in_channels):
1110
super(AttnBlock, self).__init__()
12-
self.norm = tfa.layers.GroupNormalization(epsilon=1e-5 )
13-
self.q = get_conv2d(in_channels, in_channels, 1)
14-
self.k = get_conv2d( in_channels, in_channels, 1)
15-
self.v = get_conv2d( in_channels, in_channels, 1)
16-
self.proj_out = get_conv2d(in_channels, in_channels, 1)
11+
self.norm = tfa.layers.GroupNormalization(epsilon=1e-5)
12+
self.q = PaddedConv2D(in_channels, in_channels, 1)
13+
self.k = PaddedConv2D(in_channels, in_channels, 1)
14+
self.v = PaddedConv2D(in_channels, in_channels, 1)
15+
self.proj_out = PaddedConv2D(in_channels, in_channels, 1)
1716

1817
# copied from AttnBlock in ldm repo
19-
def __call__(self, x):
18+
def call(self, x):
2019
h_ = self.norm(x)
21-
q,k,v = self.q(h_), self.k(h_), self.v(h_)
20+
q, k, v = self.q(h_), self.k(h_), self.v(h_)
2221

2322
# compute attention
24-
b, h,w, c = q.shape
25-
q = tf.reshape(q , (-1 ,h*w , c ))# b,hw,c
26-
k = tf.keras.layers.Permute((3,1,2))(k)
27-
k = tf.reshape(k , (-1,c,h*w)) # b,c,hw
23+
b, h, w, c = q.shape
24+
q = tf.reshape(q, (-1, h * w, c)) # b,hw,c
25+
k = keras.layers.Permute((3, 1, 2))(k)
26+
k = tf.reshape(k, (-1, c, h * w)) # b,c,hw
2827
w_ = q @ k
29-
w_ = w_ * (c**(-0.5))
30-
w_ = tf.keras.activations.softmax (w_)
31-
32-
28+
w_ = w_ * (c ** (-0.5))
29+
w_ = keras.activations.softmax(w_)
3330

3431
# attend to values
35-
v = tf.keras.layers.Permute((3,1,2))(v)
36-
v = tf.reshape(v , (-1,c,h*w))
37-
w_ = tf.keras.layers.Permute((2,1))( w_)
32+
v = keras.layers.Permute((3, 1, 2))(v)
33+
v = tf.reshape(v, (-1, c, h * w))
34+
w_ = keras.layers.Permute((2, 1))(w_)
3835
h_ = v @ w_
39-
h_ = tf.keras.layers.Permute((2,1))(h_)
40-
h_ = tf.reshape(h_, (-1,h,w,c))
36+
h_ = keras.layers.Permute((2, 1))(h_)
37+
h_ = tf.reshape(h_, (-1, h, w, c))
4138
return x + self.proj_out(h_)
4239

4340

44-
45-
class ResnetBlock(tf.keras.layers.Layer):
46-
def __init__(self , in_channels, out_channels=None):
41+
class ResnetBlock(keras.layers.Layer):
42+
def __init__(self, in_channels, out_channels=None):
4743
super(ResnetBlock, self).__init__()
48-
self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5 )
49-
self.conv1 = get_conv2d(in_channels, out_channels, 3, padding=1)
50-
self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5 )
51-
self.conv2 = get_conv2d(out_channels, out_channels, 3, padding=1)
52-
self.nin_shortcut = get_conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
53-
54-
def __call__(self, x):
55-
h = self.conv1(tf.keras.activations.swish(self.norm1(x)) )
56-
h = self.conv2(tf.keras.activations.swish(self.norm2(h)) )
44+
self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5)
45+
self.conv1 = PaddedConv2D(in_channels, out_channels, 3, padding=1)
46+
self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5)
47+
self.conv2 = PaddedConv2D(out_channels, out_channels, 3, padding=1)
48+
self.nin_shortcut = (
49+
PaddedConv2D(in_channels, out_channels, 1)
50+
if in_channels != out_channels
51+
else lambda x: x
52+
)
53+
54+
def call(self, x):
55+
h = self.conv1(keras.activations.swish(self.norm1(x)))
56+
h = self.conv2(keras.activations.swish(self.norm2(h)))
5757
return self.nin_shortcut(x) + h
5858

5959

60-
61-
62-
class Mid(tf.keras.layers.Layer):
63-
def __init__(self , block_in):
60+
class Mid(keras.layers.Layer):
61+
def __init__(self, block_in):
6462
super(Mid, self).__init__()
65-
self.block_1 = ResnetBlock( block_in, block_in)
63+
self.block_1 = ResnetBlock(block_in, block_in)
6664
self.attn_1 = AttnBlock(block_in)
67-
self.block_2 = ResnetBlock( block_in, block_in)
68-
69-
def __call__(self, x):
70-
return apply_seq(x , [self.block_1, self.attn_1, self.block_2])
71-
65+
self.block_2 = ResnetBlock(block_in, block_in)
7266

67+
def call(self, x):
68+
return apply_seq(x, [self.block_1, self.attn_1, self.block_2])
7369

7470

75-
class Decoder(tf.keras.models.Model):
76-
def __init__(self ):
71+
class Decoder(keras.models.Model):
72+
def __init__(self):
7773
super(Decoder, self).__init__()
7874
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
79-
self.conv_in = get_conv2d( 4,512,3, padding=1)
80-
self.mid = Mid( 512)
81-
self.upp = tf.keras.layers.UpSampling2D(size=(2, 2))
82-
83-
self.post_quant_conv = get_conv2d( 4, 4, 1)
75+
76+
self.post_quant_conv = PaddedConv2D(4, 4, 1)
77+
self.conv_in = PaddedConv2D(4, 512, 3, padding=1)
78+
self.mid = Mid(512)
79+
self.upp = keras.layers.UpSampling2D(size=(2, 2))
8480

8581
arr = []
86-
for i,s in enumerate(sz):
87-
arr.append({"block":
88-
[ResnetBlock( s[1], s[0]),
89-
ResnetBlock( s[0], s[0]),
90-
ResnetBlock( s[0], s[0])]})
91-
if i != 0: arr[-1]['upsample'] = {"conv": get_conv2d(s[0], s[0], 3, padding=1)}
82+
for i, s in enumerate(sz):
83+
arr.append(
84+
{
85+
"block": [
86+
ResnetBlock(s[1], s[0]),
87+
ResnetBlock(s[0], s[0]),
88+
ResnetBlock(s[0], s[0]),
89+
]
90+
}
91+
)
92+
if i != 0:
93+
arr[-1]["upsample"] = {"conv": PaddedConv2D(s[0], s[0], 3, padding=1)}
9294
self.up = arr
9395

9496
self.norm_out = tfa.layers.GroupNormalization(epsilon=1e-5)
95-
self.conv_out = get_conv2d(128, 3, 3, padding=1)
97+
self.conv_out = PaddedConv2D(128, 3, 3, padding=1)
9698

97-
def __call__(self, x, training=False):
98-
99-
x = self.post_quant_conv(1/0.18215 * x)
99+
def call(self, x, training=False):
100+
x = self.post_quant_conv(1 / 0.18215 * x)
100101

101102
x = self.conv_in(x)
102103
x = self.mid(x)
103104

104105
for l in self.up[::-1]:
105-
for b in l['block']: x = b(x)
106-
if 'upsample' in l:
106+
for b in l["block"]:
107+
x = b(x)
108+
if "upsample" in l:
107109
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
108-
bs,c,py,px = x.shape
110+
bs, c, py, px = x.shape
109111
x = self.upp(x)
110-
x = l['upsample']['conv'](x)
111-
112-
return self.conv_out(tf.keras.activations.swish (self.norm_out(x)) )
112+
x = l["upsample"]["conv"](x)
113113

114+
return self.conv_out(keras.activations.swish(self.norm_out(x)))

0 commit comments

Comments
 (0)