Skip to content

Commit a67110b

Browse files
NickLuccheDarkLight1337
authored andcommitted
[Misc] Human-readable max-model-len cli arg (vllm-project#16181)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 6204b90 commit a67110b

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

tests/engine/test_arg_utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from argparse import ArgumentTypeError
3+
from argparse import ArgumentError, ArgumentTypeError
44

55
import pytest
66

@@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
142142
else:
143143
args = parser.parse_args([f"--{option}", arg])
144144
assert getattr(args, option.replace("-", "_")) == expected
145+
146+
147+
def test_human_readable_model_len():
148+
# `exit_on_error` disabled to test invalid values below
149+
parser = EngineArgs.add_cli_args(
150+
FlexibleArgumentParser(exit_on_error=False))
151+
152+
args = parser.parse_args([])
153+
assert args.max_model_len is None
154+
155+
args = parser.parse_args(["--max-model-len", "1024"])
156+
assert args.max_model_len == 1024
157+
158+
# Lower
159+
args = parser.parse_args(["--max-model-len", "1m"])
160+
assert args.max_model_len == 1_000_000
161+
args = parser.parse_args(["--max-model-len", "10k"])
162+
assert args.max_model_len == 10_000
163+
164+
# Capital
165+
args = parser.parse_args(["--max-model-len", "3K"])
166+
assert args.max_model_len == 1024 * 3
167+
args = parser.parse_args(["--max-model-len", "10M"])
168+
assert args.max_model_len == 2**20 * 10
169+
170+
# Decimal values
171+
args = parser.parse_args(["--max-model-len", "10.2k"])
172+
assert args.max_model_len == 10200
173+
# ..truncated to the nearest int
174+
args = parser.parse_args(["--max-model-len", "10.212345k"])
175+
assert args.max_model_len == 10212
176+
177+
# Invalid (do not allow decimals with binary multipliers)
178+
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
179+
with pytest.raises(ArgumentError):
180+
args = parser.parse_args(["--max-model-len", invalid])

vllm/engine/arg_utils.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import argparse
44
import dataclasses
55
import json
6+
import re
67
import threading
78
from dataclasses import dataclass
89
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
@@ -368,10 +369,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
368369
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
369370
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
370371
parser.add_argument('--max-model-len',
371-
type=int,
372+
type=human_readable_int,
372373
default=EngineArgs.max_model_len,
373374
help='Model context length. If unspecified, will '
374-
'be automatically derived from the model config.')
375+
'be automatically derived from the model config. '
376+
'Supports k/m/g/K/M/G in human-readable format.\n'
377+
'Examples:\n'
378+
'- 1k → 1000\n'
379+
'- 1K → 1024\n')
375380
parser.add_argument(
376381
'--guided-decoding-backend',
377382
type=str,
@@ -1740,6 +1745,47 @@ def _warn_or_fallback(feature_name: str) -> bool:
17401745
return should_exit
17411746

17421747

1748+
def human_readable_int(value):
1749+
"""Parse human-readable integers like '1k', '2M', etc.
1750+
Including decimal values with decimal multipliers.
1751+
1752+
Examples:
1753+
- '1k' -> 1,000
1754+
- '1K' -> 1,024
1755+
- '25.6k' -> 25,600
1756+
"""
1757+
value = value.strip()
1758+
match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
1759+
if match:
1760+
decimal_multiplier = {
1761+
'k': 10**3,
1762+
'm': 10**6,
1763+
'g': 10**9,
1764+
}
1765+
binary_multiplier = {
1766+
'K': 2**10,
1767+
'M': 2**20,
1768+
'G': 2**30,
1769+
}
1770+
1771+
number, suffix = match.groups()
1772+
if suffix in decimal_multiplier:
1773+
mult = decimal_multiplier[suffix]
1774+
return int(float(number) * mult)
1775+
elif suffix in binary_multiplier:
1776+
mult = binary_multiplier[suffix]
1777+
# Do not allow decimals with binary multipliers
1778+
try:
1779+
return int(number) * mult
1780+
except ValueError as e:
1781+
raise argparse.ArgumentTypeError("Decimals are not allowed " \
1782+
f"with binary suffixes like {suffix}. Did you mean to use " \
1783+
f"{number}{suffix.lower()} instead?") from e
1784+
1785+
# Regular plain number.
1786+
return int(value)
1787+
1788+
17431789
# These functions are used by sphinx to build the documentation
17441790
def _engine_args_parser():
17451791
return EngineArgs.add_cli_args(FlexibleArgumentParser())

0 commit comments

Comments
 (0)