Skip to content

Commit 32c3b9c

Browse files
committed
optional separate vae
1 parent 1d21853 commit 32c3b9c

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ models/
33
__pycache__/
44
*.png
55
*.latent
6+
tests.py

Diff for: sd3_infer.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
44
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
55
# - `sd3_beta.safetensors`
6+
# Also can have
7+
# - `sd3_vae.safetensors` (holds the VAE separately if needed)
68

79
import torch, fire, math
810
from safetensors import safe_open
@@ -103,7 +105,10 @@ class VAE:
103105
def __init__(self, model):
104106
with safe_open(model, framework="pt", device="cpu") as f:
105107
self.model = SDVAE(device="cpu", dtype=torch.float16).eval().cpu()
106-
load_into(f, self.model, "first_stage_model.", "cpu", torch.float16)
108+
prefix = ""
109+
if any(k.startswith("first_stage_model.") for k in f.keys()):
110+
prefix = "first_stage_model."
111+
load_into(f, self.model, prefix, "cpu", torch.float16)
107112

108113

109114
#################################################################################################
@@ -126,11 +131,13 @@ def __init__(self, model):
126131
SEED = 1
127132
# Actual model file path
128133
MODEL = "models/sd3_beta.safetensors"
134+
# VAE model file path, or set "None" to use the same model file
135+
VAEFile = "models/sd3_vae.safetensors"
129136
# Output file path
130137
OUTPUT = "output.png"
131138

132139
class SD3Inferencer:
133-
def load(self, model=MODEL, shift=SHIFT):
140+
def load(self, model=MODEL, vae=VAEFile, shift=SHIFT):
134141
print("Loading tokenizers...")
135142
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
136143
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
@@ -145,7 +152,7 @@ def load(self, model=MODEL, shift=SHIFT):
145152
print("Loading SD3 model...")
146153
self.sd3 = SD3(model, shift)
147154
print("Loading VAE model...")
148-
self.vae = VAE(model)
155+
self.vae = VAE(vae or model)
149156
print("Models loaded.")
150157

151158
def get_empty_latent(self, width, height):
@@ -228,9 +235,9 @@ def gen_image(self, prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_
228235
print("Done")
229236

230237
@torch.no_grad()
231-
def main(prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, shift=SHIFT, model=MODEL, seed=SEED, output=OUTPUT):
238+
def main(prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_scale=CFG_SCALE, shift=SHIFT, model=MODEL, vae=VAEFile, seed=SEED, output=OUTPUT):
232239
inferencer = SD3Inferencer()
233-
inferencer.load(model, shift)
240+
inferencer.load(model, vae, shift)
234241
inferencer.gen_image(prompt, width, height, steps, cfg_scale, seed, output)
235242

236243
fire.Fire(main)

0 commit comments

Comments
 (0)