Skip to content

Commit 6673081

Browse files
yao-matrixydshieh
andauthored
enable 6 granite cases on xpu (#37569)
* enable 6 granite cases on XPU Signed-off-by: YAO Matrix <[email protected]> * make them all pass on A100 Signed-off-by: N <[email protected]> * fix style Signed-off-by: YAO Matrix <[email protected]> * update --------- Signed-off-by: YAO Matrix <[email protected]> Signed-off-by: N <[email protected]> Co-authored-by: ydshieh <[email protected]>
1 parent 9167461 commit 6673081

File tree

3 files changed

+104
-24
lines changed

3 files changed

+104
-24
lines changed

tests/models/granite/test_modeling_granite.py

+28-8
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from transformers import GraniteConfig, is_torch_available, set_seed
2121
from transformers.testing_utils import (
22+
Expectations,
2223
require_read_token,
2324
require_torch,
24-
require_torch_gpu,
25+
require_torch_accelerator,
2526
slow,
2627
torch_device,
2728
)
@@ -302,7 +303,7 @@ def test_model_rope_scaling(self):
302303
torch.testing.assert_close(yarn_sin_long, original_sin_long)
303304

304305

305-
@require_torch_gpu
306+
@require_torch_accelerator
306307
class GraniteIntegrationTest(unittest.TestCase):
307308
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
308309
# Depending on the hardware we get different logits / generations
@@ -328,15 +329,27 @@ def test_model_3b_logits_bf16(self):
328329
# Expected mean on dim = -1
329330

330331
# fmt: off
331-
EXPECTED_MEAN = torch.tensor([[-1.9798, -3.1626, -2.8062, -2.3777, -2.7091, -2.2338, -2.5924, -2.3974]])
332+
EXPECTED_MEANS = Expectations(
333+
{
334+
("xpu", 3): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]),
335+
("cuda", 7): torch.tensor([[-1.9798, -3.1626, -2.8062, -2.3777, -2.7091, -2.2338, -2.5924, -2.3974]]),
336+
("cuda", 8): torch.tensor([[-3.1406, -2.5469, -2.6250, -2.1250, -2.6250, -2.6562, -2.6875, -2.9688]]),
337+
}
338+
)
339+
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
332340

333-
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1), rtol=1e-2, atol=1e-2)
341+
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.mean(-1).float(), rtol=1e-2, atol=1e-2)
334342

335343
# slicing logits[0, 0, 0:15]
336-
EXPECTED_SLICE = torch.tensor([[4.8750, -2.1875, -2.1875, -2.1875, -2.1875, -2.8438, -2.1875, -2.1875,
337-
-2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875]])
344+
EXPECTED_SLICES = Expectations(
345+
{
346+
("xpu", 3): torch.tensor([[2.2031, -5.0625, -5.0625, -5.0625, -5.0625, -0.9180, -5.0625, -5.0625, -5.0625, -5.0625, -5.5312, -2.1719, -1.7891, -0.4922, -2.5469]]),
347+
("cuda", 7): torch.tensor([[4.8750, -2.1875, -2.1875, -2.1875, -2.1875, -2.8438, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875, -2.1875]]),
348+
("cuda", 8): torch.tensor([[2.0938, -5.0312, -5.0312, -5.0312, -5.0312, -1.0469, -5.0312, -5.0312, -5.0312, -5.0312, -5.5625, -2.1875, -1.7891, -0.5820, -2.6250]]),
349+
}
350+
)
351+
EXPECTED_SLICE = EXPECTED_SLICES.get_expectation()
338352
# fmt: on
339-
340353
self.assertTrue(
341354
torch.allclose(
342355
EXPECTED_SLICE.to(torch_device),
@@ -358,6 +371,13 @@ def test_model_3b_logits(self):
358371

359372
# fmt: off
360373
# Expected mean on dim = -1
361-
EXPECTED_MEAN = torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]])
374+
EXPECTED_MEANS = Expectations(
375+
{
376+
("xpu", 3): torch.tensor([[-3.2693, -2.5957, -2.6234, -2.1675, -2.6386, -2.6850, -2.7039, -2.9656]]),
377+
("cuda", 7): torch.tensor([[-2.0984, -3.1294, -2.8153, -2.3568, -2.7337, -2.2624, -2.6016, -2.4022]]),
378+
("cuda", 8): torch.tensor([[-3.2934, -2.6019, -2.6258, -2.1691, -2.6394, -2.6876, -2.7032, -2.9688]]),
379+
}
380+
)
381+
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
362382

