Skip to content

Commit 63e3993

Browse files
xendoJerzy Zagorski
and
Jerzy Zagorski
authored
[Frontend] [Neuron] Parse literals out of override-neuron-config (#8959)
Co-authored-by: Jerzy Zagorski <[email protected]>
1 parent f5d72b2 commit 63e3993

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

tests/engine/test_arg_utils.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg):
4242
nullable_kvs(arg)
4343

4444

45-
@pytest.mark.parametrize(("arg", "expected"), [
46-
(None, None),
47-
("{}", {}),
48-
('{"num_crops": 4}', {
49-
"num_crops": 4
50-
}),
51-
('{"foo": {"bar": "baz"}}', {
52-
"foo": {
53-
"bar": "baz"
54-
}
55-
}),
45+
# yapf: disable
46+
@pytest.mark.parametrize(("arg", "expected", "option"), [
47+
(None, None, "mm-processor-kwargs"),
48+
("{}", {}, "mm-processor-kwargs"),
49+
(
50+
'{"num_crops": 4}',
51+
{
52+
"num_crops": 4
53+
},
54+
"mm-processor-kwargs"
55+
),
56+
(
57+
'{"foo": {"bar": "baz"}}',
58+
{
59+
"foo":
60+
{
61+
"bar": "baz"
62+
}
63+
},
64+
"mm-processor-kwargs"
65+
),
66+
(
67+
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
68+
{
69+
"cast_logits_dtype": "bfloat16",
70+
"sequence_parallel_norm": True,
71+
"sequence_parallel_norm_threshold": 2048,
72+
},
73+
"override-neuron-config"
74+
),
5675
])
57-
def test_mm_processor_kwargs_prompt_parser(arg, expected):
76+
# yapf: enable
77+
def test_composite_arg_parser(arg, expected, option):
5878
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
5979
if arg is None:
6080
args = parser.parse_args([])
6181
else:
62-
args = parser.parse_args(["--mm-processor-kwargs", arg])
63-
assert args.mm_processor_kwargs == expected
82+
args = parser.parse_args([f"--{option}", arg])
83+
assert getattr(args, option.replace("-", "_")) == expected

vllm/engine/arg_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -800,13 +800,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
800800
"lower performance.")
801801
parser.add_argument(
802802
'--override-neuron-config',
803-
type=lambda configs: {
804-
str(key): value
805-
for key, value in
806-
(config.split(':') for config in configs.split(','))
807-
},
803+
type=json.loads,
808804
default=None,
809-
help="override or set neuron device configuration.")
805+
help="Override or set neuron device configuration. "
806+
"e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
810807

811808
parser.add_argument(
812809
'--scheduling-policy',

0 commit comments

Comments
 (0)