Skip to content

Commit 58fb648

Browse files
committed
set static graph flag when DDP ref #1363
1 parent e5bab69 commit 58fb648

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

sdxl_train_control_net_lllite.py

+3
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def train(args):
289289
# acceleratorがなんかよろしくやってくれるらしい
290290
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
291291

292+
if isinstance(unet, DDP):
293+
unet._set_static_graph() # avoid error for multiple use of the parameter
294+
292295
if args.gradient_checkpointing:
293296
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
294297
else:

0 commit comments

Comments
 (0)