363383
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)

tests/models/granitemoe/test_modeling_granitemoe.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from transformers import AutoTokenizer, GraniteMoeConfig, is_torch_available, set_seed
2121
from transformers.testing_utils import (
22+
Expectations,
2223
require_read_token,
2324
require_torch,
24-
require_torch_gpu,
25+
require_torch_accelerator,
2526
slow,
2627
torch_device,
2728
)
@@ -301,7 +302,7 @@ def test_model_rope_scaling(self):
301302
torch.testing.assert_close(yarn_sin_long, original_sin_long)
302303

303304

304-
@require_torch_gpu
305+
@require_torch_accelerator
305306
class GraniteMoeIntegrationTest(unittest.TestCase):
306307
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
307308
# Depending on the hardware we get different logits / generations
@@ -325,13 +326,26 @@ def test_model_3b_logits(self):
325326

326327
# fmt: off
327328
# Expected mean on dim = -1
328-
EXPECTED_MEAN = torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]])
329+
EXPECTED_MEANS = Expectations(
330+
{
331+
("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]),
332+
("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]),
333+
("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]),
334+
}
335+
)
336+
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
329337

330338
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
331339

332340
# slicing logits[0, 0, 0:15]
333-
EXPECTED_SLICE = torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892,
334-
-2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]])
341+
EXPECTED_SLICES = Expectations(
342+
{
343+
("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]),
344+
("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]),
345+
("cuda", 8): torch.tensor([[2.5479, -9.2124, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2162, -9.2122, -6.3101, -3.6224, -3.6377, -5.2542, -5.2524]]),
346+
}
347+
)
348+
EXPECTED_SLICE = EXPECTED_SLICES.get_expectation()
335349
# fmt: on
336350

337351
self.assertTrue(
@@ -346,10 +360,26 @@ def test_model_3b_logits(self):
346360
@slow
347361
def test_model_3b_generation(self):
348362
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
349-
EXPECTED_TEXT_COMPLETION = (
350-
"Simply put, the theory of relativity states that \n$$\n\\frac{d^2x^\\mu}{d\\tau^2} = "
351-
"\\frac{1}{c^2}\\frac{d^2x^\\mu}{dt^2}\n$$\nwhere $x^\\mu$ is a four-vector, $\\tau$ is the proper time"
363+
EXPECTED_TEXT_COMPLETIONS = Expectations(
364+
{
365+
("xpu", 3): (
366+
"Simply put, the theory of relativity states that 1) the speed of light is constant, and 2) the speed of light is the same for all observers.\n\n"
367+
"The first part is easy to understand. The second part is a little more difficult.\n\n"
368+
"The second part of the theory of relativity is a little more difficult to understand.\n"
369+
),
370+
("cuda", 7): (
371+
"Simply put, the theory of relativity states that \n$$\n\\frac{d^2x^\\mu}{d\\tau^2} = "
372+
"\\frac{1}{c^2}\\frac{d^2x^\\mu}{dt^2}\n$$\nwhere $x^\\mu$ is a four-vector, $\\tau$ is the proper time"
373+
),
374+
("cuda", 8): (
375+
"Simply put, the theory of relativity states that 1) the speed of light is constant, and 2) the speed of light is the same for all observers.\n\n"
376+
"The first part is easy to understand. The second part is a little more difficult.\n\n"
377+
"The second part of the theory of relativity is a little more difficult to understand.\n"
378+
),
379+
}
352380
)
381+
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
382+
353383
prompt = "Simply put, the theory of relativity states that "
354384
tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
355385
model = GraniteMoeForCausalLM.from_pretrained("ibm/PowerMoE-3b", device_map="auto")

tests/models/granitemoeshared/test_modeling_granitemoeshared.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from transformers import AutoTokenizer, GraniteMoeSharedConfig, is_torch_available, set_seed
2121
from transformers.testing_utils import (
22+
Expectations,
2223
require_read_token,
2324
require_torch,
24-
require_torch_gpu,
25+
require_torch_accelerator,
2526
slow,
2627
torch_device,
2728
)
@@ -304,7 +305,7 @@ def test_model_rope_scaling(self):
304305
torch.testing.assert_close(yarn_sin_long, original_sin_long)
305306

306307

307-
@require_torch_gpu
308+
@require_torch_accelerator
308309
class GraniteMoeSharedIntegrationTest(unittest.TestCase):
309310
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
310311
# Depending on the hardware we get different logits / generations
@@ -328,13 +329,26 @@ def test_model_3b_logits(self):
328329

329330
# fmt: off
330331
# Expected mean on dim = -1
331-
EXPECTED_MEAN = torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]])
332+
EXPECTED_MEANS = Expectations(
333+
{
334+
("xpu", 3): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]),
335+
("cuda", 7): torch.tensor([[-2.2122, -1.6632, -2.9269, -2.3344, -2.0143, -3.0146, -2.6839, -2.5610]]),
336+
("cuda", 8): torch.tensor([[-4.4005, -3.6689, -3.6187, -2.8308, -3.9871, -3.1001, -2.8738, -2.8063]]),
337+
}
338+
)
332339

