Skip to content

Commit 526858c

Browse files
a-r-r-o-wsayakpaul
authored andcommitted
[LoRA] Support original format loras for HunyuanVideo (#10376)
* update * fix make copies * update * add relevant markers to the integration test suite. * add copied. * fox-copies * temporarily add print. * directly place on CUDA as CPU isn't that big on the CIO. * fixes to fuse_lora, aryan was right. * fixes --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent b3d2dd3 commit 526858c

File tree

3 files changed

+256
-6
lines changed

3 files changed

+256
-6
lines changed

Diff for: src/diffusers/loaders/lora_conversion_utils.py

+175
Original file line numberDiff line numberDiff line change
@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
973973
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
974974

975975
return converted_state_dict
976+
977+
978+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
979+
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
980+
981+
def remap_norm_scale_shift_(key, state_dict):
982+
weight = state_dict.pop(key)
983+
shift, scale = weight.chunk(2, dim=0)
984+
new_weight = torch.cat([scale, shift], dim=0)
985+
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
986+
987+
def remap_txt_in_(key, state_dict):
988+
def rename_key(key):
989+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
990+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
991+
new_key = new_key.replace("txt_in", "context_embedder")
992+
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
993+
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
994+
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
995+
new_key = new_key.replace("mlp", "ff")
996+
return new_key
997+
998+
if "self_attn_qkv" in key:
999+
weight = state_dict.pop(key)
1000+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1001+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
1002+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
1003+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
1004+
else:
1005+
state_dict[rename_key(key)] = state_dict.pop(key)
1006+
1007+
def remap_img_attn_qkv_(key, state_dict):
1008+
weight = state_dict.pop(key)
1009+
if "lora_A" in key:
1010+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
1011+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
1012+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
1013+
else:
1014+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1015+
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
1016+
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
1017+
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
1018+
1019+
def remap_txt_attn_qkv_(key, state_dict):
1020+
weight = state_dict.pop(key)
1021+
if "lora_A" in key:
1022+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
1023+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
1024+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
1025+
else:
1026+
to_q, to_k, to_v = weight.chunk(3, dim=0)
1027+
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
1028+
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
1029+
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
1030+
1031+
def remap_single_transformer_blocks_(key, state_dict):
1032+
hidden_size = 3072
1033+
1034+
if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
1035+
linear1_weight = state_dict.pop(key)
1036+
if "lora_A" in key:
1037+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1038+
".linear1.lora_A.weight"
1039+
)
1040+
state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
1041+
state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
1042+
state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
1043+
state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
1044+
else:
1045+
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
1046+
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
1047+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1048+
".linear1.lora_B.weight"
1049+
)
1050+
state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
1051+
state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
1052+
state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
1053+
state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
1054+
1055+
elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
1056+
linear1_bias = state_dict.pop(key)
1057+
if "lora_A" in key:
1058+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1059+
".linear1.lora_A.bias"
1060+
)
1061+
state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
1062+
state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
1063+
state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
1064+
state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
1065+
else:
1066+
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
1067+
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
1068+
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
1069+
".linear1.lora_B.bias"
1070+
)
1071+
state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
1072+
state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
1073+
state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
1074+
state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
1075+
1076+
else:
1077+
new_key = key.replace("single_blocks", "single_transformer_blocks")
1078+
new_key = new_key.replace("linear2", "proj_out")
1079+
new_key = new_key.replace("q_norm", "attn.norm_q")
1080+
new_key = new_key.replace("k_norm", "attn.norm_k")
1081+
state_dict[new_key] = state_dict.pop(key)
1082+
1083+
TRANSFORMER_KEYS_RENAME_DICT = {
1084+
"img_in": "x_embedder",
1085+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
1086+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
1087+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
1088+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
1089+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
1090+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
1091+
"double_blocks": "transformer_blocks",
1092+
"img_attn_q_norm": "attn.norm_q",
1093+
"img_attn_k_norm": "attn.norm_k",
1094+
"img_attn_proj": "attn.to_out.0",
1095+
"txt_attn_q_norm": "attn.norm_added_q",
1096+
"txt_attn_k_norm": "attn.norm_added_k",
1097+
"txt_attn_proj": "attn.to_add_out",
1098+
"img_mod.linear": "norm1.linear",
1099+
"img_norm1": "norm1.norm",
1100+
"img_norm2": "norm2",
1101+
"img_mlp": "ff",
1102+
"txt_mod.linear": "norm1_context.linear",
1103+
"txt_norm1": "norm1.norm",
1104+
"txt_norm2": "norm2_context",
1105+
"txt_mlp": "ff_context",
1106+
"self_attn_proj": "attn.to_out.0",
1107+
"modulation.linear": "norm.linear",
1108+
"pre_norm": "norm.norm",
1109+
"final_layer.norm_final": "norm_out.norm",
1110+
"final_layer.linear": "proj_out",
1111+
"fc1": "net.0.proj",
1112+
"fc2": "net.2",
1113+
"input_embedder": "proj_in",
1114+
}
1115+
1116+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
1117+
"txt_in": remap_txt_in_,
1118+
"img_attn_qkv": remap_img_attn_qkv_,
1119+
"txt_attn_qkv": remap_txt_attn_qkv_,
1120+
"single_blocks": remap_single_transformer_blocks_,
1121+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
1122+
}
1123+
1124+
# Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
1125+
# and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
1126+
# sure that both follow the same initial format by stripping off the "transformer." prefix.
1127+
for key in list(converted_state_dict.keys()):
1128+
if key.startswith("transformer."):
1129+
converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
1130+
if key.startswith("diffusion_model."):
1131+
converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
1132+
1133+
# Rename and remap the state dict keys
1134+
for key in list(converted_state_dict.keys()):
1135+
new_key = key[:]
1136+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
1137+
new_key = new_key.replace(replace_key, rename_key)
1138+
converted_state_dict[new_key] = converted_state_dict.pop(key)
1139+
1140+
for key in list(converted_state_dict.keys()):
1141+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
1142+
if special_key not in key:
1143+
continue
1144+
handler_fn_inplace(key, converted_state_dict)
1145+
1146+
# Add back the "transformer." prefix
1147+
for key in list(converted_state_dict.keys()):
1148+
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
1149+
1150+
return converted_state_dict

