Skip to content

Commit 0e00ec9

Browse files
hmellorwuisawesome
authored andcommitted
Improve conversion from dataclass configs to argparse arguments (#17303)
Signed-off-by: Harry Mellor <[email protected]>
1 parent d5a5513 commit 0e00ec9

File tree

4 files changed

+245
-154
lines changed

4 files changed

+245
-154
lines changed

tests/engine/test_arg_utils.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,118 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import json
34
from argparse import ArgumentError, ArgumentTypeError
5+
from contextlib import nullcontext
6+
from dataclasses import dataclass, field
7+
from typing import Literal, Optional
48

59
import pytest
610

7-
from vllm.config import PoolerConfig
8-
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
11+
from vllm.config import PoolerConfig, config
12+
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
13+
get_type, is_not_builtin, is_type,
14+
nullable_kvs, optional_type)
915
from vllm.utils import FlexibleArgumentParser
1016

1117

18+
@pytest.mark.parametrize(("type", "value", "expected"), [
19+
(int, "42", 42),
20+
(int, "None", None),
21+
(float, "3.14", 3.14),
22+
(float, "None", None),
23+
(str, "Hello World!", "Hello World!"),
24+
(str, "None", None),
25+
(json.loads, '{"foo":1,"bar":2}', {
26+
"foo": 1,
27+
"bar": 2
28+
}),
29+
(json.loads, "foo=1,bar=2", {
30+
"foo": 1,
31+
"bar": 2
32+
}),
33+
(json.loads, "None", None),
34+
])
35+
def test_optional_type(type, value, expected):
36+
optional_type_func = optional_type(type)
37+
context = nullcontext()
38+
if value == "foo=1,bar=2":
39+
context = pytest.warns(DeprecationWarning)
40+
with context:
41+
assert optional_type_func(value) == expected
42+
43+
44+
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
45+
(int, int, True),
46+
(int, float, False),
47+
(list[int], list, True),
48+
(list[int], tuple, False),
49+
(Literal[0, 1], Literal, True),
50+
])
51+
def test_is_type(type_hint, type, expected):
52+
assert is_type(type_hint, type) == expected
53+
54+
55+
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
56+
({float, int}, int, True),
57+
({int, tuple[int]}, int, True),
58+
({int, tuple[int]}, float, False),
59+
({str, Literal["x", "y"]}, Literal, True),
60+
])
61+
def test_contains_type(type_hints, type, expected):
62+
assert contains_type(type_hints, type) == expected
63+
64+
65+
@pytest.mark.parametrize(("type_hints", "type", "expected"), [
66+
({int, float}, int, int),
67+
({int, float}, str, None),
68+
({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
69+
])
70+
def test_get_type(type_hints, type, expected):
71+
assert get_type(type_hints, type) == expected
72+
73+
74+
@config
75+
@dataclass
76+
class DummyConfigClass:
77+
regular_bool: bool = True
78+
"""Regular bool with default True"""
79+
optional_bool: Optional[bool] = None
80+
"""Optional bool with default None"""
81+
optional_literal: Optional[Literal["x", "y"]] = None
82+
"""Optional literal with default None"""
83+
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
84+
"""Tuple with default (1, 2, 3)"""
85+
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
86+
"""Tuple with default (1, 2)"""
87+
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
88+
"""List with default [1, 2, 3]"""
89+
90+
91+
@pytest.mark.parametrize(("type_hint", "expected"), [
92+
(int, False),
93+
(DummyConfigClass, True),
94+
])
95+
def test_is_not_builtin(type_hint, expected):
96+
assert is_not_builtin(type_hint) == expected
97+
98+
99+
def test_get_kwargs():
100+
kwargs = get_kwargs(DummyConfigClass)
101+
print(kwargs)
102+
103+
# bools should not have their type set
104+
assert kwargs["regular_bool"].get("type") is None
105+
assert kwargs["optional_bool"].get("type") is None
106+
# optional literals should have None as a choice
107+
assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
108+
# tuples should have the correct nargs
109+
assert kwargs["tuple_n"]["nargs"] == "+"
110+
assert kwargs["tuple_2"]["nargs"] == 2
111+
# lists should work
112+
assert kwargs["list_n"]["type"] is int
113+
assert kwargs["list_n"]["nargs"] == "+"
114+
115+
12116
@pytest.mark.parametrize(("arg", "expected"), [
13117
(None, dict()),
14118
("image=16", {

0 commit comments

Comments
 (0)