Skip to content

Commit c98de3c

Browse files
committed
vae encoder
1 parent 32c3b9c commit c98de3c

File tree

3 files changed

+108
-7
lines changed

3 files changed

+108
-7
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ venv/
22
models/
33
__pycache__/
44
*.png
5+
*.jpg
56
*.latent
67
tests.py

sd3_impls.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,18 @@ def forward(self, x):
221221
return x + hidden
222222

223223

224+
class Downsample(torch.nn.Module):
225+
def __init__(self, in_channels, dtype=torch.float32, device=None):
226+
super().__init__()
227+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
228+
229+
def forward(self, x):
230+
pad = (0,1,0,1)
231+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
232+
x = self.conv(x)
233+
return x
234+
235+
224236
class Upsample(torch.nn.Module):
225237
def __init__(self, in_channels, dtype=torch.float32, device=None):
226238
super().__init__()
@@ -232,6 +244,61 @@ def forward(self, x):
232244
return x
233245

234246

247+
class VAEEncoder(torch.nn.Module):
248+
def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):
249+
super().__init__()
250+
self.num_resolutions = len(ch_mult)
251+
self.num_res_blocks = num_res_blocks
252+
# downsampling
253+
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
254+
in_ch_mult = (1,) + tuple(ch_mult)
255+
self.in_ch_mult = in_ch_mult
256+
self.down = torch.nn.ModuleList()
257+
for i_level in range(self.num_resolutions):
258+
block = torch.nn.ModuleList()
259+
attn = torch.nn.ModuleList()
260+
block_in = ch*in_ch_mult[i_level]
261+
block_out = ch*ch_mult[i_level]
262+
for i_block in range(num_res_blocks):
263+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
264+
block_in = block_out
265+
down = torch.nn.Module()
266+
down.block = block
267+
down.attn = attn
268+
if i_level != self.num_resolutions - 1:
269+
down.downsample = Downsample(block_in, dtype=dtype, device=device)
270+
self.down.append(down)
271+
# middle
272+
self.mid = torch.nn.Module()
273+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
274+
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
275+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
276+
# end
277+
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
278+
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
279+
self.swish = torch.nn.SiLU(inplace=True)
280+
281+
def forward(self, x):
282+
# downsampling
283+
hs = [self.conv_in(x)]
284+
for i_level in range(self.num_resolutions):
285+
for i_block in range(self.num_res_blocks):
286+
h = self.down[i_level].block[i_block](hs[-1])
287+
hs.append(h)
288+
if i_level != self.num_resolutions-1:
289+
hs.append(self.down[i_level].downsample(hs[-1]))
290+
# middle
291+
h = hs[-1]
292+
h = self.mid.block_1(h)
293+
h = self.mid.attn_1(h)
294+
h = self.mid.block_2(h)
295+
# end
296+
h = self.norm_out(h)
297+
h = self.swish(h)
298+
h = self.conv_out(h)
299+
return h
300+
301+
235302
class VAEDecoder(torch.nn.Module):
236303
def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):
237304
super().__init__()
@@ -286,11 +353,19 @@ def forward(self, z):
286353

287354

288355
class SDVAE(torch.nn.Module):
289-
"""Note that the VAE Encoder is not included in our current reference SD3 models. Might be added on release. Not needed for most gens anyway, only for img2img (Init Image), so for this codebase we'll just ignore it, and implement only the decoder."""
290356
def __init__(self, dtype=torch.float32, device=None):
291357
super().__init__()
358+
self.encoder = VAEEncoder(dtype=dtype, device=device)
292359
self.decoder = VAEDecoder(dtype=dtype, device=device)
293360

294361
@torch.autocast("cuda", dtype=torch.float16)
295362
def decode(self, latent):
296363
return self.decoder(latent)
364+
365+
@torch.autocast("cuda", dtype=torch.float16)
366+
def encode(self, image):
367+
hidden = self.encoder(image)
368+
mean, logvar = torch.chunk(hidden, 2, dim=1)
369+
logvar = torch.clamp(logvar, -30.0, 20.0)
370+
std = torch.exp(0.5 * logvar)
371+
return mean + std * torch.randn_like(mean)