340+
EXPECTED_MEAN = EXPECTED_MEANS.get_expectation()
333341
torch.testing.assert_close(EXPECTED_MEAN.to(torch_device), out.logits.float().mean(-1), rtol=1e-2, atol=1e-2)
334342

335343
# slicing logits[0, 0, 0:15]
336-
EXPECTED_SLICE = torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892,
337-
-2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]])
344+
EXPECTED_SLICES = Expectations(
345+
{
346+
("xpu", 3): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]),
347+
("cuda", 7): torch.tensor([[4.8785, -2.2890, -2.2892, -2.2885, -2.2890, -3.5007, -2.2897, -2.2892, -2.2895, -2.2891, -2.2887, -2.2882, -2.2889, -2.2898, -2.2892]]),
348+
("cuda", 8): torch.tensor([[2.5479, -9.2123, -9.2121, -9.2175, -9.2122, -1.5024, -9.2121, -9.2122, -9.2161, -9.2122, -6.3100, -3.6223, -3.6377, -5.2542, -5.2523]]),
349+
}
350+
)
351+
EXPECTED_SLICE = EXPECTED_SLICES.get_expectation()
338352
# fmt: on
339353

340354
self.assertTrue(
@@ -349,10 +363,26 @@ def test_model_3b_logits(self):
349363
@slow
350364
def test_model_3b_generation(self):
351365
# ground truth text generated with dola_layers="low", repetition_penalty=1.2
352-
EXPECTED_TEXT_COMPLETION = (
353-
"Simply put, the theory of relativity states that \n$$\n\\frac{d^2x^\\mu}{d\\tau^2} = "
354-
"\\frac{1}{c^2}\\frac{d^2x^\\mu}{dt^2}\n$$\nwhere $x^\\mu$ is a four-vector, $\\tau$ is the proper time"
366+
EXPECTED_TEXT_COMPLETIONS = Expectations(
367+
{
368+
("xpu", 3): (
369+
"Simply put, the theory of relativity states that 1) the speed of light is constant, and 2) the speed of light is the same for all observers.\n\n"
370+
"The first part is easy to understand. The second part is a little more difficult.\n\n"
371+
"The second part of the theory of relativity is a little more difficult to understand.\n"
372+
),
373+
("cuda", 7): (
374+
"Simply put, the theory of relativity states that \n$$\n\\frac{d^2x^\\mu}{d\\tau^2} = "
375+
"\\frac{1}{c^2}\\frac{d^2x^\\mu}{dt^2}\n$$\nwhere $x^\\mu$ is a four-vector, $\\tau$ is the proper time"
376+
),
377+
("cuda", 8): (
378+
"Simply put, the theory of relativity states that 1) the speed of light is constant, and 2) the speed of light is the same for all observers.\n\n"
379+
"The first part is easy to understand. The second part is a little more difficult.\n\n"
380+
"The second part of the theory of relativity is a little more difficult to understand.\n"
381+
),
382+
}
355383
)
384+
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
385+
356386
prompt = "Simply put, the theory of relativity states that "
357387
tokenizer = AutoTokenizer.from_pretrained("ibm/PowerMoE-3b")
358388
model = GraniteMoeSharedForCausalLM.from_pretrained("ibm/PowerMoE-3b", device_map="auto")

0 commit comments

Comments
 (0)