Skip to content

Commit fbff43a

Browse files
SunMarcsayakpaulWauplinstevhliuyiyixuxu
authored
[FEAT] DDUF format (#10037)
* load and save dduf archive * style * switch to zip uncompressed * updates * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <[email protected]> * first draft * remove print * switch to dduf_file for consistency * switch to huggingface hub api * fix log * add a basic test * Update src/diffusers/configuration_utils.py Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <[email protected]> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <[email protected]> * fix * fix variant * change saving logic * DDUF - Load transformers components manually (#10171) * update hfh version * Load transformers components manually * load encoder from_pretrained with state_dict * working version with transformers and tokenizer ! * add generation_config case * fix tests * remove saving for now * typing * need next version from transformers * Update src/diffusers/configuration_utils.py Co-authored-by: Lucain <[email protected]> * check path corectly * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * udapte * typing * remove check for subfolder * quality * revert setup changes * oups * more readable condition * add loading from the hub test * add basic docs. * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * add example * add * make functions private * Apply suggestions from code review Co-authored-by: Steven Liu <[email protected]> * minor. * fixes * fix * change the precdence of parameterized. * error out when custom pipeline is passed with dduf_file. * updates * fix * updates * fixes * updates * fix xfail condition. * fix xfail * fixes * sharded checkpoint compat * add test for sharded checkpoint * add suggestions * Update src/diffusers/models/model_loading_utils.py Co-authored-by: YiYi Xu <[email protected]> * from suggestions * add class attributes to flag dduf tests * last one * fix logic * remove comment * revert changes --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Lucain <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: YiYi Xu <[email protected]>
1 parent 3279751 commit fbff43a

File tree

62 files changed

+750
-45
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+750
-45
lines changed

Diff for: docs/source/en/using-diffusers/other-formats.md

+40
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
240240
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
241241
2. Easier to manage (download and share) a single file.
242242

243+
### DDUF
244+
245+
> [!WARNING]
246+
> DDUF is an experimental file format and APIs related to it can change in the future.
247+
248+
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
249+
250+
Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
251+
252+
Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
253+
254+
```py
255+
from diffusers import DiffusionPipeline
256+
import torch
257+
258+
pipe = DiffusionPipeline.from_pretrained(
259+
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
260+
).to("cuda")
261+
image = pipe(
262+
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
263+
).images[0]
264+
image.save("cat.png")
265+
```
266+
267+
To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
268+
269+
```py
270+
from huggingface_hub import export_folder_as_dduf
271+
from diffusers import DiffusionPipeline
272+
import torch
273+
274+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
275+
276+
save_folder = "flux-dev"
277+
pipe.save_pretrained("flux-dev")
278+
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
279+
280+
> [!TIP]
281+
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
282+
243283
## Convert layout and files
244284

245285
Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"filelock",
102102
"flax>=0.4.1",
103103
"hf-doc-builder>=0.3.0",
104-
"huggingface-hub>=0.23.2",
104+
"huggingface-hub>=0.27.0",
105105
"requests-mock==1.10.0",
106106
"importlib_metadata",
107107
"invisible-watermark>=0.2.0",

Diff for: src/diffusers/configuration_utils.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
import re
2525
from collections import OrderedDict
2626
from pathlib import Path
27-
from typing import Any, Dict, Tuple, Union
27+
from typing import Any, Dict, Optional, Tuple, Union
2828

2929
import numpy as np
30-
from huggingface_hub import create_repo, hf_hub_download
30+
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
3131
from huggingface_hub.utils import (
3232
EntryNotFoundError,
3333
RepositoryNotFoundError,
@@ -347,6 +347,7 @@ def load_config(
347347
_ = kwargs.pop("mirror", None)
348348
subfolder = kwargs.pop("subfolder", None)
349349
user_agent = kwargs.pop("user_agent", {})
350+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
350351

351352
user_agent = {**user_agent, "file_type": "config"}
352353
user_agent = http_user_agent(user_agent)
@@ -358,8 +359,15 @@ def load_config(
358359
"`self.config_name` is not defined. Note that one should not load a config from "
359360
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
360361
)
361-
362-
if os.path.isfile(pretrained_model_name_or_path):
362+
# Custom path for now
363+
if dduf_entries:
364+
if subfolder is not None:
365+
raise ValueError(
366+
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
367+
"Please check the DDUF structure"
368+
)
369+
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
370+
elif os.path.isfile(pretrained_model_name_or_path):
363371
config_file = pretrained_model_name_or_path
364372
elif os.path.isdir(pretrained_model_name_or_path):
365373
if subfolder is not None and os.path.isfile(
@@ -426,10 +434,8 @@ def load_config(
426434
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
427435
f"containing a {cls.config_name} file"
428436
)
429-
430437
try:
431-
# Load config dict
432-
config_dict = cls._dict_from_json_file(config_file)
438+
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
433439

434440
commit_hash = extract_commit_hash(config_file)
435441
except (json.JSONDecodeError, UnicodeDecodeError):
@@ -552,9 +558,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
552558
return init_dict, unused_kwargs, hidden_config_dict
553559

554560
@classmethod
555-
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
556-
with open(json_file, "r", encoding="utf-8") as reader:
557-
text = reader.read()
561+
def _dict_from_json_file(
562+
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
563+
):
564+
if dduf_entries:
565+
text = dduf_entries[json_file].read_text()
566+
else:
567+
with open(json_file, "r", encoding="utf-8") as reader:
568+
text = reader.read()
558569
return json.loads(text)
559570

560571
def __repr__(self):
@@ -616,6 +627,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
616627
with open(json_file_path, "w", encoding="utf-8") as writer:
617628
writer.write(self.to_json_string())
618629

630+
@classmethod
631+
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
632+
# paths inside a DDUF file must always be "/"
633+
config_file = (
634+
cls.config_name
635+
if pretrained_model_name_or_path == ""
636+
else "/".join([pretrained_model_name_or_path, cls.config_name])
637+
)
638+
if config_file not in dduf_entries:
639+
raise ValueError(
640+
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
641+
)
642+
return config_file
643+
619644

620645
def register_to_config(init):
621646
r"""

Diff for: src/diffusers/dependency_versions_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"filelock": "filelock",
1010
"flax": "flax>=0.4.1",
1111
"hf-doc-builder": "hf-doc-builder>=0.3.0",
12-
"huggingface-hub": "huggingface-hub>=0.23.2",
12+
"huggingface-hub": "huggingface-hub>=0.27.0",
1313
"requests-mock": "requests-mock==1.10.0",
1414
"importlib_metadata": "importlib_metadata",
1515
"invisible-watermark": "invisible-watermark>=0.2.0",

Diff for: src/diffusers/models/model_loading_utils.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
from array import array
2121
from collections import OrderedDict
2222
from pathlib import Path
23-
from typing import List, Optional, Union
23+
from typing import Dict, List, Optional, Union
2424

2525
import safetensors
2626
import torch
27+
from huggingface_hub import DDUFEntry
2728
from huggingface_hub.utils import EntryNotFoundError
2829

2930
from ..utils import (
@@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class):
132133

133134

134135
def load_state_dict(
135-
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
136+
checkpoint_file: Union[str, os.PathLike],
137+
variant: Optional[str] = None,
138+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
139+
disable_mmap: bool = False,
136140
):
137141
"""
138142
Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -144,6 +148,10 @@ def load_state_dict(
144148
try:
145149
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
146150
if file_extension == SAFETENSORS_FILE_EXTENSION:
151+
if dduf_entries:
152+
# tensors are loaded on cpu
153+
with dduf_entries[checkpoint_file].as_mmap() as mm:
154+
return safetensors.torch.load(mm)
147155
if disable_mmap:
148156
return safetensors.torch.load(open(checkpoint_file, "rb").read())
149157
else:
@@ -284,6 +292,7 @@ def _fetch_index_file(
284292
revision,
285293
user_agent,
286294
commit_hash,
295+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
287296
):
288297
if is_local:
289298
index_file = Path(
@@ -309,8 +318,10 @@ def _fetch_index_file(
309318
subfolder=None,
310319
user_agent=user_agent,
311320
commit_hash=commit_hash,
321+
dduf_entries=dduf_entries,
312322
)
313-
index_file = Path(index_file)
323+
if not dduf_entries:
324+
index_file = Path(index_file)
314325
except (EntryNotFoundError, EnvironmentError):
315326
index_file = None
316327

@@ -319,7 +330,9 @@ def _fetch_index_file(
319330

320331
# Adapted from
321332
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
322-
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
333+
def _merge_sharded_checkpoints(
334+
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
335+
):
323336
weight_map = sharded_metadata.get("weight_map", None)
324337
if weight_map is None:
325338
raise KeyError("'weight_map' key not found in the shard index file.")
@@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
332345
# Load tensors from each unique file
333346
for file_name in files_to_load:
334347
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
335-
if not os.path.exists(part_file_path):
336-
raise FileNotFoundError(f"Part file {file_name} not found.")
348+
if dduf_entries:
349+
if part_file_path not in dduf_entries:
350+
raise FileNotFoundError(f"Part file {file_name} not found.")
351+
else:
352+
if not os.path.exists(part_file_path):
353+
raise FileNotFoundError(f"Part file {file_name} not found.")
337354

338355
if is_safetensors:
339-
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
340-
for tensor_key in f.keys():
341-
if tensor_key in weight_map:
342-
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
356+
if dduf_entries:
357+
with dduf_entries[part_file_path].as_mmap() as mm:
358+
tensors = safetensors.torch.load(mm)
359+
merged_state_dict.update(tensors)
360+
else:
361+
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
362+
for tensor_key in f.keys():
363+
if tensor_key in weight_map:
364+
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
343365
else:
344366
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
345367

@@ -360,6 +382,7 @@ def _fetch_index_file_legacy(
360382
revision,
361383
user_agent,
362384
commit_hash,
385+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
363386
):
364387
if is_local:
365388
index_file = Path(
@@ -400,6 +423,7 @@ def _fetch_index_file_legacy(
400423
subfolder=None,
401424
user_agent=user_agent,
402425
commit_hash=commit_hash,
426+
dduf_entries=dduf_entries,
403427
)
404428
index_file = Path(index_file)
405429
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."

Diff for: src/diffusers/models/modeling_utils.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from collections import OrderedDict
2424
from functools import partial, wraps
2525
from pathlib import Path
26-
from typing import Any, Callable, List, Optional, Tuple, Union
26+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2727

2828
import safetensors
2929
import torch
30-
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
30+
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
3131
from huggingface_hub.utils import validate_hf_hub_args
3232
from torch import Tensor, nn
3333

@@ -607,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
607607
variant = kwargs.pop("variant", None)
608608
use_safetensors = kwargs.pop("use_safetensors", None)
609609
quantization_config = kwargs.pop("quantization_config", None)
610+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
610611
disable_mmap = kwargs.pop("disable_mmap", False)
611612

612613
allow_pickle = False
@@ -700,6 +701,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
700701
revision=revision,
701702
subfolder=subfolder,
702703
user_agent=user_agent,
704+
dduf_entries=dduf_entries,
703705
**kwargs,
704706
)
705707
# no in-place modification of the original config.
@@ -776,13 +778,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
776778
"revision": revision,
777779
"user_agent": user_agent,
778780
"commit_hash": commit_hash,
781+
"dduf_entries": dduf_entries,
779782
}
780783
index_file = _fetch_index_file(**index_file_kwargs)
781784
# In case the index file was not found we still have to consider the legacy format.
782785
# this becomes applicable when the variant is not None.
783786
if variant is not None and (index_file is None or not os.path.exists(index_file)):
784787
index_file = _fetch_index_file_legacy(**index_file_kwargs)
785-
if index_file is not None and index_file.is_file():
788+
if index_file is not None and (dduf_entries or index_file.is_file()):
786789
is_sharded = True
787790

788791
if is_sharded and from_flax:
@@ -811,6 +814,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
811814

812815
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
813816
else:
817+
# in the case it is sharded, we have already the index
814818
if is_sharded:
815819
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
816820
pretrained_model_name_or_path,
@@ -822,10 +826,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
822826
user_agent=user_agent,
823827
revision=revision,
824828
subfolder=subfolder or "",
829+
dduf_entries=dduf_entries,
825830
)
826831
# TODO: https://github.com/huggingface/diffusers/issues/10013
827-
if hf_quantizer is not None:
828-
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
832+
if hf_quantizer is not None or dduf_entries:
833+
model_file = _merge_sharded_checkpoints(
834+
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
835+
)
829836
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
830837
is_sharded = False
831838

@@ -843,6 +850,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
843850
subfolder=subfolder,
844851
user_agent=user_agent,
845852
commit_hash=commit_hash,
853+
dduf_entries=dduf_entries,
846854
)
847855

848856
except IOError as e:
@@ -866,6 +874,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
866874
subfolder=subfolder,
867875
user_agent=user_agent,
868876
commit_hash=commit_hash,
877+
dduf_entries=dduf_entries,
869878
)
870879

871880
if low_cpu_mem_usage:
@@ -887,7 +896,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
887896
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
888897
else:
889898
param_device = torch.device(torch.cuda.current_device())
890-
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
899+
state_dict = load_state_dict(
900+
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
901+
)
891902
model._convert_deprecated_attention_blocks(state_dict)
892903

893904
# move the params from meta device to cpu
@@ -983,7 +994,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
983994
else:
984995
model = cls.from_config(config, **unused_kwargs)
985996

986-
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
997+
state_dict = load_state_dict(
998+
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
999+
)
9871000
model._convert_deprecated_attention_blocks(state_dict)
9881001

9891002
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

0 commit comments

Comments
 (0)