sd3_infer.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
def load_into(f, model, prefix, device, dtype=None):
2323
"""Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module."""
2424
for key in f.keys():
25-
if key.startswith(prefix):
25+
if key.startswith(prefix) and not key.startswith("loss."):
2626
path = key[len(prefix):].split(".")
2727
obj = model
2828
for p in path:
@@ -133,6 +133,10 @@ def __init__(self, model):
133133
MODEL = "models/sd3_beta.safetensors"
134134
# VAE model file path, or set "None" to use the same model file
135135
VAEFile = "models/sd3_vae.safetensors"
136+
# Optional init image file path
137+
INIT_IMAGE = None
138+
# If init_image is given, this is the percentage of denoising steps to run (1.0 = full denoise, 0.0 = no denoise at all)
139+
DENOISE = 0.6
136140
# Output file path
137141
OUTPUT = "output.png"
138142

@@ -194,12 +198,13 @@ def fix_cond(self, cond):
194198
cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda())
195199
return { "c_crossattn": cond, "y": pooled }
196200

197-
def do_sampling(self, latent, seed, conditioning, neg_cond, steps, cfg_scale) -> torch.Tensor:
201+
def do_sampling(self, latent, seed, conditioning, neg_cond, steps, cfg_scale, denoise=1.0) -> torch.Tensor:
198202
print("Sampling...")
199203
latent = latent.half().cuda()
200204
self.sd3.model = self.sd3.model.cuda()
201205
noise = self.get_noise(seed, latent).cuda()
202206
sigmas = self.get_sigmas(self.sd3.model.model_sampling, steps).cuda()
207+
sigmas = sigmas[int(steps * (1 - denoise)):]
203208
conditioning = self.fix_cond(conditioning)
204209
neg_cond = self.fix_cond(neg_cond)
205210
extra_args = { "cond": conditioning, "uncond": neg_cond, "cond_scale": cfg_scale }
@@ -210,6 +215,21 @@ def do_sampling(self, latent, seed, conditioning, neg_cond, steps, cfg_scale) ->
210215
print("Sampling done")
211216
return latent
212217

218+
def vae_encode(self, image) -> torch.Tensor:
219+
print("Encoding image to latent...")
220+
image = image.convert("RGB")
221+
image_np = np.array(image).astype(np.float32) / 255.0
222+
image_np = np.moveaxis(image_np, 2, 0)
223+
batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0)
224+
image_torch = torch.from_numpy(batch_images)
225+
image_torch = 2.0 * image_torch - 1.0
226+
image_torch = image_torch.cuda()
227+
self.vae.model = self.vae.model.cuda()
228+
latent = self.vae.model.encode(image_torch).cpu()
229+
self.vae.model = self.vae.model.cpu()
230+
print("Encoded")
231+
return latent
232+
213233
def vae_decode(self, latent) -> Image.Image:
214234
print("Decoding latent to image...")
215235
latent = latent.cuda()
@@ -224,20 +244,25 @@ def vae_decode(self, latent) -> Image.Image:
224244
print("Decoded")
225245
return out_image
226246

227-
def gen_image(self, prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, seed=SEED, output=OUTPUT):
247+
def gen_image(self, prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, seed=SEED, output=OUTPUT, init_image=INIT_IMAGE, denoise=DENOISE):
228248
latent = self.get_empty_latent(width, height)
249+
if init_image:
250+
image_data = Image.open(init_image)
251+
image_data = image_data.resize((width, height), Image.LANCZOS)
252+
latent = self.vae_encode(image_data)
253+
latent = SD3LatentFormat().process_in(latent)
229254
conditioning = self.get_cond(prompt)
230255
neg_cond = self.get_cond("")
231-
sampled_latent = self.do_sampling(latent, seed, conditioning, neg_cond, steps, cfg_scale)
256+
sampled_latent = self.do_sampling(latent, seed, conditioning, neg_cond, steps, cfg_scale, denoise if init_image else 1.0)
232257
image = self.vae_decode(sampled_latent)
233258
print(f"Will save to {output}")
234259
image.save(output)
235260
print("Done")
236261

237262
@torch.no_grad()
238-
def main(prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, shift=SHIFT, model=MODEL, vae=VAEFile, seed=SEED, output=OUTPUT):
263+
def main(prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, shift=SHIFT, model=MODEL, vae=VAEFile, seed=SEED, output=OUTPUT, init_image=INIT_IMAGE, denoise=DENOISE):
239264
inferencer = SD3Inferencer()
240265
inferencer.load(model, vae, shift)
241-
inferencer.gen_image(prompt, width, height, steps, cfg_scale, seed, output)
266+
inferencer.gen_image(prompt, width, height, steps, cfg_scale, seed, output, init_image, denoise)
242267

243268
fire.Fire(main)

0 commit comments

Comments
 (0)