Diff for: src/diffusers/loaders/lora_pipeline.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
3737
from .lora_conversion_utils import (
3838
_convert_bfl_flux_control_lora_to_diffusers,
39+
_convert_hunyuan_video_lora_to_diffusers,
3940
_convert_kohya_flux_lora_to_diffusers,
4041
_convert_non_diffusers_lora_to_diffusers,
4142
_convert_xlabs_flux_lora_to_diffusers,
@@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
40074008

40084009
@classmethod
40094010
@validate_hf_hub_args
4010-
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
40114011
def lora_state_dict(
40124012
cls,
40134013
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4018,7 +4018,7 @@ def lora_state_dict(
40184018
40194019
<Tip warning={true}>
40204020
4021-
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
4021+
We support loading original format HunyuanVideo LoRA checkpoints.
40224022
40234023
This function is experimental and might change in the future.
40244024
@@ -4101,6 +4101,10 @@ def lora_state_dict(
41014101
logger.warning(warn_msg)
41024102
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
41034103

4104+
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
4105+
if is_original_hunyuan_video:
4106+
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
4107+
41044108
return state_dict
41054109

41064110
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@@ -4239,10 +4243,9 @@ def save_lora_weights(
42394243
safe_serialization=safe_serialization,
42404244
)
42414245

4242-
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
42434246
def fuse_lora(
42444247
self,
4245-
components: List[str] = ["transformer", "text_encoder"],
4248+
components: List[str] = ["transformer"],
42464249
lora_scale: float = 1.0,
42474250
safe_fusing: bool = False,
42484251
adapter_names: Optional[List[str]] = None,
@@ -4283,8 +4286,7 @@ def fuse_lora(
42834286
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
42844287
)
42854288

4286-
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
4287-
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
4289+
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
42884290
r"""
42894291
Reverses the effect of
42904292
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).

Diff for: tests/lora/test_lora_layers_hunyuanvideo.py

+73
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import sys
1617
import unittest
1718

19+
import numpy as np
20+
import pytest
1821
import torch
1922
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
2023

@@ -26,7 +29,11 @@
2629
)
2730
from diffusers.utils.testing_utils import (
2831
floats_tensor,
32+
nightly,
33+
numpy_cosine_similarity_distance,
34+
require_big_gpu_with_torch_cuda,
2935
require_peft_backend,
36+
require_torch_gpu,
3037
skip_mps,
3138
)
3239

@@ -182,3 +189,69 @@ def test_simple_inference_with_text_lora_fused(self):
182189
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
183190
def test_simple_inference_with_text_lora_save_load(self):
184191
pass
192+
193+
194+
@nightly
195+
@require_torch_gpu
196+
@require_peft_backend
197+
@require_big_gpu_with_torch_cuda
198+
@pytest.mark.big_gpu_with_torch_cuda
199+
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
200+
"""internal note: The integration slices were obtained on DGX.
201+
202+
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
203+
assertions to pass.
204+
"""
205+
206+
num_inference_steps = 10
207+
seed = 0
208+
209+
def setUp(self):
210+
super().setUp()
211+
212+
gc.collect()
213+
torch.cuda.empty_cache()
214+
215+
model_id = "hunyuanvideo-community/HunyuanVideo"
216+
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
217+
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
218+
)
219+
self.pipeline = HunyuanVideoPipeline.from_pretrained(
220+
model_id, transformer=transformer, torch_dtype=torch.float16
221+
).to("cuda")
222+
223+
def tearDown(self):
224+
super().tearDown()
225+
226+
gc.collect()
227+
torch.cuda.empty_cache()
228+
229+
def test_original_format_cseti(self):
230+
self.pipeline.load_lora_weights(
231+
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
232+
)
233+
self.pipeline.fuse_lora()
234+
self.pipeline.unload_lora_weights()
235+
self.pipeline.vae.enable_tiling()
236+
237+
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
238+
239+
out = self.pipeline(
240+
prompt=prompt,
241+
height=320,
242+
width=512,
243+
num_frames=9,
244+
num_inference_steps=self.num_inference_steps,
245+
output_type="np",
246+
generator=torch.manual_seed(self.seed),
247+
).frames[0]
248+
out = out.flatten()
249+
out_slice = np.concatenate((out[:8], out[-8:]))
250+
251+
# fmt: off
252+
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
253+
# fmt: on
254+
255+
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
256+
257+
assert max_diff < 1e-3

0 commit comments

Comments
 (0)