Skip to content

Commit 921036d

Browse files
authored
Merge pull request #1240 from kohya-ss/verify-command-line-args
verify command line args if wandb is enabled
2 parents b748b48 + cd587ce commit 921036d

12 files changed

+66
-1
lines changed

Diff for: fine_tune.py

+1
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
520520
parser = setup_parser()
521521

522522
args = parser.parse_args()
523+
train_util.verify_command_line_training_args(args)
523524
args = train_util.read_config_from_file(args, parser)
524525

525526
train(args)

Diff for: library/train_util.py

+55-1
Original file line numberDiff line numberDiff line change
@@ -1890,7 +1890,7 @@ def __init__(
18901890
subset.image_dir,
18911891
False,
18921892
None,
1893-
subset.caption_extension,
1893+
subset.caption_extension,
18941894
subset.cache_info,
18951895
subset.num_repeats,
18961896
subset.shuffle_caption,
@@ -3358,6 +3358,60 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
33583358
)
33593359

33603360

3361+
# verify command line args for training
3362+
def verify_command_line_training_args(args: argparse.Namespace):
3363+
# if wandb is enabled, the command line is exposed to the public
3364+
# check whether sensitive options are included in the command line arguments
3365+
# if so, warn or inform the user to move them to the configuration file
3366+
# wandbが有効な場合、コマンドラインが公開される
3367+
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
3368+
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する
3369+
3370+
wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
3371+
if not wandb_enabled:
3372+
return
3373+
3374+
sensitive_args = ["wandb_api_key", "huggingface_token"]
3375+
sensitive_path_args = [
3376+
"pretrained_model_name_or_path",
3377+
"vae",
3378+
"tokenizer_cache_dir",
3379+
"train_data_dir",
3380+
"conditioning_data_dir",
3381+
"reg_data_dir",
3382+
"output_dir",
3383+
"logging_dir",
3384+
]
3385+
3386+
for arg in sensitive_args:
3387+
if getattr(args, arg, None) is not None:
3388+
logger.warning(
3389+
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
3390+
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
3391+
)
3392+
3393+
# if path is absolute, it may include sensitive information
3394+
for arg in sensitive_path_args:
3395+
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
3396+
logger.info(
3397+
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
3398+
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
3399+
)
3400+
3401+
if getattr(args, "config_file", None) is not None:
3402+
logger.info(
3403+
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
3404+
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
3405+
)
3406+
3407+
# other sensitive options
3408+
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
3409+
logger.info(
3410+
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
3411+
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
3412+
)
3413+
3414+
33613415
def verify_training_args(args: argparse.Namespace):
33623416
r"""
33633417
Verify training arguments. Also reflect highvram option to global variable

Diff for: sdxl_train.py

+1
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ def setup_parser() -> argparse.ArgumentParser:
812812
parser = setup_parser()
813813

814814
args = parser.parse_args()
815+
train_util.verify_command_line_training_args(args)
815816
args = train_util.read_config_from_file(args, parser)
816817

817818
train(args)

Diff for: sdxl_train_control_net_lllite.py

+1
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def setup_parser() -> argparse.ArgumentParser:
612612
parser = setup_parser()
613613

614614
args = parser.parse_args()
615+
train_util.verify_command_line_training_args(args)
615616
args = train_util.read_config_from_file(args, parser)
616617

617618
train(args)

Diff for: sdxl_train_control_net_lllite_old.py

+1
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def setup_parser() -> argparse.ArgumentParser:
580580
parser = setup_parser()
581581

582582
args = parser.parse_args()
583+
train_util.verify_command_line_training_args(args)
583584
args = train_util.read_config_from_file(args, parser)
584585

585586
train(args)

Diff for: sdxl_train_network.py

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def setup_parser() -> argparse.ArgumentParser:
178178
parser = setup_parser()
179179

180180
args = parser.parse_args()
181+
train_util.verify_command_line_training_args(args)
181182
args = train_util.read_config_from_file(args, parser)
182183

183184
trainer = SdxlNetworkTrainer()

Diff for: sdxl_train_textual_inversion.py

+1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def setup_parser() -> argparse.ArgumentParser:
131131
parser = setup_parser()
132132

133133
args = parser.parse_args()
134+
train_util.verify_command_line_training_args(args)
134135
args = train_util.read_config_from_file(args, parser)
135136

136137
trainer = SdxlTextualInversionTrainer()

Diff for: train_controlnet.py

+1
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ def setup_parser() -> argparse.ArgumentParser:
617617
parser = setup_parser()
618618

619619
args = parser.parse_args()
620+
train_util.verify_command_line_training_args(args)
620621
args = train_util.read_config_from_file(args, parser)
621622

622623
train(args)

Diff for: train_db.py

+1
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,7 @@ def setup_parser() -> argparse.ArgumentParser:
523523
parser = setup_parser()
524524

525525
args = parser.parse_args()
526+
train_util.verify_command_line_training_args(args)
526527
args = train_util.read_config_from_file(args, parser)
527528

528529
train(args)

Diff for: train_network.py

+1
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,7 @@ def setup_parser() -> argparse.ArgumentParser:
11011101
parser = setup_parser()
11021102

11031103
args = parser.parse_args()
1104+
train_util.verify_command_line_training_args(args)
11041105
args = train_util.read_config_from_file(args, parser)
11051106

11061107
trainer = NetworkTrainer()

Diff for: train_textual_inversion.py

+1
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ def setup_parser() -> argparse.ArgumentParser:
806806
parser = setup_parser()
807807

808808
args = parser.parse_args()
809+
train_util.verify_command_line_training_args(args)
809810
args = train_util.read_config_from_file(args, parser)
810811

811812
trainer = TextualInversionTrainer()

Diff for: train_textual_inversion_XTI.py

+1
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ def setup_parser() -> argparse.ArgumentParser:
714714
parser = setup_parser()
715715

716716
args = parser.parse_args()
717+
train_util.verify_command_line_training_args(args)
717718
args = train_util.read_config_from_file(args, parser)
718719

719720
train(args)

0 commit comments

Comments
 (0)