Skip to content

Commit 8d1b1ac

Browse files
authored
Merge pull request kohya-ss#1266 from Zovjsra/feature/disable-mmap
Add "--disable_mmap_load_safetensors" parameter
2 parents 02298e3 + 64916a3 commit 8d1b1ac

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

library/sdxl_model_util.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import safetensors
23
from accelerate import init_empty_weights
34
from accelerate.utils.modeling import set_module_tensor_to_device
45
from safetensors.torch import load_file, save_file
@@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
163164
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
164165

165166

166-
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
167+
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
167168
# model_version is reserved for future use
168169
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
169170

170171
# Load the state dict
171172
if model_util.is_safetensors(ckpt_path):
172173
checkpoint = None
173-
try:
174-
state_dict = load_file(ckpt_path, device=map_location)
175-
except:
176-
state_dict = load_file(ckpt_path) # prevent device invalid Error
174+
if(disable_mmap):
175+
state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read())
176+
else:
177+
try:
178+
state_dict = load_file(ckpt_path, device=map_location)
179+
except:
180+
state_dict = load_file(ckpt_path) # prevent device invalid Error
177181
epoch = None
178182
global_step = None
179183
else:

library/sdxl_train_util.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
4444
weight_dtype,
4545
accelerator.device if args.lowram else "cpu",
4646
model_dtype,
47+
args.disable_mmap_load_safetensors
4748
)
4849

4950
# work on low-ram device
@@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
6061

6162

6263
def _load_target_model(
63-
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
64+
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
6465
):
6566
# model_dtype only work with full fp16/bf16
6667
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
@@ -75,7 +76,7 @@ def _load_target_model(
7576
unet,
7677
logit_scale,
7778
ckpt_info,
78-
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
79+
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
7980
else:
8081
# Diffusers model is loaded to CPU
8182
from diffusers import StableDiffusionXLPipeline
@@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
332333
action="store_true",
333334
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
334335
)
336+
parser.add_argument(
337+
"--disable_mmap_load_safetensors",
338+
action="store_true",
339+
)
335340

336341

337342
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):

0 commit comments

Comments
 (0)