3
3
# - `clip_l.safetensors` (OpenAI CLIP-L, same as SDXL)
4
4
# - `t5xxl.safetensors` (google T5-v1.1-XXL)
5
5
# - `sd3_beta.safetensors`
6
+ # Also can have
7
+ # - `sd3_vae.safetensors` (holds the VAE separately if needed)
6
8
7
9
import torch , fire , math
8
10
from safetensors import safe_open
@@ -103,7 +105,10 @@ class VAE:
103
105
def __init__ (self , model ):
104
106
with safe_open (model , framework = "pt" , device = "cpu" ) as f :
105
107
self .model = SDVAE (device = "cpu" , dtype = torch .float16 ).eval ().cpu ()
106
- load_into (f , self .model , "first_stage_model." , "cpu" , torch .float16 )
108
+ prefix = ""
109
+ if any (k .startswith ("first_stage_model." ) for k in f .keys ()):
110
+ prefix = "first_stage_model."
111
+ load_into (f , self .model , prefix , "cpu" , torch .float16 )
107
112
108
113
109
114
#################################################################################################
@@ -126,11 +131,13 @@ def __init__(self, model):
126
131
SEED = 1
127
132
# Actual model file path
128
133
MODEL = "models/sd3_beta.safetensors"
134
+ # VAE model file path, or set "None" to use the same model file
135
+ VAEFile = "models/sd3_vae.safetensors"
129
136
# Output file path
130
137
OUTPUT = "output.png"
131
138
132
139
class SD3Inferencer :
133
- def load (self , model = MODEL , shift = SHIFT ):
140
+ def load (self , model = MODEL , vae = VAEFile , shift = SHIFT ):
134
141
print ("Loading tokenizers..." )
135
142
# NOTE: if you need a reference impl for a high performance CLIP tokenizer instead of just using the HF transformers one,
136
143
# check https://github.com/Stability-AI/StableSwarmUI/blob/master/src/Utils/CliplikeTokenizer.cs
@@ -145,7 +152,7 @@ def load(self, model=MODEL, shift=SHIFT):
145
152
print ("Loading SD3 model..." )
146
153
self .sd3 = SD3 (model , shift )
147
154
print ("Loading VAE model..." )
148
- self .vae = VAE (model )
155
+ self .vae = VAE (vae or model )
149
156
print ("Models loaded." )
150
157
151
158
def get_empty_latent (self , width , height ):
@@ -228,9 +235,9 @@ def gen_image(self, prompt=PROMPT, width=WIDTH, height=HEIGHT, steps=STEPS, cfg_
228
235
print ("Done" )
229
236
230
237
@torch .no_grad ()
231
- def main (prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , shift = SHIFT , model = MODEL , seed = SEED , output = OUTPUT ):
238
+ def main (prompt = PROMPT , width = WIDTH , height = HEIGHT , steps = STEPS , cfg_scale = CFG_SCALE , shift = SHIFT , model = MODEL , vae = VAEFile , seed = SEED , output = OUTPUT ):
232
239
inferencer = SD3Inferencer ()
233
- inferencer .load (model , shift )
240
+ inferencer .load (model , vae , shift )
234
241
inferencer .gen_image (prompt , width , height , steps , cfg_scale , seed , output )
235
242
236
243
fire .Fire (main )
0 commit comments