29
29
import torch
30
30
import torch .utils .checkpoint
31
31
import transformers
32
- from accelerate import Accelerator
32
+ from accelerate import Accelerator , DistributedType
33
33
from accelerate .logging import get_logger
34
34
from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
35
35
from huggingface_hub import create_repo , upload_folder
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
1292
1292
text_encoder_two_lora_layers_to_save = None
1293
1293
1294
1294
for model in models :
1295
- if isinstance (model , type (unwrap_model (transformer ))):
1295
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1296
+ model = unwrap_model (model )
1297
+ if args .upcast_before_saving :
1298
+ model = model .to (torch .float32 )
1296
1299
transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1297
- elif isinstance (model , type (unwrap_model (text_encoder_one ))): # or text_encoder_two
1300
+ elif args .train_text_encoder and isinstance (
1301
+ unwrap_model (model ), type (unwrap_model (text_encoder_one ))
1302
+ ): # or text_encoder_two
1298
1303
# both text encoders are of the same class, so we check hidden size to distinguish between the two
1299
- hidden_size = unwrap_model (model ).config .hidden_size
1304
+ model = unwrap_model (model )
1305
+ hidden_size = model .config .hidden_size
1300
1306
if hidden_size == 768 :
1301
1307
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
1302
1308
elif hidden_size == 1280 :
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
1305
1311
raise ValueError (f"unexpected save model: { model .__class__ } " )
1306
1312
1307
1313
# make sure to pop weight so that corresponding model is not saved again
1308
- weights .pop ()
1314
+ if weights :
1315
+ weights .pop ()
1309
1316
1310
1317
StableDiffusion3Pipeline .save_lora_weights (
1311
1318
output_dir ,
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
1319
1326
text_encoder_one_ = None
1320
1327
text_encoder_two_ = None
1321
1328
1322
- while len (models ) > 0 :
1323
- model = models .pop ()
1329
+ if not accelerator .distributed_type == DistributedType .DEEPSPEED :
1330
+ while len (models ) > 0 :
1331
+ model = models .pop ()
1324
1332
1325
- if isinstance (model , type (unwrap_model (transformer ))):
1326
- transformer_ = model
1327
- elif isinstance (model , type (unwrap_model (text_encoder_one ))):
1328
- text_encoder_one_ = model
1329
- elif isinstance (model , type (unwrap_model (text_encoder_two ))):
1330
- text_encoder_two_ = model
1331
- else :
1332
- raise ValueError (f"unexpected save model: { model .__class__ } " )
1333
+ if isinstance (unwrap_model (model ), type (unwrap_model (transformer ))):
1334
+ transformer_ = unwrap_model (model )
1335
+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_one ))):
1336
+ text_encoder_one_ = unwrap_model (model )
1337
+ elif isinstance (unwrap_model (model ), type (unwrap_model (text_encoder_two ))):
1338
+ text_encoder_two_ = unwrap_model (model )
1339
+ else :
1340
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
1341
+
1342
+ else :
1343
+ transformer_ = SD3Transformer2DModel .from_pretrained (
1344
+ args .pretrained_model_name_or_path , subfolder = "transformer"
1345
+ )
1346
+ transformer_ .add_adapter (transformer_lora_config )
1347
+ if args .train_text_encoder :
1348
+ text_encoder_one_ = text_encoder_cls_one .from_pretrained (
1349
+ args .pretrained_model_name_or_path , subfolder = "text_encoder"
1350
+ )
1351
+ text_encoder_two_ = text_encoder_cls_two .from_pretrained (
1352
+ args .pretrained_model_name_or_path , subfolder = "text_encoder_2"
1353
+ )
1333
1354
1334
1355
lora_state_dict = StableDiffusion3Pipeline .lora_state_dict (input_dir )
1335
1356
@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1829
1850
progress_bar .update (1 )
1830
1851
global_step += 1
1831
1852
1832
- if accelerator .is_main_process :
1853
+ if accelerator .is_main_process or accelerator . distributed_type == DistributedType . DEEPSPEED :
1833
1854
if global_step % args .checkpointing_steps == 0 :
1834
1855
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1835
1856
if args .checkpoints_total_limit is not None :
0 commit comments