|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
| 3 | +import json |
3 | 4 | from argparse import ArgumentError, ArgumentTypeError
|
| 5 | +from contextlib import nullcontext |
| 6 | +from dataclasses import dataclass, field |
| 7 | +from typing import Literal, Optional |
4 | 8 |
|
5 | 9 | import pytest
|
6 | 10 |
|
7 |
| -from vllm.config import PoolerConfig |
8 |
| -from vllm.engine.arg_utils import EngineArgs, nullable_kvs |
| 11 | +from vllm.config import PoolerConfig, config |
| 12 | +from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, |
| 13 | + get_type, is_not_builtin, is_type, |
| 14 | + nullable_kvs, optional_type) |
9 | 15 | from vllm.utils import FlexibleArgumentParser
|
10 | 16 |
|
11 | 17 |
|
| 18 | +@pytest.mark.parametrize(("type", "value", "expected"), [ |
| 19 | + (int, "42", 42), |
| 20 | + (int, "None", None), |
| 21 | + (float, "3.14", 3.14), |
| 22 | + (float, "None", None), |
| 23 | + (str, "Hello World!", "Hello World!"), |
| 24 | + (str, "None", None), |
| 25 | + (json.loads, '{"foo":1,"bar":2}', { |
| 26 | + "foo": 1, |
| 27 | + "bar": 2 |
| 28 | + }), |
| 29 | + (json.loads, "foo=1,bar=2", { |
| 30 | + "foo": 1, |
| 31 | + "bar": 2 |
| 32 | + }), |
| 33 | + (json.loads, "None", None), |
| 34 | +]) |
| 35 | +def test_optional_type(type, value, expected): |
| 36 | + optional_type_func = optional_type(type) |
| 37 | + context = nullcontext() |
| 38 | + if value == "foo=1,bar=2": |
| 39 | + context = pytest.warns(DeprecationWarning) |
| 40 | + with context: |
| 41 | + assert optional_type_func(value) == expected |
| 42 | + |
| 43 | + |
| 44 | +@pytest.mark.parametrize(("type_hint", "type", "expected"), [ |
| 45 | + (int, int, True), |
| 46 | + (int, float, False), |
| 47 | + (list[int], list, True), |
| 48 | + (list[int], tuple, False), |
| 49 | + (Literal[0, 1], Literal, True), |
| 50 | +]) |
| 51 | +def test_is_type(type_hint, type, expected): |
| 52 | + assert is_type(type_hint, type) == expected |
| 53 | + |
| 54 | + |
| 55 | +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ |
| 56 | + ({float, int}, int, True), |
| 57 | + ({int, tuple[int]}, int, True), |
| 58 | + ({int, tuple[int]}, float, False), |
| 59 | + ({str, Literal["x", "y"]}, Literal, True), |
| 60 | +]) |
| 61 | +def test_contains_type(type_hints, type, expected): |
| 62 | + assert contains_type(type_hints, type) == expected |
| 63 | + |
| 64 | + |
| 65 | +@pytest.mark.parametrize(("type_hints", "type", "expected"), [ |
| 66 | + ({int, float}, int, int), |
| 67 | + ({int, float}, str, None), |
| 68 | + ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]), |
| 69 | +]) |
| 70 | +def test_get_type(type_hints, type, expected): |
| 71 | + assert get_type(type_hints, type) == expected |
| 72 | + |
| 73 | + |
| 74 | +@config |
| 75 | +@dataclass |
| 76 | +class DummyConfigClass: |
| 77 | + regular_bool: bool = True |
| 78 | + """Regular bool with default True""" |
| 79 | + optional_bool: Optional[bool] = None |
| 80 | + """Optional bool with default None""" |
| 81 | + optional_literal: Optional[Literal["x", "y"]] = None |
| 82 | + """Optional literal with default None""" |
| 83 | + tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) |
| 84 | + """Tuple with default (1, 2, 3)""" |
| 85 | + tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) |
| 86 | + """Tuple with default (1, 2)""" |
| 87 | + list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) |
| 88 | + """List with default [1, 2, 3]""" |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.parametrize(("type_hint", "expected"), [ |
| 92 | + (int, False), |
| 93 | + (DummyConfigClass, True), |
| 94 | +]) |
| 95 | +def test_is_not_builtin(type_hint, expected): |
| 96 | + assert is_not_builtin(type_hint) == expected |
| 97 | + |
| 98 | + |
| 99 | +def test_get_kwargs(): |
| 100 | + kwargs = get_kwargs(DummyConfigClass) |
| 101 | + print(kwargs) |
| 102 | + |
| 103 | + # bools should not have their type set |
| 104 | + assert kwargs["regular_bool"].get("type") is None |
| 105 | + assert kwargs["optional_bool"].get("type") is None |
| 106 | + # optional literals should have None as a choice |
| 107 | + assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"] |
| 108 | + # tuples should have the correct nargs |
| 109 | + assert kwargs["tuple_n"]["nargs"] == "+" |
| 110 | + assert kwargs["tuple_2"]["nargs"] == 2 |
| 111 | + # lists should work |
| 112 | + assert kwargs["list_n"]["type"] is int |
| 113 | + assert kwargs["list_n"]["nargs"] == "+" |
| 114 | + |
| 115 | + |
12 | 116 | @pytest.mark.parametrize(("arg", "expected"), [
|
13 | 117 | (None, dict()),
|
14 | 118 | ("image=16", {
|
|
0 commit comments