Skip to content

Commit 9167461

Browse files
authored
enable mllama cases on xpu (#37644)
* enable mllama testing on xpu Signed-off-by: YAO Matrix <[email protected]> * more mllama cases enabling Signed-off-by: YAO Matrix <[email protected]> * make cases pass on A100 Signed-off-by: N <[email protected]> --------- Signed-off-by: YAO Matrix <[email protected]> Signed-off-by: N <[email protected]>
1 parent de182ba commit 9167461

File tree

1 file changed

+69
-16
lines changed

1 file changed

+69
-16
lines changed

tests/models/mllama/test_modeling_mllama.py

+69-16
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131
from transformers.cache_utils import Cache
3232
from transformers.models.mllama.configuration_mllama import MllamaTextConfig
3333
from transformers.testing_utils import (
34+
Expectations,
3435
cleanup,
3536
require_bitsandbytes,
3637
require_read_token,
3738
require_torch,
38-
require_torch_gpu,
39+
require_torch_accelerator,
3940
slow,
4041
torch_device,
4142
)
@@ -524,7 +525,7 @@ def tearDown(self):
524525
cleanup(torch_device, gc_collect=True)
525526

526527
@slow
527-
@require_torch_gpu
528+
@require_torch_accelerator
528529
@require_bitsandbytes
529530
@require_read_token
530531
def test_11b_model_integration_generate(self):
@@ -537,9 +538,18 @@ def test_11b_model_integration_generate(self):
537538

538539
inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch_device)
539540

541+
input_ids = inputs["input_ids"]
542+
540543
# Check inputs ids
541-
expected_input_ids = torch.tensor([[128256, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342, 369, 420, 832]], device=torch_device) # fmt: skip
542-
self.assertTrue(torch.equal(inputs["input_ids"], expected_input_ids))
544+
expected_input_ids_all = Expectations(
545+
{
546+
("xpu", 3): torch.tensor([[128000, 128256, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342, 369, 420, 832]], device=torch_device),
547+
("cuda", 7): torch.tensor([[128256, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342, 369, 420, 832]], device=torch_device),
548+
("cuda", 8): torch.tensor([[128000, 128256, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342, 369, 420, 832]], device=torch_device),
549+
}
550+
) # fmt: skip
551+
expected_input_ids = expected_input_ids_all.get_expectation()
552+
self.assertTrue(torch.equal(input_ids, expected_input_ids))
543553

544554
# Load model in 4 bit
545555
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
@@ -551,7 +561,14 @@ def test_11b_model_integration_generate(self):
551561
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
552562

553563
decoded_output = processor.decode(output[0], skip_special_tokens=True)
554-
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip
564+
expected_outputs = Expectations(
565+
{
566+
("xpu", 3): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
567+
("cuda", 7): "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a",
568+
("cuda", 8): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
569+
}
570+
) # fmt: skip
571+
expected_output = expected_outputs.get_expectation()
555572

556573
self.assertEqual(
557574
decoded_output,
@@ -560,18 +577,26 @@ def test_11b_model_integration_generate(self):
560577
)
561578

562579
@slow
563-
@require_torch_gpu
580+
@require_torch_accelerator
564581
@require_bitsandbytes
565582
@require_read_token
566583
def test_11b_model_integration_generate_text_only(self):
567584
# Prepare inputs
568585
processor = AutoProcessor.from_pretrained(self.base_model_checkpoint)
569586
prompt = "If I had to write a haiku"
570587
inputs = processor(text=prompt, return_tensors="pt").to(torch_device)
588+
input_ids = inputs["input_ids"].cpu().squeeze().tolist()
571589

572590
# Check inputs ids
573-
expected_input_ids = [128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342]
574-
self.assertEqual(inputs["input_ids"].cpu().squeeze().tolist(), expected_input_ids)
591+
expected_input_ids_all = Expectations(
592+
{
593+
("xpu", 3): [128000, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342],
594+
("cuda", 7): [128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342],
595+
("cuda", 8): [128000, 128000, 2746, 358, 1047, 311, 3350, 264, 6520, 39342],
596+
}
597+
)
598+
expected_input_ids = expected_input_ids_all.get_expectation()
599+
self.assertEqual(input_ids, expected_input_ids)
575600

576601
# Load model in 4 bit
577602
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
@@ -583,16 +608,22 @@ def test_11b_model_integration_generate_text_only(self):
583608
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
584609

