Skip to content

Commit 13e8fde

Browse files
[feat] add load_lora_adapter() for compatible models (#9712)
* add first draft. * fix * updates. * updates. * updates * updates * updates. * fix-copies * lora constants. * add tests * Apply suggestions from code review Co-authored-by: Benjamin Bossan <[email protected]> * docstrings. --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent c10f875 commit 13e8fde

File tree

5 files changed

+515
-491
lines changed

5 files changed

+515
-491
lines changed

src/diffusers/loaders/lora_base.py

+124-118
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151

5252
logger = logging.get_logger(__name__)
5353

54+
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
55+
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
56+
5457

5558
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
5659
"""
@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
181184
text_encoder._hf_peft_config_loaded = None
182185

183186

187+
def _fetch_state_dict(
188+
pretrained_model_name_or_path_or_dict,
189+
weight_name,
190+
use_safetensors,
191+
local_files_only,
192+
cache_dir,
193+
force_download,
194+
proxies,
195+
token,
196+
revision,
197+
subfolder,
198+
user_agent,
199+
allow_pickle,
200+
):
201+
model_file = None
202+
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
203+
# Let's first try to load .safetensors weights
204+
if (use_safetensors and weight_name is None) or (
205+
weight_name is not None and weight_name.endswith(".safetensors")
206+
):
207+
try:
208+
# Here we're relaxing the loading check to enable more Inference API
209+
# friendliness where sometimes, it's not at all possible to automatically
210+
# determine `weight_name`.
211+
if weight_name is None:
212+
weight_name = _best_guess_weight_name(
213+
pretrained_model_name_or_path_or_dict,
214+
file_extension=".safetensors",
215+
local_files_only=local_files_only,
216+
)
217+
model_file = _get_model_file(
218+
pretrained_model_name_or_path_or_dict,
219+
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
220+
cache_dir=cache_dir,
221+
force_download=force_download,
222+
proxies=proxies,
223+
local_files_only=local_files_only,
224+
token=token,
225+
revision=revision,
226+
subfolder=subfolder,
227+
user_agent=user_agent,
228+
)
229+
state_dict = safetensors.torch.load_file(model_file, device="cpu")
230+
except (IOError, safetensors.SafetensorError) as e:
231+
if not allow_pickle:
232+
raise e
233+
# try loading non-safetensors weights
234+
model_file = None
235+
pass
236+
237+
if model_file is None:
238+
if weight_name is None:
239+
weight_name = _best_guess_weight_name(
240+
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
241+
)
242+
model_file = _get_model_file(
243+
pretrained_model_name_or_path_or_dict,
244+
weights_name=weight_name or LORA_WEIGHT_NAME,
245+
cache_dir=cache_dir,
246+
force_download=force_download,
247+
proxies=proxies,
248+
local_files_only=local_files_only,
249+
token=token,
250+
revision=revision,
251+
subfolder=subfolder,
252+
user_agent=user_agent,
253+
)
254+
state_dict = load_state_dict(model_file)
255+
else:
256+
state_dict = pretrained_model_name_or_path_or_dict
257+
258+
return state_dict
259+
260+
261+
def _best_guess_weight_name(
262+
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
263+
):
264+
if local_files_only or HF_HUB_OFFLINE:
265+
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
266+
267+
targeted_files = []
268+
269+
if os.path.isfile(pretrained_model_name_or_path_or_dict):
270+
return
271+
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
272+
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
273+
else:
274+
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
275+
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
276+
if len(targeted_files) == 0:
277+
return
278+
279+
# "scheduler" does not correspond to a LoRA checkpoint.
280+
# "optimizer" does not correspond to a LoRA checkpoint
281+
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
282+
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
283+
targeted_files = list(
284+
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
285+
)
286+
287+
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
288+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
289+
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
290+
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
291+
292+
if len(targeted_files) > 1:
293+
raise ValueError(
294+
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
295+
)
296+
weight_name = targeted_files[0]
297+
return weight_name
298+
299+
184300
class LoraBaseMixin:
185301
"""Utility class for handling LoRAs."""
186302

@@ -234,124 +350,16 @@ def _optionally_disable_offloading(cls, _pipeline):
234350
return (is_model_cpu_offload, is_sequential_cpu_offload)
235351

236352
@classmethod
237-
def _fetch_state_dict(
238-
cls,
239-
pretrained_model_name_or_path_or_dict,
240-
weight_name,
241-
use_safetensors,
242-
local_files_only,
243-
cache_dir,
244-
force_download,
245-
proxies,
246-
token,
247-
revision,
248-
subfolder,
249-
user_agent,
250-
allow_pickle,
251-
):
252-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
253-
254-
model_file = None
255-
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
256-
# Let's first try to load .safetensors weights
257-
if (use_safetensors and weight_name is None) or (
258-
weight_name is not None and weight_name.endswith(".safetensors")
259-
):
260-
try:
261-
# Here we're relaxing the loading check to enable more Inference API
262-
# friendliness where sometimes, it's not at all possible to automatically
263-
# determine `weight_name`.
264-
if weight_name is None:
265-
weight_name = cls._best_guess_weight_name(
266-
pretrained_model_name_or_path_or_dict,
267-
file_extension=".safetensors",
268-
local_files_only=local_files_only,
269-
)
270-
model_file = _get_model_file(
271-
pretrained_model_name_or_path_or_dict,
272-
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
273-
cache_dir=cache_dir,
274-
force_download=force_download,
275-
proxies=proxies,
276-
local_files_only=local_files_only,
277-
token=token,
278-
revision=revision,
279-
subfolder=subfolder,
280-
user_agent=user_agent,
281-
)
282-
state_dict = safetensors.torch.load_file(model_file, device="cpu")
283-
except (IOError, safetensors.SafetensorError) as e:
284-
if not allow_pickle:
285-
raise e
286-
# try loading non-safetensors weights
287-
model_file = None
288-
pass
289-
290-
if model_file is None:
291-
if weight_name is None:
292-
weight_name = cls._best_guess_weight_name(
293-
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
294-
)
295-
model_file = _get_model_file(
296-
pretrained_model_name_or_path_or_dict,
297-
weights_name=weight_name or LORA_WEIGHT_NAME,
298-
cache_dir=cache_dir,
299-
force_download=force_download,
300-
proxies=proxies,
301-
local_files_only=local_files_only,
302-
token=token,
303-
revision=revision,
304-
subfolder=subfolder,
305-
user_agent=user_agent,
306-
)
307-
state_dict = load_state_dict(model_file)
308-
else:
309-
state_dict = pretrained_model_name_or_path_or_dict
310-
311-
return state_dict
353+
def _fetch_state_dict(cls, *args, **kwargs):
354+
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
355+
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
356+
return _fetch_state_dict(*args, **kwargs)
312357

313358
@classmethod
314-
def _best_guess_weight_name(
315-
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
316-
):
317-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
318-
319-
if local_files_only or HF_HUB_OFFLINE:
320-
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
321-
322-
targeted_files = []
323-
324-
if os.path.isfile(pretrained_model_name_or_path_or_dict):
325-
return
326-
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
327-
targeted_files = [
328-
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
329-
]
330-
else:
331-
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
332-
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
333-
if len(targeted_files) == 0:
334-
return
335-
336-
# "scheduler" does not correspond to a LoRA checkpoint.
337-
# "optimizer" does not correspond to a LoRA checkpoint
338-
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
339-
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
340-
targeted_files = list(
341-
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
342-
)
343-
344-
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
345-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
346-
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
347-
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
348-
349-
if len(targeted_files) > 1:
350-
raise ValueError(
351-
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
352-
)
353-
weight_name = targeted_files[0]
354-
return weight_name
359+
def _best_guess_weight_name(cls, *args, **kwargs):
360+
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
361+
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
362+
return _best_guess_weight_name(*args, **kwargs)
355363

356364
def unload_lora_weights(self):
357365
"""
@@ -725,8 +733,6 @@ def write_lora_layers(
725733
save_function: Callable,
726734
safe_serialization: bool,
727735
):
728-
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
729-
730736
if os.path.isfile(save_directory):
731737
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
732738
return

0 commit comments

Comments
 (0)