Skip to content

Commit 7b5d888

Browse files
committed
set new variance type in schedulers
1 parent 0401277 commit 7b5d888

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,17 @@ def compute_text_embeddings(prompt):
10921092
revision=args.revision,
10931093
torch_dtype=weight_dtype,
10941094
)
1095-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1095+
1096+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1097+
variance_type = pipeline.scheduler.config.variance_type
1098+
1099+
if variance_type in ["learned", "learned_range"]:
1100+
variance_type = "fixed_small"
1101+
1102+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1103+
pipeline.scheduler.config, variance_type=variance_type
1104+
)
1105+
10961106
pipeline = pipeline.to(accelerator.device)
10971107
pipeline.set_progress_bar_config(disable=True)
10981108

@@ -1143,7 +1153,17 @@ def compute_text_embeddings(prompt):
11431153
pipeline = DiffusionPipeline.from_pretrained(
11441154
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
11451155
)
1146-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1156+
1157+
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1158+
variance_type = pipeline.scheduler.config.variance_type
1159+
1160+
if variance_type in ["learned", "learned_range"]:
1161+
variance_type = "fixed_small"
1162+
1163+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1164+
pipeline.scheduler.config, variance_type=variance_type
1165+
)
1166+
11471167
pipeline = pipeline.to(accelerator.device)
11481168

11491169
# load attention processors

0 commit comments

Comments
 (0)