-
Notifications
You must be signed in to change notification settings - Fork 6k
[LoRA] Support original format loras for HunyuanVideo #10376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
66fc85e
893b9c0
904e3a4
8be9180
a040c5d
f682d76
63d5e9f
4ac0c12
5fbc59c
95a7e0f
738f50d
23854f2
2cc3683
3b64bd5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa | ||
from .lora_conversion_utils import ( | ||
_convert_bfl_flux_control_lora_to_diffusers, | ||
_convert_hunyuan_video_lora_to_diffusers, | ||
_convert_kohya_flux_lora_to_diffusers, | ||
_convert_non_diffusers_lora_to_diffusers, | ||
_convert_xlabs_flux_lora_to_diffusers, | ||
|
@@ -4007,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): | |
|
||
@classmethod | ||
@validate_hf_hub_args | ||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict | ||
def lora_state_dict( | ||
cls, | ||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | ||
|
@@ -4018,7 +4018,7 @@ def lora_state_dict( | |
|
||
<Tip warning={true}> | ||
|
||
We support loading A1111 formatted LoRA checkpoints in a limited capacity. | ||
We support loading original format HunyuanVideo LoRA checkpoints. | ||
|
||
This function is experimental and might change in the future. | ||
|
||
|
@@ -4101,6 +4101,10 @@ def lora_state_dict( | |
logger.warning(warn_msg) | ||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | ||
|
||
is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict) | ||
if is_original_hunyuan_video: | ||
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) | ||
|
||
return state_dict | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights | ||
|
@@ -4239,10 +4243,9 @@ def save_lora_weights( | |
safe_serialization=safe_serialization, | ||
) | ||
|
||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could leverage the CogVideoX fuse_lora for the "Copy" statement, no? If so, I'd prefer that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really because we have a hunyuan specific example here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, we follow "Copied from ..." with the same example to play it to our advantage (of maintenance) for the other classes, too. So, let's perhaps maintain that consistency. @stevhliu WDYT about that? |
||
def fuse_lora( | ||
self, | ||
components: List[str] = ["transformer", "text_encoder"], | ||
components: List[str] = ["transformer"], | ||
lora_scale: float = 1.0, | ||
safe_fusing: bool = False, | ||
adapter_names: Optional[List[str]] = None, | ||
|
@@ -4269,14 +4272,16 @@ def fuse_lora( | |
Example: | ||
|
||
```py | ||
from diffusers import DiffusionPipeline | ||
import torch | ||
|
||
pipeline = DiffusionPipeline.from_pretrained( | ||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 | ||
).to("cuda") | ||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | ||
pipeline.fuse_lora(lora_scale=0.7) | ||
>>> import torch | ||
>>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel | ||
|
||
>>> model_id = "hunyuanvideo-community/HunyuanVideo" | ||
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( | ||
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 | ||
... ) | ||
>>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) | ||
>>> pipe.load_lora_weights("a-r-r-o-w/HunyuanVideo-tuxemons", adapter_name="tuxemons") | ||
>>> pipe.set_adapter("tuxemons", 1.2) | ||
``` | ||
""" | ||
super().fuse_lora( | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -12,9 +12,11 @@ | |||
# See the License for the specific language governing permissions and | ||||
# limitations under the License. | ||||
|
||||
import gc | ||||
import sys | ||||
import unittest | ||||
|
||||
import numpy as np | ||||
import torch | ||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast | ||||
|
||||
|
@@ -26,8 +28,12 @@ | |||
) | ||||
from diffusers.utils.testing_utils import ( | ||||
floats_tensor, | ||||
nightly, | ||||
numpy_cosine_similarity_distance, | ||||
require_peft_backend, | ||||
require_torch_gpu, | ||||
skip_mps, | ||||
slow, | ||||
) | ||||
|
||||
|
||||
|
@@ -182,3 +188,70 @@ def test_simple_inference_with_text_lora_fused(self): | |||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.") | ||||
def test_simple_inference_with_text_lora_save_load(self): | ||||
pass | ||||
|
||||
|
||||
@slow | ||||
@nightly | ||||
@require_torch_gpu | ||||
@require_peft_backend | ||||
# @unittest.skip("We cannot run inference on this model with the current CI hardware") | ||||
# TODO (DN6, sayakpaul): move these tests to a beefier GPU | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can happily go after we add the following two markers:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will put on my todos because seems like Flux also is using the same marker, where I copied from. For the next few days, I have a few other things I'd like to PoC or work on, so will take up the test refactor for this soon otherwise this PR might get stalled more There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. We have it in the Flux Control tests, already: diffusers/tests/lora/test_lora_layers_flux.py Line 942 in b94cfd7
Flux LoRA ones will be in after #9845 is merged. Since we already have a test suite for LoRA that uses the big model marker, I think it's fine to utilize that here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be a blocker to do it in separate PR? If not, will revert the copied from changes and proceed to merge as this seems like something folks want without more delay, and I don't really have the bandwidth atm There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine by me. |
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): | ||||
"""internal note: The integration slices were obtained on DGX. | ||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the | ||||
assertions to pass. | ||||
""" | ||||
|
||||
num_inference_steps = 10 | ||||
seed = 0 | ||||
|
||||
def setUp(self): | ||||
super().setUp() | ||||
|
||||
gc.collect() | ||||
torch.cuda.empty_cache() | ||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo" | ||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained( | ||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16 | ||||
) | ||||
self.pipeline = HunyuanVideoPipeline.from_pretrained( | ||||
model_id, transformer=transformer, torch_dtype=torch.float16 | ||||
) | ||||
|
||||
def tearDown(self): | ||||
super().tearDown() | ||||
|
||||
gc.collect() | ||||
torch.cuda.empty_cache() | ||||
|
||||
def test_original_format_cseti(self): | ||||
self.pipeline.load_lora_weights( | ||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors" | ||||
) | ||||
self.pipeline.fuse_lora() | ||||
self.pipeline.unload_lora_weights() | ||||
self.pipeline.vae.enable_tiling() | ||||
self.pipeline.enable_model_cpu_offload() | ||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic" | ||||
|
||||
out = self.pipeline( | ||||
prompt=prompt, | ||||
height=320, | ||||
width=512, | ||||
num_frames=9, | ||||
num_inference_steps=self.num_inference_steps, | ||||
output_type="np", | ||||
generator=torch.manual_seed(self.seed), | ||||
).frames[0] | ||||
out = out.flatten() | ||||
out_slice = np.concatenate((out[:8], out[-8:])) | ||||
# fmt: off | ||||
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]) | ||||
# fmt: on | ||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice) | ||||
|
||||
assert max_diff < 1e-3 |
Uh oh!
There was an error while loading. Please reload this page.