@@ -1092,7 +1092,17 @@ def compute_text_embeddings(prompt):
1092
1092
revision = args .revision ,
1093
1093
torch_dtype = weight_dtype ,
1094
1094
)
1095
- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1095
+
1096
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1097
+ variance_type = pipeline .scheduler .config .variance_type
1098
+
1099
+ if variance_type in ["learned" , "learned_range" ]:
1100
+ variance_type = "fixed_small"
1101
+
1102
+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1103
+ pipeline .scheduler .config , variance_type = variance_type
1104
+ )
1105
+
1096
1106
pipeline = pipeline .to (accelerator .device )
1097
1107
pipeline .set_progress_bar_config (disable = True )
1098
1108
@@ -1143,7 +1153,17 @@ def compute_text_embeddings(prompt):
1143
1153
pipeline = DiffusionPipeline .from_pretrained (
1144
1154
args .pretrained_model_name_or_path , revision = args .revision , torch_dtype = weight_dtype
1145
1155
)
1146
- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1156
+
1157
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1158
+ variance_type = pipeline .scheduler .config .variance_type
1159
+
1160
+ if variance_type in ["learned" , "learned_range" ]:
1161
+ variance_type = "fixed_small"
1162
+
1163
+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1164
+ pipeline .scheduler .config , variance_type = variance_type
1165
+ )
1166
+
1147
1167
pipeline = pipeline .to (accelerator .device )
1148
1168
1149
1169
# load attention processors
0 commit comments