585610
decoded_output = processor.decode(output[0], skip_special_tokens=True)
586-
expected_output = "If I had to write a haiku about my life, I think it would be something like:\n\"Life is a messy stream\nTwists and turns, ups" # fmt: skip
587-
611+
expected_outputs = Expectations(
612+
{
613+
("xpu", 3): "If I had to write a haiku about my life, I would write:\nLife is a messy tapestry\n Threads of joy and sorrow\nWeft of memories",
614+
("cuda", 7): "If I had to write a haiku about my life, I think it would be something like:\n\"Life is a messy stream\nTwists and turns, ups",
615+
("cuda", 8): "If I had to write a haiku about my life, I would write:\nLife is a messy stream\nRipples of joy and pain\nFlowing, ever",
616+
}
617+
) # fmt: skip
618+
expected_output = expected_outputs.get_expectation()
588619
self.assertEqual(
589620
decoded_output,
590621
expected_output,
591622
f"Decoded output: {decoded_output}\nExpected output: {expected_output}",
592623
)
593624

594625
@slow
595-
@require_torch_gpu
626+
@require_torch_accelerator
596627
@require_bitsandbytes
597628
@require_read_token
598629
def test_11b_model_integration_forward(self):
@@ -616,7 +647,15 @@ def test_11b_model_integration_forward(self):
616647
output = model(**inputs)
617648

618649
actual_logits = output.logits[0, -1, :5].cpu()
619-
expected_logits = torch.tensor([8.3594, 7.7148, 4.7266, 0.7803, 3.1504])
650+
expected_logits_all = Expectations(
651+
{
652+
("xpu", 3): torch.tensor([9.1562, 8.9141, 5.0664, 1.6855, 3.2324]),
653+
("cuda", 7): torch.tensor([8.3594, 7.7148, 4.7266, 0.7803, 3.1504]),
654+
("cuda", 8): torch.tensor([9.0703, 8.8750, 5.0781, 1.6279, 3.2207]),
655+
}
656+
)
657+
658+
expected_logits = expected_logits_all.get_expectation()
620659
self.assertTrue(
621660
torch.allclose(actual_logits, expected_logits, atol=0.1),
622661
f"Actual logits: {actual_logits}"
@@ -625,7 +664,7 @@ def test_11b_model_integration_forward(self):
625664
)
626665

627666
@slow
628-
@require_torch_gpu
667+
@require_torch_accelerator
629668
@require_bitsandbytes
630669
@require_read_token
631670
def test_11b_model_integration_batched_generate(self):
@@ -653,7 +692,14 @@ def test_11b_model_integration_batched_generate(self):
653692

654693
# Check first output
655694
decoded_output = processor.decode(output[0], skip_special_tokens=True)
656-
expected_output = "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a" # fmt: skip
695+
expected_outputs = Expectations(
696+
{
697+
("xpu", 3): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
698+
("cuda", 7): "If I had to write a haiku for this one, it would be:.\\nI'm not a poet.\\nBut I'm a photographer.\\nAnd I'm a",
699+
("cuda", 8): "If I had to write a haiku for this one, it would be:.\\nA dock on a lake.\\nA mountain in the distance.\\nA long exposure.",
700+
}
701+
) # fmt: skip
702+
expected_output = expected_outputs.get_expectation()
657703

658704
self.assertEqual(
659705
decoded_output,
@@ -663,7 +709,14 @@ def test_11b_model_integration_batched_generate(self):
663709

664710
# Check second output
665711
decoded_output = processor.decode(output[1], skip_special_tokens=True)
666-
expected_output = "This image shows is a photograph of a stop sign in front of a Chinese archway. The stop sign is red with white letters and is" # fmt: skip
712+
expected_outputs = Expectations(
713+
{
714+
("xpu", 3): "This image shows\nI'm not able to provide information on the person in this image. I can give you an idea of what's happening",
715+
("cuda", 7): "This image shows is a photograph of a stop sign in front of a Chinese archway. The stop sign is red with white letters and is",
716+
("cuda", 8): "This image shows\nI'm not able to provide information on the person in this image. I can give you an idea of what's happening",
717+
}
718+
) # fmt: skip
719+
expected_output = expected_outputs.get_expectation()
667720

668721
self.assertEqual(
669722
decoded_output,
@@ -672,7 +725,7 @@ def test_11b_model_integration_batched_generate(self):
672725
)
673726

674727
@slow
675-
@require_torch_gpu
728+
@require_torch_accelerator
676729
@require_bitsandbytes
677730
@require_read_token
678731
def test_11b_model_integration_multi_image_generate(self):

0 commit comments

Comments
 (0)