Skip to content

Commit e2ead7c

Browse files
leisuzz蒋硕sayakpaul
authored
Fix the issue on sd3 dreambooth w./w.t. lora training (#9419)
* Fix dtype error * [bugfix] Fixed the issue on sd3 dreambooth training * [bugfix] Fixed the issue on sd3 dreambooth training --------- Co-authored-by: 蒋硕 <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 48e3635 commit e2ead7c

6 files changed

+24
-6
lines changed

Diff for: examples/dreambooth/train_dreambooth_flux.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,14 @@ def log_validation(
154154
accelerator,
155155
pipeline_args,
156156
epoch,
157+
torch_dtype,
157158
is_final_validation=False,
158159
):
159160
logger.info(
160161
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
161162
f" {args.validation_prompt}."
162163
)
163-
pipeline = pipeline.to(accelerator.device)
164+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
164165
pipeline.set_progress_bar_config(disable=True)
165166

166167
# run inference
@@ -1717,6 +1718,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17171718
accelerator=accelerator,
17181719
pipeline_args=pipeline_args,
17191720
epoch=epoch,
1721+
torch_dtype=weight_dtype,
17201722
)
17211723
if not args.train_text_encoder:
17221724
del text_encoder_one, text_encoder_two
@@ -1761,6 +1763,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17611763
pipeline_args=pipeline_args,
17621764
epoch=epoch,
17631765
is_final_validation=True,
1766+
torch_dtype=weight_dtype,
17641767
)
17651768

17661769
if args.push_to_hub:

Diff for: examples/dreambooth/train_dreambooth_lora.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def log_validation(
122122
accelerator,
123123
pipeline_args,
124124
epoch,
125+
torch_dtype,
125126
is_final_validation=False,
126127
):
127128
logger.info(
@@ -141,7 +142,7 @@ def log_validation(
141142

142143
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
143144

144-
pipeline = pipeline.to(accelerator.device)
145+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
145146
pipeline.set_progress_bar_config(disable=True)
146147

147148
# run inference
@@ -1360,6 +1361,7 @@ def compute_text_embeddings(prompt):
13601361
accelerator,
13611362
pipeline_args,
13621363
epoch,
1364+
torch_dtype=weight_dtype,
13631365
)
13641366

13651367
# Save the lora layers
@@ -1402,6 +1404,7 @@ def compute_text_embeddings(prompt):
14021404
pipeline_args,
14031405
epoch,
14041406
is_final_validation=True,
1407+
torch_dtype=weight_dtype,
14051408
)
14061409

14071410
if args.push_to_hub:

Diff for: examples/dreambooth/train_dreambooth_lora_flux.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,14 @@ def log_validation(
170170
accelerator,
171171
pipeline_args,
172172
epoch,
173+
torch_dtype,
173174
is_final_validation=False,
174175
):
175176
logger.info(
176177
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
177178
f" {args.validation_prompt}."
178179
)
179-
pipeline = pipeline.to(accelerator.device)
180+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
180181
pipeline.set_progress_bar_config(disable=True)
181182

182183
# run inference
@@ -1785,6 +1786,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17851786
accelerator=accelerator,
17861787
pipeline_args=pipeline_args,
17871788
epoch=epoch,
1789+
torch_dtype=weight_dtype,
17881790
)
17891791
if not args.train_text_encoder:
17901792
del text_encoder_one, text_encoder_two
@@ -1832,6 +1834,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18321834
pipeline_args=pipeline_args,
18331835
epoch=epoch,
18341836
is_final_validation=True,
1837+
torch_dtype=weight_dtype,
18351838
)
18361839

18371840
if args.push_to_hub:

Diff for: examples/dreambooth/train_dreambooth_lora_sd3.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,14 @@ def log_validation(
179179
accelerator,
180180
pipeline_args,
181181
epoch,
182+
torch_dtype,
182183
is_final_validation=False,
183184
):
184185
logger.info(
185186
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
186187
f" {args.validation_prompt}."
187188
)
188-
pipeline = pipeline.to(accelerator.device)
189+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
189190
pipeline.set_progress_bar_config(disable=True)
190191

191192
# run inference
@@ -1788,6 +1789,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17881789
accelerator=accelerator,
17891790
pipeline_args=pipeline_args,
17901791
epoch=epoch,
1792+
torch_dtype=weight_dtype,
17911793
)
17921794
objs = []
17931795
if not args.train_text_encoder:
@@ -1840,6 +1842,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18401842
pipeline_args=pipeline_args,
18411843
epoch=epoch,
18421844
is_final_validation=True,
1845+
torch_dtype=weight_dtype,
18431846
)
18441847

18451848
if args.push_to_hub:

Diff for: examples/dreambooth/train_dreambooth_lora_sdxl.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def log_validation(
180180
accelerator,
181181
pipeline_args,
182182
epoch,
183+
torch_dtype,
183184
is_final_validation=False,
184185
):
185186
logger.info(
@@ -201,7 +202,7 @@ def log_validation(
201202

202203
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
203204

204-
pipeline = pipeline.to(accelerator.device)
205+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
205206
pipeline.set_progress_bar_config(disable=True)
206207

207208
# run inference
@@ -1890,6 +1891,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18901891
accelerator,
18911892
pipeline_args,
18921893
epoch,
1894+
torch_dtype=weight_dtype,
18931895
)
18941896

18951897
# Save the lora layers
@@ -1955,6 +1957,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
19551957
pipeline_args,
19561958
epoch,
19571959
is_final_validation=True,
1960+
torch_dtype=weight_dtype,
19581961
)
19591962

19601963
if args.push_to_hub:

Diff for: examples/dreambooth/train_dreambooth_sd3.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,14 @@ def log_validation(
157157
accelerator,
158158
pipeline_args,
159159
epoch,
160+
torch_dtype,
160161
is_final_validation=False,
161162
):
162163
logger.info(
163164
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
164165
f" {args.validation_prompt}."
165166
)
166-
pipeline = pipeline.to(accelerator.device)
167+
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
167168
pipeline.set_progress_bar_config(disable=True)
168169

169170
# run inference
@@ -1725,6 +1726,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17251726
accelerator=accelerator,
17261727
pipeline_args=pipeline_args,
17271728
epoch=epoch,
1729+
torch_dtype=weight_dtype,
17281730
)
17291731
if not args.train_text_encoder:
17301732
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -1775,6 +1777,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17751777
pipeline_args=pipeline_args,
17761778
epoch=epoch,
17771779
is_final_validation=True,
1780+
torch_dtype=weight_dtype,
17781781
)
17791782

17801783
if args.push_to_hub:

0 commit comments

Comments
 (0)