Skip to content

Commit 2c9db5d

Browse files
committed
passing filtered hyperparameters to accelerate
1 parent 71e2c91 commit 2c9db5d

10 files changed

+23
-9
lines changed

fine_tune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
310310
init_kwargs["wandb"] = {"name": args.wandb_run_name}
311311
if args.log_tracker_config is not None:
312312
init_kwargs = toml.load(args.log_tracker_config)
313-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
313+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
314314

315315
# For --sample_at_first
316316
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

library/train_util.py

+14
Original file line numberDiff line numberDiff line change
@@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser):
33783378
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
33793379
)
33803380

3381+
def filter_sensitive_args(args: argparse.Namespace):
3382+
sensitive_args = ["wandb_api_key", "huggingface_token"]
3383+
sensitive_path_args = [
3384+
"pretrained_model_name_or_path",
3385+
"vae",
3386+
"tokenizer_cache_dir",
3387+
"train_data_dir",
3388+
"conditioning_data_dir",
3389+
"reg_data_dir",
3390+
"output_dir",
3391+
"logging_dir",
3392+
]
3393+
filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args}
3394+
return filtered_args
33813395

33823396
# verify command line args for training
33833397
def verify_command_line_training_args(args: argparse.Namespace):

sdxl_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
487487
init_kwargs["wandb"] = {"name": args.wandb_run_name}
488488
if args.log_tracker_config is not None:
489489
init_kwargs = toml.load(args.log_tracker_config)
490-
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
490+
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
491491

492492
# For --sample_at_first
493493
sdxl_train_util.sample_images(

sdxl_train_control_net_lllite.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def train(args):
353353
if args.log_tracker_config is not None:
354354
init_kwargs = toml.load(args.log_tracker_config)
355355
accelerator.init_trackers(
356-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
356+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
357357
)
358358

359359
loss_recorder = train_util.LossRecorder()

sdxl_train_control_net_lllite_old.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def train(args):
324324
if args.log_tracker_config is not None:
325325
init_kwargs = toml.load(args.log_tracker_config)
326326
accelerator.init_trackers(
327-
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
327+
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
328328
)
329329

330330
loss_recorder = train_util.LossRecorder()

train_controlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def train(args):
344344
if args.log_tracker_config is not None:
345345
init_kwargs = toml.load(args.log_tracker_config)
346346
accelerator.init_trackers(
347-
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
347+
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
348348
)
349349

350350
loss_recorder = train_util.LossRecorder()

train_db.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def train(args):
290290
init_kwargs["wandb"] = {"name": args.wandb_run_name}
291291
if args.log_tracker_config is not None:
292292
init_kwargs = toml.load(args.log_tracker_config)
293-
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
293+
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs)
294294

295295
# For --sample_at_first
296296
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

train_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def load_model_hook(models, input_dir):
753753
if args.log_tracker_config is not None:
754754
init_kwargs = toml.load(args.log_tracker_config)
755755
accelerator.init_trackers(
756-
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
756+
"network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
757757
)
758758

759759
loss_recorder = train_util.LossRecorder()

train_textual_inversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def train(self, args):
510510
if args.log_tracker_config is not None:
511511
init_kwargs = toml.load(args.log_tracker_config)
512512
accelerator.init_trackers(
513-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
513+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
514514
)
515515

516516
# function for saving/removing

train_textual_inversion_XTI.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def train(args):
407407
if args.log_tracker_config is not None:
408408
init_kwargs = toml.load(args.log_tracker_config)
409409
accelerator.init_trackers(
410-
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
410+
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs
411411
)
412412

413413
# function for saving/removing

0 commit comments

Comments
 (0)