Skip to content

Commit bba59fb

Browse files
sayakpaulDN6
andauthored
[Tests] add: test to check 8bit bnb quantized models work with lora loading. (#10576)
* add: test to check 8bit bnb quantized models work with lora loading. * Update tests/quantization/bnb/test_mixed_int8.py Co-authored-by: Dhruv Nair <[email protected]> --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 2432f80 commit bba59fb

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/quantization/bnb/test_mixed_int8.py

+25
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21+
from huggingface_hub import hf_hub_download
2122

2223
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
2324
from diffusers.utils import is_accelerate_version
@@ -30,6 +31,7 @@
3031
numpy_cosine_similarity_distance,
3132
require_accelerate,
3233
require_bitsandbytes_version_greater,
34+
require_peft_version_greater,
3335
require_torch,
3436
require_torch_gpu,
3537
require_transformers_version_greater,
@@ -509,6 +511,29 @@ def test_quality(self):
509511
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
510512
self.assertTrue(max_diff < 1e-3)
511513

514+
@require_peft_version_greater("0.14.0")
515+
def test_lora_loading(self):
516+
self.pipeline_8bit.load_lora_weights(
517+
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
518+
)
519+
self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)
520+
521+
output = self.pipeline_8bit(
522+
prompt=self.prompt,
523+
height=256,
524+
width=256,
525+
max_sequence_length=64,
526+
output_type="np",
527+
num_inference_steps=8,
528+
generator=torch.manual_seed(42),
529+
).images
530+
out_slice = output[0, -3:, -3:, -1].flatten()
531+
532+
expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])
533+
534+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
535+
self.assertTrue(max_diff < 1e-3)
536+
512537

513538
@slow
514539
class BaseBnb8bitSerializationTests(Base8bitTests):

0 commit comments

Comments
 (0)