Skip to content

Commit c68baae

Browse files
committed
add --log_config option to enable/disable output training config
1 parent 47187f7 commit c68baae

11 files changed

+42
-16
lines changed

Diff for: README.md

+6
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
165165
- Specify the learning rate and dim (rank) for each block.
166166
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
167167

168+
- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
169+
- Some settings, such as API keys and directory specifications, are not output due to security issues.
170+
168171
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
169172
- It seems that the model file loading is faster in the WSL environment etc.
170173
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
@@ -209,6 +212,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
209212
- ブロックごとに学習率および dim (rank) を指定することができます。
210213
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
211214

215+
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
216+
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
217+
212218
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
213219
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
214220
- `sdxl_train.py``sdxl_train_network.py``sdxl_train_textual_inversion.py``sdxl_train_control_net_lllite.py` で使用可能です。

Diff for: fine_tune.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
310310
init_kwargs["wandb"] = {"name": args.wandb_run_name}
311311
if args.log_tracker_config is not None:
312312
init_kwargs = toml.load(args.log_tracker_config)
313-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
313+
accelerator.init_trackers(
314+
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
315+
config=train_util.get_sanitized_config_or_none(args),
316+
init_kwargs=init_kwargs,
317+
)
314318

315319
# For --sample_at_first
316320
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
@@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
354358

355359
# Sample noise, sample a random timestep for each image, and add noise to the latents,
356360
# with noise offset and/or multires noise if specified
357-
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
361+
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
362+
args, noise_scheduler, latents
363+
)
358364

359365
# Predict the noise residual
360366
with accelerator.autocast():
@@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
368374

369375
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
370376
# do not mean over batch dimension for snr weight or scale v-pred loss
371-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
377+
loss = train_util.conditional_loss(
378+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
379+
)
372380
loss = loss.mean([1, 2, 3])
373381

374382
if args.min_snr_gamma:
@@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
380388

381389
loss = loss.mean() # mean over batch dimension
382390
else:
383-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
391+
loss = train_util.conditional_loss(
392+
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
393+
)
384394

385395
accelerator.backward(loss)
386396
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
471481

472482
accelerator.end_training()
473483

474-
if is_main_process and (args.save_state or args.save_state_on_train_end):
484+
if is_main_process and (args.save_state or args.save_state_on_train_end):
475485
train_util.save_state_on_train_end(args, accelerator)
476486

477487
del accelerator # この後メモリを使うのでこれは消す

Diff for: library/train_util.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
31803180
default=None,
31813181
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
31823182
)
3183+
parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する")
31833184

31843185
parser.add_argument(
31853186
"--noise_offset",
@@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
33883389
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
33893390
)
33903391

3391-
def filter_sensitive_args(args: argparse.Namespace):
3392+
3393+
def get_sanitized_config_or_none(args: argparse.Namespace):
3394+
# if `--log_config` is enabled, return args for logging. if not, return None.
3395+
# when `--log_config is enabled, filter out sensitive values from args
3396+
# if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe
3397+
3398+
if not args.log_config:
3399+
return None
3400+
33923401
sensitive_args = ["wandb_api_key", "huggingface_token"]
33933402
sensitive_path_args = [
33943403
"pretrained_model_name_or_path",
@@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace):
34023411
]
34033412
filtered_args = {}
34043413
for k, v in vars(args).items():
3405-
# filter out sensitive values
3414+
# filter out sensitive values and convert to string if necessary
34063415
if k not in sensitive_args + sensitive_path_args:
3407-
#Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
3416+
# Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`.
34083417
if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int):
34093418
filtered_args[k] = v
34103419
# accelerate does not support lists
@@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace):
34163425

34173426
return filtered_args
34183427

3428+
34193429
# verify command line args for training
34203430
def verify_command_line_training_args(args: argparse.Namespace):
34213431
# if wandb is enabled, the command line is exposed to the public

Diff for: sdxl_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor):
589589
init_kwargs["wandb"] = {"name": args.wandb_run_name}
590590
if args.log_tracker_config is not None:
591591
init_kwargs = toml.load(args.log_tracker_config)
592-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
592+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
593593

594594
# For --sample_at_first
595595
sdxl_train_util.sample_images(

Diff for: sdxl_train_control_net_lllite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def train(args):
354354
if args.log_tracker_config is not None:
355355
init_kwargs = toml.load(args.log_tracker_config)
356356
accelerator.init_trackers(
357-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
357+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
358358
)
359359

360360
loss_recorder = train_util.LossRecorder()

Diff for: sdxl_train_control_net_lllite_old.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def train(args):
324324
if args.log_tracker_config is not None:
325325
init_kwargs = toml.load(args.log_tracker_config)
326326
accelerator.init_trackers(
327-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
327+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
328328
)
329329

330330
loss_recorder = train_util.LossRecorder()

Diff for: train_controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def train(args):
344344
if args.log_tracker_config is not None:
345345
init_kwargs = toml.load(args.log_tracker_config)
346346
accelerator.init_trackers(
347-
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
347+
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
348348
)
349349

350350
loss_recorder = train_util.LossRecorder()

Diff for: train_db.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def train(args):
290290
init_kwargs["wandb"] = {"name": args.wandb_run_name}
291291
if args.log_tracker_config is not None:
292292
init_kwargs = toml.load(args.log_tracker_config)
293-
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
293+
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
294294

295295
# For --sample_at_first
296296
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

Diff for: train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def load_model_hook(models, input_dir):
774774
if args.log_tracker_config is not None:
775775
init_kwargs = toml.load(args.log_tracker_config)
776776
accelerator.init_trackers(
777-
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
777+
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
778778
)
779779

780780
loss_recorder = train_util.LossRecorder()

Diff for: train_textual_inversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def train(self, args):
510510
if args.log_tracker_config is not None:
511511
init_kwargs = toml.load(args.log_tracker_config)
512512
accelerator.init_trackers(
513-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
513+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
514514
)
515515

516516
# function for saving/removing

Diff for: train_textual_inversion_XTI.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def train(args):
407407
if args.log_tracker_config is not None:
408408
init_kwargs = toml.load(args.log_tracker_config)
409409
accelerator.init_trackers(
410-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
410+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
411411
)
412412

413413
# function for saving/removing

0 commit comments

Comments
 (0)