@@ -53,7 +53,15 @@ def __init__(self):
53
53
54
54
# TODO 他のスクリプトと共通化する
55
55
def generate_step_logs (
56
- self , args : argparse .Namespace , current_loss , avr_loss , lr_scheduler , keys_scaled = None , mean_norm = None , maximum_norm = None
56
+ self ,
57
+ args : argparse .Namespace ,
58
+ current_loss ,
59
+ avr_loss ,
60
+ lr_scheduler ,
61
+ lr_descriptions ,
62
+ keys_scaled = None ,
63
+ mean_norm = None ,
64
+ maximum_norm = None ,
57
65
):
58
66
logs = {"loss/current" : current_loss , "loss/average" : avr_loss }
59
67
@@ -63,68 +71,25 @@ def generate_step_logs(
63
71
logs ["max_norm/max_key_norm" ] = maximum_norm
64
72
65
73
lrs = lr_scheduler .get_last_lr ()
66
-
67
- if len (lrs ) > 4 :
68
- idx = 0
69
- if not args .network_train_unet_only :
70
- logs ["lr/textencoder" ] = float (lrs [0 ])
71
- idx = 1
72
-
73
- for i in range (idx , len (lrs )):
74
- lora_plus = ""
75
- group_id = i
76
-
77
- if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
78
- lora_plus = '_lora+' if i % 2 == 1 else ''
79
- group_id = int ((i / 2 ) + (i % 2 + 0.5 ))
80
-
81
- logs [f"lr/group{ group_id } { lora_plus } " ] = float (lrs [i ])
82
- if args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ():
83
- logs [f"lr/d*lr/group{ group_id } { lora_plus } " ] = (
84
- lr_scheduler .optimizers [- 1 ].param_groups [i ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [i ]["lr" ]
85
- )
86
-
87
- else :
88
- if args .network_train_text_encoder_only :
89
- if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None :
90
- logs ["lr/textencoder" ] = float (lrs [0 ])
91
- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
92
- else :
93
- logs ["lr/textencoder" ] = float (lrs [0 ])
94
-
95
- elif args .network_train_unet_only :
96
- if args .loraplus_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
97
- logs ["lr/unet" ] = float (lrs [0 ])
98
- logs ["lr/unet_lora+" ] = float (lrs [1 ])
99
- else :
100
- logs ["lr/unet" ] = float (lrs [0 ])
74
+ for i , lr in enumerate (lrs ):
75
+ if lr_descriptions is not None :
76
+ lr_desc = lr_descriptions [i ]
101
77
else :
102
- if len (lrs ) == 2 :
103
- if args .loraplus_text_encoder_lr_ratio is not None and args .loraplus_unet_lr_ratio is None :
104
- logs ["lr/textencoder" ] = float (lrs [0 ])
105
- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
106
- elif args .loraplus_unet_lr_ratio is not None and args .loraplus_text_encoder_lr_ratio is None :
107
- logs ["lr/unet" ] = float (lrs [0 ])
108
- logs ["lr/unet_lora+" ] = float (lrs [1 ])
109
- elif args .loraplus_unet_lr_ratio is None and args .loraplus_text_encoder_lr_ratio is None and args .loraplus_lr_ratio is not None :
110
- logs ["lr/all" ] = float (lrs [0 ])
111
- logs ["lr/all_lora+" ] = float (lrs [1 ])
112
- else :
113
- logs ["lr/textencoder" ] = float (lrs [0 ])
114
- logs ["lr/unet" ] = float (lrs [- 1 ])
115
- elif len (lrs ) == 4 :
116
- logs ["lr/textencoder" ] = float (lrs [0 ])
117
- logs ["lr/textencoder_lora+" ] = float (lrs [1 ])
118
- logs ["lr/unet" ] = float (lrs [2 ])
119
- logs ["lr/unet_lora+" ] = float (lrs [3 ])
78
+ idx = i - (0 if args .network_train_unet_only else - 1 )
79
+ if idx == - 1 :
80
+ lr_desc = "textencoder"
120
81
else :
121
- logs ["lr/all" ] = float (lrs [0 ])
82
+ if len (lrs ) > 2 :
83
+ lr_desc = f"group{ idx } "
84
+ else :
85
+ lr_desc = "unet"
86
+
87
+ logs [f"lr/{ lr_desc } " ] = lr
122
88
123
- if (
124
- args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ()
125
- ): # tracking d*lr value of unet.
126
- logs ["lr/d*lr" ] = (
127
- lr_scheduler .optimizers [- 1 ].param_groups [0 ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [0 ]["lr" ]
89
+ if args .optimizer_type .lower ().startswith ("DAdapt" .lower ()) or args .optimizer_type .lower () == "Prodigy" .lower ():
90
+ # tracking d*lr value
91
+ logs [f"lr/d*lr/{ lr_desc } " ] = (
92
+ lr_scheduler .optimizers [- 1 ].param_groups [i ]["d" ] * lr_scheduler .optimizers [- 1 ].param_groups [i ]["lr" ]
128
93
)
129
94
130
95
return logs
@@ -358,6 +323,7 @@ def train(self, args):
358
323
network .apply_to (text_encoder , unet , train_text_encoder , train_unet )
359
324
360
325
if args .network_weights is not None :
326
+ # FIXME consider alpha of weights
361
327
info = network .load_weights (args .network_weights )
362
328
accelerator .print (f"load network weights from { args .network_weights } : { info } " )
363
329
@@ -373,20 +339,23 @@ def train(self, args):
373
339
374
340
# 後方互換性を確保するよ
375
341
try :
376
- trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate , args .loraplus_text_encoder_lr_ratio , args .loraplus_unet_lr_ratio , args .loraplus_lr_ratio )
342
+ results = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr , args .learning_rate )
343
+ if type (results ) is tuple :
344
+ trainable_params = results [0 ]
345
+ lr_descriptions = results [1 ]
346
+ else :
347
+ trainable_params = results
348
+ lr_descriptions = None
377
349
except TypeError :
378
- accelerator .print (
379
- "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
380
- )
350
+ # accelerator.print(
351
+ # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
352
+ # )
381
353
trainable_params = network .prepare_optimizer_params (args .text_encoder_lr , args .unet_lr )
354
+ lr_descriptions = None
355
+ print (lr_descriptions )
382
356
383
357
optimizer_name , optimizer_args , optimizer = train_util .get_optimizer (args , trainable_params )
384
358
385
- if args .loraplus_lr_ratio is not None or args .loraplus_text_encoder_lr_ratio is not None or args .loraplus_unet_lr_ratio is not None :
386
- assert (
387
- (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name )
388
- ), "LoRA+ and Prodigy/DAdaptation is not supported"
389
-
390
359
# dataloaderを準備する
391
360
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
392
361
n_workers = min (args .max_data_loader_n_workers , os .cpu_count ()) # cpu_count or max_data_loader_n_workers
@@ -992,7 +961,9 @@ def remove_model(old_ckpt_name):
992
961
progress_bar .set_postfix (** {** max_mean_logs , ** logs })
993
962
994
963
if args .logging_dir is not None :
995
- logs = self .generate_step_logs (args , current_loss , avr_loss , lr_scheduler , keys_scaled , mean_norm , maximum_norm )
964
+ logs = self .generate_step_logs (
965
+ args , current_loss , avr_loss , lr_scheduler , lr_descriptions , keys_scaled , mean_norm , maximum_norm
966
+ )
996
967
accelerator .log (logs , step = global_step )
997
968
998
969
if global_step >= args .max_train_steps :
@@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser:
1143
1114
action = "store_true" ,
1144
1115
help = "do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う" ,
1145
1116
)
1117
+ # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
1118
+ # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
1119
+ # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
1146
1120
return parser
1147
1121
1148
1122
0 commit comments