23
23
import unittest
24
24
from typing import List
25
25
26
- import torch
26
+ import safetensors
27
27
from accelerate .utils import write_basic_config
28
28
29
29
from diffusers import DiffusionPipeline , UNet2DConditionModel
@@ -93,7 +93,7 @@ def test_train_unconditional(self):
93
93
94
94
run_command (self ._launch_args + test_args , return_stdout = True )
95
95
# save_pretrained smoke test
96
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
96
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
97
97
self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
98
98
99
99
def test_textual_inversion (self ):
@@ -144,7 +144,7 @@ def test_dreambooth(self):
144
144
145
145
run_command (self ._launch_args + test_args )
146
146
# save_pretrained smoke test
147
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
147
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
148
148
self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
149
149
150
150
def test_dreambooth_if (self ):
@@ -170,7 +170,7 @@ def test_dreambooth_if(self):
170
170
171
171
run_command (self ._launch_args + test_args )
172
172
# save_pretrained smoke test
173
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
173
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
174
174
self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
175
175
176
176
def test_dreambooth_checkpointing (self ):
@@ -272,10 +272,10 @@ def test_dreambooth_lora(self):
272
272
273
273
run_command (self ._launch_args + test_args )
274
274
# save_pretrained smoke test
275
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
275
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
276
276
277
277
# make sure the state_dict has the correct naming in the parameters.
278
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
278
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
279
279
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
280
280
self .assertTrue (is_lora )
281
281
@@ -305,10 +305,10 @@ def test_dreambooth_lora_with_text_encoder(self):
305
305
306
306
run_command (self ._launch_args + test_args )
307
307
# save_pretrained smoke test
308
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
308
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
309
309
310
310
# check `text_encoder` is present at all.
311
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
311
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
312
312
keys = lora_state_dict .keys ()
313
313
is_text_encoder_present = any (k .startswith ("text_encoder" ) for k in keys )
314
314
self .assertTrue (is_text_encoder_present )
@@ -341,10 +341,10 @@ def test_dreambooth_lora_if_model(self):
341
341
342
342
run_command (self ._launch_args + test_args )
343
343
# save_pretrained smoke test
344
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
344
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
345
345
346
346
# make sure the state_dict has the correct naming in the parameters.
347
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
347
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
348
348
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
349
349
self .assertTrue (is_lora )
350
350
@@ -373,10 +373,10 @@ def test_dreambooth_lora_sdxl(self):
373
373
374
374
run_command (self ._launch_args + test_args )
375
375
# save_pretrained smoke test
376
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
376
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
377
377
378
378
# make sure the state_dict has the correct naming in the parameters.
379
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
379
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
380
380
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
381
381
self .assertTrue (is_lora )
382
382
@@ -406,10 +406,10 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
406
406
407
407
run_command (self ._launch_args + test_args )
408
408
# save_pretrained smoke test
409
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
409
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
410
410
411
411
# make sure the state_dict has the correct naming in the parameters.
412
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
412
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
413
413
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
414
414
self .assertTrue (is_lora )
415
415
@@ -437,6 +437,7 @@ def test_custom_diffusion(self):
437
437
--lr_scheduler constant
438
438
--lr_warmup_steps 0
439
439
--modifier_token <new1>
440
+ --no_safe_serialization
440
441
--output_dir { tmpdir }
441
442
""" .split ()
442
443
@@ -466,7 +467,7 @@ def test_text_to_image(self):
466
467
467
468
run_command (self ._launch_args + test_args )
468
469
# save_pretrained smoke test
469
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
470
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
470
471
self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
471
472
472
473
def test_text_to_image_checkpointing (self ):
@@ -778,7 +779,7 @@ def test_text_to_image_sdxl(self):
778
779
779
780
run_command (self ._launch_args + test_args )
780
781
# save_pretrained smoke test
781
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.bin " )))
782
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "unet" , "diffusion_pytorch_model.safetensors " )))
782
783
self .assertTrue (os .path .isfile (os .path .join (tmpdir , "scheduler" , "scheduler_config.json" )))
783
784
784
785
def test_text_to_image_lora_checkpointing_checkpoints_total_limit (self ):
@@ -1373,7 +1374,7 @@ def test_controlnet_sdxl(self):
1373
1374
1374
1375
run_command (self ._launch_args + test_args )
1375
1376
1376
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "diffusion_pytorch_model.bin " )))
1377
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "diffusion_pytorch_model.safetensors " )))
1377
1378
1378
1379
def test_custom_diffusion_checkpointing_checkpoints_total_limit (self ):
1379
1380
with tempfile .TemporaryDirectory () as tmpdir :
@@ -1390,6 +1391,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
1390
1391
--max_train_steps=6
1391
1392
--checkpoints_total_limit=2
1392
1393
--checkpointing_steps=2
1394
+ --no_safe_serialization
1393
1395
""" .split ()
1394
1396
1395
1397
run_command (self ._launch_args + test_args )
@@ -1413,6 +1415,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
1413
1415
--dataloader_num_workers=0
1414
1416
--max_train_steps=9
1415
1417
--checkpointing_steps=2
1418
+ --no_safe_serialization
1416
1419
""" .split ()
1417
1420
1418
1421
run_command (self ._launch_args + test_args )
@@ -1436,6 +1439,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
1436
1439
--checkpointing_steps=2
1437
1440
--resume_from_checkpoint=checkpoint-8
1438
1441
--checkpoints_total_limit=3
1442
+ --no_safe_serialization
1439
1443
""" .split ()
1440
1444
1441
1445
run_command (self ._launch_args + resume_run_args )
@@ -1464,10 +1468,10 @@ def test_text_to_image_lora_sdxl(self):
1464
1468
1465
1469
run_command (self ._launch_args + test_args )
1466
1470
# save_pretrained smoke test
1467
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
1471
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
1468
1472
1469
1473
# make sure the state_dict has the correct naming in the parameters.
1470
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
1474
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
1471
1475
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
1472
1476
self .assertTrue (is_lora )
1473
1477
@@ -1491,10 +1495,10 @@ def test_text_to_image_lora_sdxl_with_text_encoder(self):
1491
1495
1492
1496
run_command (self ._launch_args + test_args )
1493
1497
# save_pretrained smoke test
1494
- self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.bin " )))
1498
+ self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " )))
1495
1499
1496
1500
# make sure the state_dict has the correct naming in the parameters.
1497
- lora_state_dict = torch .load (os .path .join (tmpdir , "pytorch_lora_weights.bin " ))
1501
+ lora_state_dict = safetensors . torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors " ))
1498
1502
is_lora = all ("lora" in k for k in lora_state_dict .keys ())
1499
1503
self .assertTrue (is_lora )
1500
1504
0 commit comments