Skip to content

[CI] fix dump_input for str type #18697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/runai_model_streamer_test/test_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def test_runai_model_loader():
runai_model_streamer_tensors = {}
hf_safetensors_tensors = {}

for name, tensor in runai_safetensors_weights_iterator(safetensors):
for name, tensor in runai_safetensors_weights_iterator(
safetensors, True):
runai_model_streamer_tensors[name] = tensor

for name, tensor in safetensors_weights_iterator(safetensors):
for name, tensor in safetensors_weights_iterator(safetensors, True):
hf_safetensors_tensors[name] = tensor

assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)
Expand Down
38 changes: 37 additions & 1 deletion tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0

import enum
import json
import logging
import os
import sys
import tempfile
from dataclasses import dataclass
from json.decoder import JSONDecodeError
from tempfile import NamedTemporaryFile
from typing import Any
Expand All @@ -16,6 +17,7 @@
from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger,
enable_trace_function_call, init_logger)
from vllm.logging_utils import NewLineFormatter
from vllm.logging_utils.dump_input import prepare_object_to_dump


def f1(x):
Expand Down Expand Up @@ -216,3 +218,37 @@ def test_custom_logging_config_causes_an_error_if_configure_logging_is_off():
assert other_logger.handlers != root_logger.handlers
assert other_logger.level != root_logger.level
assert other_logger.propagate


def test_prepare_object_to_dump():
str_obj = 'str'
assert prepare_object_to_dump(str_obj) == "'str'"

list_obj = [1, 2, 3]
assert prepare_object_to_dump(list_obj) == '[1, 2, 3]'

dict_obj = {'a': 1, 'b': 'b'}
assert prepare_object_to_dump(dict_obj) in [
"{a: 1, b: 'b'}", "{b: 'b', a: 1}"
]

set_obj = {1, 2, 3}
assert prepare_object_to_dump(set_obj) == '[1, 2, 3]'

tuple_obj = ('a', 'b', 'c')
assert prepare_object_to_dump(tuple_obj) == "['a', 'b', 'c']"

class CustomEnum(enum.Enum):
A = enum.auto()
B = enum.auto()
C = enum.auto()

assert prepare_object_to_dump(CustomEnum.A) == repr(CustomEnum.A)

@dataclass
class CustomClass:
a: int
b: str

assert (prepare_object_to_dump(CustomClass(
1, 'b')) == "CustomClass(a=1, b='b')")
6 changes: 3 additions & 3 deletions vllm/logging_utils/dump_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def prepare_object_to_dump(obj) -> str:
if isinstance(obj, str):
return "'{obj}'" # Double quotes
return f"'{obj}'" # Double quotes
elif isinstance(obj, dict):
dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \
for k, v in obj.items()})
Expand All @@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
return obj.anon_repr()
elif hasattr(obj, '__dict__'):
items = obj.__dict__.items()
dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \
dict_str = ', '.join([f'{str(k)}={prepare_object_to_dump(v)}' \
for k, v in items])
return (f"{type(obj).__name__}({dict_str})")
return f"{type(obj).__name__}({dict_str})"
else:
# Hacky way to make sure we can serialize the object in JSON format
try:
Expand Down