|
6 | 6 | import pytest
|
7 | 7 | import torch
|
8 | 8 |
|
| 9 | +from vllm.platforms import current_platform |
9 | 10 | from vllm.utils import make_tensor_with_pad
|
10 | 11 | from vllm.v1.sample.metadata import SamplingMetadata
|
11 | 12 | from vllm.v1.sample.sampler import Sampler
|
12 | 13 |
|
13 | 14 | VOCAB_SIZE = 1024
|
14 | 15 | NUM_OUTPUT_TOKENS = 20
|
15 |
| -CUDA_DEVICES = [ |
16 |
| - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) |
| 16 | +TORCH_DEVICES = [ |
| 17 | + f"{current_platform.device_type}:{i}" |
| 18 | + for i in range(1 if current_platform.get_device_count() == 1 else 2) |
17 | 19 | ]
|
18 | 20 | MAX_NUM_PROMPT_TOKENS = 64
|
19 | 21 |
|
@@ -224,7 +226,7 @@ def _create_weighted_output_token_list(
|
224 | 226 | return output_token_ids, sorted_token_ids_in_output
|
225 | 227 |
|
226 | 228 |
|
227 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 229 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
228 | 230 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
229 | 231 | def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
230 | 232 | """
|
@@ -254,7 +256,7 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
|
254 | 256 | assert logits[batch_idx][token_id] != -float("inf")
|
255 | 257 |
|
256 | 258 |
|
257 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 259 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
258 | 260 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
259 | 261 | @pytest.mark.parametrize("presence_penalty", [-2.0, 2.0])
|
260 | 262 | def test_sampler_presence_penalty(device: str, batch_size: int,
|
@@ -299,7 +301,7 @@ def test_sampler_presence_penalty(device: str, batch_size: int,
|
299 | 301 | assert penalized_token_id not in output_token_ids[batch_idx]
|
300 | 302 |
|
301 | 303 |
|
302 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 304 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
303 | 305 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
304 | 306 | @pytest.mark.parametrize("frequency_penalty", [-2.0, 2.0])
|
305 | 307 | def test_sampler_frequency_penalty(device: str, batch_size: int,
|
@@ -352,7 +354,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
|
352 | 354 | assert penalized_token_id not in distinct_sorted_token_ids_in_output
|
353 | 355 |
|
354 | 356 |
|
355 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 357 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
356 | 358 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
357 | 359 | @pytest.mark.parametrize("repetition_penalty", [0.1, 1.9])
|
358 | 360 | def test_sampler_repetition_penalty(device: str, batch_size: int,
|
@@ -398,7 +400,7 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
|
398 | 400 | or non_penalized_token_id in output_tokens)
|
399 | 401 |
|
400 | 402 |
|
401 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 403 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
402 | 404 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
403 | 405 | @pytest.mark.parametrize("min_p", [0.0, 0.1])
|
404 | 406 | def test_sampler_min_p(device: str, batch_size: int, min_p: float):
|
@@ -438,7 +440,7 @@ def test_sampler_min_p(device: str, batch_size: int, min_p: float):
|
438 | 440 | assert logits[batch_idx][token_id] != -float("inf")
|
439 | 441 |
|
440 | 442 |
|
441 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 443 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
442 | 444 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
443 | 445 | @pytest.mark.parametrize("bias_value", [-0.1, 1.2])
|
444 | 446 | def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
@@ -472,7 +474,7 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
|
472 | 474 | assert logits_for_req[token_id] == pytest.approx(1e-2)
|
473 | 475 |
|
474 | 476 |
|
475 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 477 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
476 | 478 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
477 | 479 | @pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
|
478 | 480 | def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
@@ -513,7 +515,7 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
|
513 | 515 | assert logits_for_req[token_id] != -float("inf")
|
514 | 516 |
|
515 | 517 |
|
516 |
| -@pytest.mark.parametrize("device", CUDA_DEVICES) |
| 518 | +@pytest.mark.parametrize("device", TORCH_DEVICES) |
517 | 519 | @pytest.mark.parametrize("batch_size", [1, 2, 32])
|
518 | 520 | @pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
|
519 | 521 | def test_sampler_bad_words(device: str, batch_size: int,
|
|
0 commit comments