@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
44
44
weight_dtype ,
45
45
accelerator .device if args .lowram else "cpu" ,
46
46
model_dtype ,
47
+ args .disable_mmap_load_safetensors
47
48
)
48
49
49
50
# work on low-ram device
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
60
61
61
62
62
63
def _load_target_model (
63
- name_or_path : str , vae_path : Optional [str ], model_version : str , weight_dtype , device = "cpu" , model_dtype = None
64
+ name_or_path : str , vae_path : Optional [str ], model_version : str , weight_dtype , device = "cpu" , model_dtype = None , disable_mmap = False
64
65
):
65
66
# model_dtype only work with full fp16/bf16
66
67
name_or_path = os .readlink (name_or_path ) if os .path .islink (name_or_path ) else name_or_path
@@ -75,7 +76,7 @@ def _load_target_model(
75
76
unet ,
76
77
logit_scale ,
77
78
ckpt_info ,
78
- ) = sdxl_model_util .load_models_from_sdxl_checkpoint (model_version , name_or_path , device , model_dtype )
79
+ ) = sdxl_model_util .load_models_from_sdxl_checkpoint (model_version , name_or_path , device , model_dtype , disable_mmap )
79
80
else :
80
81
# Diffusers model is loaded to CPU
81
82
from diffusers import StableDiffusionXLPipeline
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
332
333
action = "store_true" ,
333
334
help = "cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする" ,
334
335
)
336
+ parser .add_argument (
337
+ "--disable_mmap_load_safetensors" ,
338
+ action = "store_true" ,
339
+ )
335
340
336
341
337
342
def verify_sdxl_training_args (args : argparse .Namespace , supportTextEncoderCaching : bool = True ):
0 commit comments