Skip to content

Commit ff334ca

Browse files
authored
Update deprecated type hinting in vllm/profiler (vllm-project#18057)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 6223dd8 commit ff334ca

File tree

3 files changed

+23
-24
lines changed

3 files changed

+23
-24
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ exclude = [
8484
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
8585
"vllm/platforms/**/*.py" = ["UP006", "UP035"]
8686
"vllm/plugins/**/*.py" = ["UP006", "UP035"]
87-
"vllm/profiler/**/*.py" = ["UP006", "UP035"]
8887
"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"]
8988
"vllm/spec_decode/**/*.py" = ["UP006", "UP035"]
9089
"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"]

vllm/profiler/layerwise_profile.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
from collections import defaultdict
55
from dataclasses import asdict, dataclass, field
6-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
6+
from typing import Any, Callable, Optional, TypeAlias, Union
77

88
import pandas as pd
99
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
@@ -20,7 +20,7 @@
2020
class _ModuleTreeNode:
2121
event: _ProfilerEvent
2222
parent: Optional['_ModuleTreeNode'] = None
23-
children: List['_ModuleTreeNode'] = field(default_factory=list)
23+
children: list['_ModuleTreeNode'] = field(default_factory=list)
2424
trace: str = ""
2525

2626
@property
@@ -60,19 +60,19 @@ class ModelStatsEntry:
6060
@dataclass
6161
class _StatsTreeNode:
6262
entry: StatsEntry
63-
children: List[StatsEntry]
63+
children: list[StatsEntry]
6464
parent: Optional[StatsEntry]
6565

6666

6767
@dataclass
6868
class LayerwiseProfileResults(profile):
6969
_kineto_results: _ProfilerResult
70-
_kineto_event_correlation_map: Dict[int,
71-
List[_KinetoEvent]] = field(init=False)
72-
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
73-
_module_tree: List[_ModuleTreeNode] = field(init=False)
74-
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
75-
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
70+
_kineto_event_correlation_map: dict[int,
71+
list[_KinetoEvent]] = field(init=False)
72+
_event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
73+
_module_tree: list[_ModuleTreeNode] = field(init=False)
74+
_model_stats_tree: list[_StatsTreeNode] = field(init=False)
75+
_summary_stats_tree: list[_StatsTreeNode] = field(init=False)
7676

7777
# profile metadata
7878
num_running_seqs: Optional[int] = None
@@ -82,7 +82,7 @@ def __post_init__(self):
8282
self._build_module_tree()
8383
self._build_stats_trees()
8484

85-
def print_model_table(self, column_widths: Dict[str, int] = None):
85+
def print_model_table(self, column_widths: dict[str, int] = None):
8686
_column_widths = dict(name=60,
8787
cpu_time_us=12,
8888
cuda_time_us=12,
@@ -100,7 +100,7 @@ def print_model_table(self, column_widths: Dict[str, int] = None):
100100
filtered_model_table,
101101
indent_style=lambda indent: "|" + "-" * indent + " "))
102102

103-
def print_summary_table(self, column_widths: Dict[str, int] = None):
103+
def print_summary_table(self, column_widths: dict[str, int] = None):
104104
_column_widths = dict(name=80,
105105
cuda_time_us=12,
106106
pct_cuda_time=12,
@@ -142,7 +142,7 @@ def convert_stats_to_dict(self) -> dict[str, Any]:
142142
}
143143

144144
@staticmethod
145-
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
145+
def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
146146
StatsEntry]],
147147
indent_style: Union[Callable[[int],
148148
str],
@@ -229,7 +229,7 @@ def _total_cuda_time(self):
229229
[self._cumulative_cuda_time(root) for root in self._module_tree])
230230

231231
def _build_stats_trees(self):
232-
summary_dict: Dict[str, _StatsTreeNode] = {}
232+
summary_dict: dict[str, _StatsTreeNode] = {}
233233
total_cuda_time = self._total_cuda_time()
234234

235235
def pct_cuda_time(cuda_time_us):
@@ -238,7 +238,7 @@ def pct_cuda_time(cuda_time_us):
238238
def build_summary_stats_tree_df(
239239
node: _ModuleTreeNode,
240240
parent: Optional[_StatsTreeNode] = None,
241-
summary_trace: Tuple[str] = ()):
241+
summary_trace: tuple[str] = ()):
242242

243243
if event_has_module(node.event):
244244
name = event_module_repr(node.event)
@@ -313,8 +313,8 @@ def build_model_stats_tree_df(node: _ModuleTreeNode,
313313
self._model_stats_tree.append(build_model_stats_tree_df(root))
314314

315315
def _flatten_stats_tree(
316-
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
317-
entries: List[Tuple[int, StatsEntry]] = []
316+
self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
317+
entries: list[tuple[int, StatsEntry]] = []
318318

319319
def df_traversal(node: _StatsTreeNode, depth=0):
320320
entries.append((depth, node.entry))
@@ -327,10 +327,10 @@ def df_traversal(node: _StatsTreeNode, depth=0):
327327
return entries
328328

329329
def _convert_stats_tree_to_dict(self,
330-
tree: List[_StatsTreeNode]) -> List[Dict]:
331-
root_dicts: List[Dict] = []
330+
tree: list[_StatsTreeNode]) -> list[dict]:
331+
root_dicts: list[dict] = []
332332

333-
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
333+
def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
334334
curr_json_list.append({
335335
"entry": asdict(node.entry),
336336
"children": []

vllm/profiler/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import dataclasses
4-
from typing import Callable, Dict, List, Type, Union
4+
from typing import Callable, Union
55

66
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
77

@@ -30,14 +30,14 @@ def trim_string_back(string, width):
3030

3131
class TablePrinter:
3232

33-
def __init__(self, row_cls: Type[dataclasses.dataclass],
34-
column_widths: Dict[str, int]):
33+
def __init__(self, row_cls: type[dataclasses.dataclass],
34+
column_widths: dict[str, int]):
3535
self.row_cls = row_cls
3636
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
3737
self.column_widths = column_widths
3838
assert set(self.column_widths.keys()) == set(self.fieldnames)
3939

40-
def print_table(self, rows: List[dataclasses.dataclass]):
40+
def print_table(self, rows: list[dataclasses.dataclass]):
4141
self._print_header()
4242
self._print_line()
4343
for row in rows:

0 commit comments

Comments
 (0)