Skip to content

Commit 5f72473

Browse files
sayakpaulleisuzzlinoytsaban
authored
[training] add ds support to lora sd3. (#10378)
* add ds support to lora sd3. Co-authored-by: leisuzz <[email protected]> * style. --------- Co-authored-by: leisuzz <[email protected]> Co-authored-by: Linoy Tsaban <[email protected]>
1 parent 01780c3 commit 5f72473

File tree

1 file changed

+37
-16
lines changed

1 file changed

+37
-16
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import torch
3030
import torch.utils.checkpoint
3131
import transformers
32-
from accelerate import Accelerator
32+
from accelerate import Accelerator, DistributedType
3333
from accelerate.logging import get_logger
3434
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
3535
from huggingface_hub import create_repo, upload_folder
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
12921292
text_encoder_two_lora_layers_to_save = None
12931293

12941294
for model in models:
1295-
if isinstance(model, type(unwrap_model(transformer))):
1295+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1296+
model = unwrap_model(model)
1297+
if args.upcast_before_saving:
1298+
model = model.to(torch.float32)
12961299
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1297-
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
1300+
elif args.train_text_encoder and isinstance(
1301+
unwrap_model(model), type(unwrap_model(text_encoder_one))
1302+
): # or text_encoder_two
12981303
# both text encoders are of the same class, so we check hidden size to distinguish between the two
1299-
hidden_size = unwrap_model(model).config.hidden_size
1304+
model = unwrap_model(model)
1305+
hidden_size = model.config.hidden_size
13001306
if hidden_size == 768:
13011307
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
13021308
elif hidden_size == 1280:
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
13051311
raise ValueError(f"unexpected save model: {model.__class__}")
13061312

13071313
# make sure to pop weight so that corresponding model is not saved again
1308-
weights.pop()
1314+
if weights:
1315+
weights.pop()
13091316

13101317
StableDiffusion3Pipeline.save_lora_weights(
13111318
output_dir,
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
13191326
text_encoder_one_ = None
13201327
text_encoder_two_ = None
13211328

1322-
while len(models) > 0:
1323-
model = models.pop()
1329+
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
1330+
while len(models) > 0:
1331+
model = models.pop()
13241332

1325-
if isinstance(model, type(unwrap_model(transformer))):
1326-
transformer_ = model
1327-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
1328-
text_encoder_one_ = model
1329-
elif isinstance(model, type(unwrap_model(text_encoder_two))):
1330-
text_encoder_two_ = model
1331-
else:
1332-
raise ValueError(f"unexpected save model: {model.__class__}")
1333+
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
1334+
transformer_ = unwrap_model(model)
1335+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
1336+
text_encoder_one_ = unwrap_model(model)
1337+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
1338+
text_encoder_two_ = unwrap_model(model)
1339+
else:
1340+
raise ValueError(f"unexpected save model: {model.__class__}")
1341+
1342+
else:
1343+
transformer_ = SD3Transformer2DModel.from_pretrained(
1344+
args.pretrained_model_name_or_path, subfolder="transformer"
1345+
)
1346+
transformer_.add_adapter(transformer_lora_config)
1347+
if args.train_text_encoder:
1348+
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
1349+
args.pretrained_model_name_or_path, subfolder="text_encoder"
1350+
)
1351+
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
1352+
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
1353+
)
13331354

13341355
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
13351356

@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18291850
progress_bar.update(1)
18301851
global_step += 1
18311852

1832-
if accelerator.is_main_process:
1853+
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
18331854
if global_step % args.checkpointing_steps == 0:
18341855
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
18351856
if args.checkpoints_total_limit is not None:

0 commit comments

Comments
 (0)