Skip to content

Commit b94cfd7

Browse files
authored
[Training] QoL improvements in the Flux Control training scripts (#10461)
* qol improvements to the Flux script. * propagate the dataloader changes.
1 parent 661bde0 commit b94cfd7

File tree

3 files changed

+93
-20
lines changed

3 files changed

+93
-20
lines changed

Diff for: examples/flux-control/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
121121

122122
gen_images = pipe(
123123
prompt=prompt,
124-
condition_image=image,
124+
control_image=image,
125125
num_inference_steps=50,
126126
joint_attention_kwargs={"scale": 0.9},
127127
guidance_scale=25.,
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
190190

191191
gen_images = pipe(
192192
prompt=prompt,
193-
condition_image=image,
193+
control_image=image,
194194
num_inference_steps=50,
195195
guidance_scale=25.,
196196
).images[0]
@@ -200,5 +200,5 @@ gen_images.save("output.png")
200200
## Things to note
201201

202202
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
203-
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
203+
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
204204
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.

Diff for: examples/flux-control/train_control_flux.py

+51-9
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
122122

123123
for _ in range(args.num_validation_images):
124124
with autocast_ctx:
125-
# need to fix in pipeline_flux_controlnet
126125
image = pipeline(
127126
prompt=validation_prompt,
128127
control_image=validation_image,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
159158
images = log["images"]
160159
validation_prompt = log["validation_prompt"]
161160
validation_image = log["validation_image"]
162-
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
161+
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
163162
for image in images:
164163
image = wandb.Image(image, caption=validation_prompt)
165164
formatted_images.append(image)
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
188187
img_str += f"![images_{i})](./images_{i}.png)\n"
189188

190189
model_description = f"""
191-
# control-lora-{repo_id}
190+
# flux-control-{repo_id}
192191
193192
These are Control weights trained on {base_model} with new type of conditioning.
194193
{img_str}
@@ -434,14 +433,15 @@ def parse_args(input_args=None):
434433
"--conditioning_image_column",
435434
type=str,
436435
default="conditioning_image",
437-
help="The column of the dataset containing the controlnet conditioning image.",
436+
help="The column of the dataset containing the control conditioning image.",
438437
)
439438
parser.add_argument(
440439
"--caption_column",
441440
type=str,
442441
default="text",
443442
help="The column of the dataset containing a caption or a list of captions.",
444443
)
444+
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
445445
parser.add_argument(
446446
"--max_train_samples",
447447
type=int,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
468468
default=None,
469469
nargs="+",
470470
help=(
471-
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
471+
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
472472
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
473473
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
474474
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
505505
default=None,
506506
help="Path to the jsonl file containing the training data.",
507507
)
508-
508+
parser.add_argument(
509+
"--only_target_transformer_blocks",
510+
action="store_true",
511+
help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
512+
)
509513
parser.add_argument(
510514
"--guidance_scale",
511515
type=float,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
581585

582586
if args.resolution % 8 != 0:
583587
raise ValueError(
584-
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
588+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
585589
)
586590

587591
return args
@@ -665,7 +669,12 @@ def preprocess_train(examples):
665669
conditioning_images = [image_transforms(image) for image in conditioning_images]
666670
examples["pixel_values"] = images
667671
examples["conditioning_pixel_values"] = conditioning_images
668-
examples["captions"] = list(examples[args.caption_column])
672+
673+
is_caption_list = isinstance(examples[args.caption_column][0], list)
674+
if is_caption_list:
675+
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
676+
else:
677+
examples["captions"] = list(examples[args.caption_column])
669678

670679
return examples
671680

@@ -765,7 +774,8 @@ def main(args):
765774
subfolder="scheduler",
766775
)
767776
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
768-
flux_transformer.requires_grad_(True)
777+
if not args.only_target_transformer_blocks:
778+
flux_transformer.requires_grad_(True)
769779
vae.requires_grad_(False)
770780

771781
# cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
797807
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
798808
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
799809

810+
if args.only_target_transformer_blocks:
811+
flux_transformer.x_embedder.requires_grad_(True)
812+
for name, module in flux_transformer.named_modules():
813+
if "transformer_blocks" in name:
814+
module.requires_grad_(True)
815+
800816
def unwrap_model(model):
801817
model = accelerator.unwrap_model(model)
802818
model = model._orig_mod if is_compiled_module(model) else model
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
974990
else:
975991
initial_global_step = 0
976992

993+
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
994+
logger.info("Logging some dataset samples.")
995+
formatted_images = []
996+
formatted_control_images = []
997+
all_prompts = []
998+
for i, batch in enumerate(train_dataloader):
999+
images = (batch["pixel_values"] + 1) / 2
1000+
control_images = (batch["conditioning_pixel_values"] + 1) / 2
1001+
prompts = batch["captions"]
1002+
1003+
if len(formatted_images) > 10:
1004+
break
1005+
1006+
for img, control_img, prompt in zip(images, control_images, prompts):
1007+
formatted_images.append(img)
1008+
formatted_control_images.append(control_img)
1009+
all_prompts.append(prompt)
1010+
1011+
logged_artifacts = []
1012+
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1013+
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1014+
logged_artifacts.append(wandb.Image(img, caption=prompt))
1015+
1016+
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1017+
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1018+
9771019
progress_bar = tqdm(
9781020
range(0, args.max_train_steps),
9791021
initial=initial_global_step,

Diff for: examples/flux-control/train_control_lora_flux.py

+39-8
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
132132

133133
for _ in range(args.num_validation_images):
134134
with autocast_ctx:
135-
# need to fix in pipeline_flux_controlnet
136135
image = pipeline(
137136
prompt=validation_prompt,
138137
control_image=validation_image,
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
169168
images = log["images"]
170169
validation_prompt = log["validation_prompt"]
171170
validation_image = log["validation_image"]
172-
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
171+
formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
173172
for image in images:
174173
image = wandb.Image(image, caption=validation_prompt)
175174
formatted_images.append(image)
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
198197
img_str += f"![images_{i})](./images_{i}.png)\n"
199198

200199
model_description = f"""
201-
# controlnet-lora-{repo_id}
200+
# control-lora-{repo_id}
202201
203202
These are Control LoRA weights trained on {base_model} with new type of conditioning.
204203
{img_str}
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
256255
parser.add_argument(
257256
"--output_dir",
258257
type=str,
259-
default="controlnet-lora",
258+
default="control-lora",
260259
help="The output directory where the model predictions and checkpoints will be written.",
261260
)
262261
parser.add_argument(
@@ -466,14 +465,15 @@ def parse_args(input_args=None):
466465
"--conditioning_image_column",
467466
type=str,
468467
default="conditioning_image",
469-
help="The column of the dataset containing the controlnet conditioning image.",
468+
help="The column of the dataset containing the control conditioning image.",
470469
)
471470
parser.add_argument(
472471
"--caption_column",
473472
type=str,
474473
default="text",
475474
help="The column of the dataset containing a caption or a list of captions.",
476475
)
476+
parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
477477
parser.add_argument(
478478
"--max_train_samples",
479479
type=int,
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
500500
default=None,
501501
nargs="+",
502502
help=(
503-
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
503+
"A set of paths to the control conditioning image be evaluated every `--validation_steps`"
504504
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
505505
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
506506
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
613613

614614
if args.resolution % 8 != 0:
615615
raise ValueError(
616-
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
616+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
617617
)
618618

619619
return args
@@ -697,7 +697,12 @@ def preprocess_train(examples):
697697
conditioning_images = [image_transforms(image) for image in conditioning_images]
698698
examples["pixel_values"] = images
699699
examples["conditioning_pixel_values"] = conditioning_images
700-
examples["captions"] = list(examples[args.caption_column])
700+
701+
is_caption_list = isinstance(examples[args.caption_column][0], list)
702+
if is_caption_list:
703+
examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
704+
else:
705+
examples["captions"] = list(examples[args.caption_column])
701706

702707
return examples
703708

@@ -1132,6 +1137,32 @@ def load_model_hook(models, input_dir):
11321137
else:
11331138
initial_global_step = 0
11341139

1140+
if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
1141+
logger.info("Logging some dataset samples.")
1142+
formatted_images = []
1143+
formatted_control_images = []
1144+
all_prompts = []
1145+
for i, batch in enumerate(train_dataloader):
1146+
images = (batch["pixel_values"] + 1) / 2
1147+
control_images = (batch["conditioning_pixel_values"] + 1) / 2
1148+
prompts = batch["captions"]
1149+
1150+
if len(formatted_images) > 10:
1151+
break
1152+
1153+
for img, control_img, prompt in zip(images, control_images, prompts):
1154+
formatted_images.append(img)
1155+
formatted_control_images.append(control_img)
1156+
all_prompts.append(prompt)
1157+
1158+
logged_artifacts = []
1159+
for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
1160+
logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
1161+
logged_artifacts.append(wandb.Image(img, caption=prompt))
1162+
1163+
wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
1164+
wandb_tracker[0].log({"dataset_samples": logged_artifacts})
1165+
11351166
progress_bar = tqdm(
11361167
range(0, args.max_train_steps),
11371168
initial=initial_global_step,

0 commit comments

Comments
 (0)