Skip to content

[Misc] Human-readable max-model-len cli arg #16181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from argparse import ArgumentTypeError
from argparse import ArgumentError, ArgumentTypeError

import pytest

Expand Down Expand Up @@ -142,3 +142,39 @@ def test_composite_arg_parser(arg, expected, option):
else:
args = parser.parse_args([f"--{option}", arg])
assert getattr(args, option.replace("-", "_")) == expected


def test_human_readable_model_len():
# `exit_on_error` disabled to test invalid values below
parser = EngineArgs.add_cli_args(
FlexibleArgumentParser(exit_on_error=False))

args = parser.parse_args([])
assert args.max_model_len is None

args = parser.parse_args(["--max-model-len", "1024"])
assert args.max_model_len == 1024

# Lower
args = parser.parse_args(["--max-model-len", "1m"])
assert args.max_model_len == 1_000_000
args = parser.parse_args(["--max-model-len", "10k"])
assert args.max_model_len == 10_000

# Capital
args = parser.parse_args(["--max-model-len", "3K"])
assert args.max_model_len == 1024 * 3
args = parser.parse_args(["--max-model-len", "10M"])
assert args.max_model_len == 2**20 * 10

# Decimal values
args = parser.parse_args(["--max-model-len", "10.2k"])
assert args.max_model_len == 10200
# ..truncated to the nearest int
args = parser.parse_args(["--max-model-len", "10.212345k"])
assert args.max_model_len == 10212

# Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
with pytest.raises(ArgumentError):
args = parser.parse_args(["--max-model-len", invalid])
50 changes: 48 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import dataclasses
import json
import re
import threading
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Expand Down Expand Up @@ -368,10 +369,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len',
type=int,
type=human_readable_int,
default=EngineArgs.max_model_len,
help='Model context length. If unspecified, will '
'be automatically derived from the model config.')
'be automatically derived from the model config. '
'Supports k/m/g/K/M/G in human-readable format.\n'
'Examples:\n'
'- 1k → 1000\n'
'- 1K → 1024\n')
parser.add_argument(
'--guided-decoding-backend',
type=str,
Expand Down Expand Up @@ -1739,6 +1744,47 @@ def _warn_or_fallback(feature_name: str) -> bool:
return should_exit


def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.

Examples:
- '1k' -> 1,000
- '1K' -> 1,024
- '25.6k' -> 25,600
"""
value = value.strip()
match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value)
if match:
decimal_multiplier = {
'k': 10**3,
'm': 10**6,
'g': 10**9,
}
binary_multiplier = {
'K': 2**10,
'M': 2**20,
'G': 2**30,
}

number, suffix = match.groups()
if suffix in decimal_multiplier:
mult = decimal_multiplier[suffix]
return int(float(number) * mult)
elif suffix in binary_multiplier:
mult = binary_multiplier[suffix]
# Do not allow decimals with binary multipliers
try:
return int(number) * mult
except ValueError as e:
raise argparse.ArgumentTypeError("Decimals are not allowed " \
f"with binary suffixes like {suffix}. Did you mean to use " \
f"{number}{suffix.lower()} instead?") from e

# Regular plain number.
return int(value)


# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(FlexibleArgumentParser())
Expand Down