We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 95f4c52 commit 2be64b3Copy full SHA for 2be64b3
train_network.py
@@ -474,7 +474,8 @@ def train(self, args):
474
# before resuming make hook for saving/loading to save/load the network weights only
475
def save_model_hook(models, weights, output_dir):
476
# pop weights of other models than network to save only network weights
477
- if accelerator.is_main_process:
+ # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
478
+ if accelerator.is_main_process or args.deepspeed:
479
remove_indices = []
480
for i, model in enumerate(models):
481
if not isinstance(model, type(accelerator.unwrap_model(network))):
0 commit comments