Skip to content

Commit 4267d8f

Browse files
authored
[Single File] GGUF/Single File Support for HiDream (#11550)
* update * update * update * update * update * update * update
1 parent f4fa3be commit 4267d8f

File tree

6 files changed

+67
-5
lines changed

6 files changed

+67
-5
lines changed

docs/source/en/api/models/hidream_image_transformer.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
2121
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
2222
```
2323

24+
## Loading GGUF quantized checkpoints for HiDream-I1
25+
26+
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
27+
28+
```python
29+
import torch
30+
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
31+
32+
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
33+
transformer = HiDreamImageTransformer2DModel.from_single_file(
34+
ckpt_path,
35+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
36+
torch_dtype=torch.bfloat16
37+
)
38+
```
39+
2440
## HiDreamImageTransformer2DModel
2541

2642
[[autodoc]] HiDreamImageTransformer2DModel

src/diffusers/loaders/single_file_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
convert_autoencoder_dc_checkpoint_to_diffusers,
3232
convert_controlnet_checkpoint,
3333
convert_flux_transformer_checkpoint_to_diffusers,
34+
convert_hidream_transformer_to_diffusers,
3435
convert_hunyuan_video_transformer_to_diffusers,
3536
convert_ldm_unet_checkpoint,
3637
convert_ldm_vae_checkpoint,
@@ -133,6 +134,10 @@
133134
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
134135
"default_subfolder": "vae",
135136
},
137+
"HiDreamImageTransformer2DModel": {
138+
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
139+
"default_subfolder": "transformer",
140+
},
136141
}
137142

138143

src/diffusers/loaders/single_file_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
],
127127
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
128128
"wan_vae": "decoder.middle.0.residual.0.gamma",
129+
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
129130
}
130131

131132
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -190,6 +191,7 @@
190191
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
191192
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
192193
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
194+
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
193195
}
194196

195197
# Use to configure model sample size when original config is provided
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
701703
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
702704
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
703705
model_type = "wan-t2v-14B"
706+
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
707+
model_type = "hidream"
704708
else:
705709
model_type = "v1"
706710

@@ -3293,3 +3297,12 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
32933297
converted_state_dict[key] = value
32943298

32953299
return converted_state_dict
3300+
3301+
3302+
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
3303+
keys = list(checkpoint.keys())
3304+
for k in keys:
3305+
if "model.diffusion_model." in k:
3306+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
3307+
3308+
return checkpoint

src/diffusers/models/transformers/transformer_hidream_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.nn.functional as F
66

77
from ...configuration_utils import ConfigMixin, register_to_config
8-
from ...loaders import PeftAdapterMixin
8+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
99
from ...models.modeling_outputs import Transformer2DModelOutput
1010
from ...models.modeling_utils import ModelMixin
1111
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -602,7 +602,7 @@ def forward(
602602
)
603603

604604

605-
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
605+
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
606606
_supports_gradient_checkpointing = True
607607
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
608608

src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
Examples:
3737
```py
3838
>>> import torch
39-
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
40-
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
39+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
40+
>>> from diffusers import HiDreamImagePipeline
4141
4242
43-
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
43+
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
4444
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
4545
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
4646
... output_hidden_states=True,

tests/quantization/gguf/test_gguf.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FluxPipeline,
1313
FluxTransformer2DModel,
1414
GGUFQuantizationConfig,
15+
HiDreamImageTransformer2DModel,
1516
SD3Transformer2DModel,
1617
StableDiffusion3Pipeline,
1718
)
@@ -549,3 +550,30 @@ def test_lora_loading(self):
549550

550551
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
551552
self.assertTrue(max_diff < 1e-3)
553+
554+
555+
class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
556+
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
557+
torch_dtype = torch.bfloat16
558+
model_cls = HiDreamImageTransformer2DModel
559+
expected_memory_use_in_gb = 8
560+
561+
def get_dummy_inputs(self):
562+
return {
563+
"hidden_states": torch.randn((1, 16, 128, 128), generator=torch.Generator("cpu").manual_seed(0)).to(
564+
torch_device, self.torch_dtype
565+
),
566+
"encoder_hidden_states_t5": torch.randn(
567+
(1, 128, 4096),
568+
generator=torch.Generator("cpu").manual_seed(0),
569+
).to(torch_device, self.torch_dtype),
570+
"encoder_hidden_states_llama3": torch.randn(
571+
(32, 1, 128, 4096),
572+
generator=torch.Generator("cpu").manual_seed(0),
573+
).to(torch_device, self.torch_dtype),
574+
"pooled_embeds": torch.randn(
575+
(1, 2048),
576+
generator=torch.Generator("cpu").manual_seed(0),
577+
).to(torch_device, self.torch_dtype),
578+
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
579+
}

0 commit comments

Comments
 (0)