@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
122
122
123
123
for _ in range (args .num_validation_images ):
124
124
with autocast_ctx :
125
- # need to fix in pipeline_flux_controlnet
126
125
image = pipeline (
127
126
prompt = validation_prompt ,
128
127
control_image = validation_image ,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
159
158
images = log ["images" ]
160
159
validation_prompt = log ["validation_prompt" ]
161
160
validation_image = log ["validation_image" ]
162
- formatted_images .append (wandb .Image (validation_image , caption = "Controlnet conditioning " ))
161
+ formatted_images .append (wandb .Image (validation_image , caption = "Conditioning " ))
163
162
for image in images :
164
163
image = wandb .Image (image , caption = validation_prompt )
165
164
formatted_images .append (image )
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
188
187
img_str += f"\n "
189
188
190
189
model_description = f"""
191
- # control-lora -{ repo_id }
190
+ # flux-control -{ repo_id }
192
191
193
192
These are Control weights trained on { base_model } with new type of conditioning.
194
193
{ img_str }
@@ -434,14 +433,15 @@ def parse_args(input_args=None):
434
433
"--conditioning_image_column" ,
435
434
type = str ,
436
435
default = "conditioning_image" ,
437
- help = "The column of the dataset containing the controlnet conditioning image." ,
436
+ help = "The column of the dataset containing the control conditioning image." ,
438
437
)
439
438
parser .add_argument (
440
439
"--caption_column" ,
441
440
type = str ,
442
441
default = "text" ,
443
442
help = "The column of the dataset containing a caption or a list of captions." ,
444
443
)
444
+ parser .add_argument ("--log_dataset_samples" , action = "store_true" , help = "Whether to log somple dataset samples." )
445
445
parser .add_argument (
446
446
"--max_train_samples" ,
447
447
type = int ,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
468
468
default = None ,
469
469
nargs = "+" ,
470
470
help = (
471
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
471
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
472
472
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
473
473
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
474
474
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
505
505
default = None ,
506
506
help = "Path to the jsonl file containing the training data." ,
507
507
)
508
-
508
+ parser .add_argument (
509
+ "--only_target_transformer_blocks" ,
510
+ action = "store_true" ,
511
+ help = "If we should only target the transformer blocks to train along with the input layer (`x_embedder`)." ,
512
+ )
509
513
parser .add_argument (
510
514
"--guidance_scale" ,
511
515
type = float ,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
581
585
582
586
if args .resolution % 8 != 0 :
583
587
raise ValueError (
584
- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder ."
588
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer ."
585
589
)
586
590
587
591
return args
@@ -665,7 +669,12 @@ def preprocess_train(examples):
665
669
conditioning_images = [image_transforms (image ) for image in conditioning_images ]
666
670
examples ["pixel_values" ] = images
667
671
examples ["conditioning_pixel_values" ] = conditioning_images
668
- examples ["captions" ] = list (examples [args .caption_column ])
672
+
673
+ is_caption_list = isinstance (examples [args .caption_column ][0 ], list )
674
+ if is_caption_list :
675
+ examples ["captions" ] = [max (example , key = len ) for example in examples [args .caption_column ]]
676
+ else :
677
+ examples ["captions" ] = list (examples [args .caption_column ])
669
678
670
679
return examples
671
680
@@ -765,7 +774,8 @@ def main(args):
765
774
subfolder = "scheduler" ,
766
775
)
767
776
noise_scheduler_copy = copy .deepcopy (noise_scheduler )
768
- flux_transformer .requires_grad_ (True )
777
+ if not args .only_target_transformer_blocks :
778
+ flux_transformer .requires_grad_ (True )
769
779
vae .requires_grad_ (False )
770
780
771
781
# cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
797
807
assert torch .all (flux_transformer .x_embedder .weight [:, initial_input_channels :].data == 0 )
798
808
flux_transformer .register_to_config (in_channels = initial_input_channels * 2 , out_channels = initial_input_channels )
799
809
810
+ if args .only_target_transformer_blocks :
811
+ flux_transformer .x_embedder .requires_grad_ (True )
812
+ for name , module in flux_transformer .named_modules ():
813
+ if "transformer_blocks" in name :
814
+ module .requires_grad_ (True )
815
+
800
816
def unwrap_model (model ):
801
817
model = accelerator .unwrap_model (model )
802
818
model = model ._orig_mod if is_compiled_module (model ) else model
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
974
990
else :
975
991
initial_global_step = 0
976
992
993
+ if accelerator .is_main_process and args .report_to == "wandb" and args .log_dataset_samples :
994
+ logger .info ("Logging some dataset samples." )
995
+ formatted_images = []
996
+ formatted_control_images = []
997
+ all_prompts = []
998
+ for i , batch in enumerate (train_dataloader ):
999
+ images = (batch ["pixel_values" ] + 1 ) / 2
1000
+ control_images = (batch ["conditioning_pixel_values" ] + 1 ) / 2
1001
+ prompts = batch ["captions" ]
1002
+
1003
+ if len (formatted_images ) > 10 :
1004
+ break
1005
+
1006
+ for img , control_img , prompt in zip (images , control_images , prompts ):
1007
+ formatted_images .append (img )
1008
+ formatted_control_images .append (control_img )
1009
+ all_prompts .append (prompt )
1010
+
1011
+ logged_artifacts = []
1012
+ for img , control_img , prompt in zip (formatted_images , formatted_control_images , all_prompts ):
1013
+ logged_artifacts .append (wandb .Image (control_img , caption = "Conditioning" ))
1014
+ logged_artifacts .append (wandb .Image (img , caption = prompt ))
1015
+
1016
+ wandb_tracker = [tracker for tracker in accelerator .trackers if tracker .name == "wandb" ]
1017
+ wandb_tracker [0 ].log ({"dataset_samples" : logged_artifacts })
1018
+
977
1019
progress_bar = tqdm (
978
1020
range (0 , args .max_train_steps ),
979
1021
initial = initial_global_step ,
0 commit comments