3
3
import copy
4
4
from collections import defaultdict
5
5
from dataclasses import asdict , dataclass , field
6
- from typing import Any , Callable , Dict , List , Optional , Tuple , TypeAlias , Union
6
+ from typing import Any , Callable , Optional , TypeAlias , Union
7
7
8
8
import pandas as pd
9
9
from torch ._C ._autograd import DeviceType , _KinetoEvent , _ProfilerResult
20
20
class _ModuleTreeNode :
21
21
event : _ProfilerEvent
22
22
parent : Optional ['_ModuleTreeNode' ] = None
23
- children : List ['_ModuleTreeNode' ] = field (default_factory = list )
23
+ children : list ['_ModuleTreeNode' ] = field (default_factory = list )
24
24
trace : str = ""
25
25
26
26
@property
@@ -60,19 +60,19 @@ class ModelStatsEntry:
60
60
@dataclass
61
61
class _StatsTreeNode :
62
62
entry : StatsEntry
63
- children : List [StatsEntry ]
63
+ children : list [StatsEntry ]
64
64
parent : Optional [StatsEntry ]
65
65
66
66
67
67
@dataclass
68
68
class LayerwiseProfileResults (profile ):
69
69
_kineto_results : _ProfilerResult
70
- _kineto_event_correlation_map : Dict [int ,
71
- List [_KinetoEvent ]] = field (init = False )
72
- _event_correlation_map : Dict [int , List [FunctionEvent ]] = field (init = False )
73
- _module_tree : List [_ModuleTreeNode ] = field (init = False )
74
- _model_stats_tree : List [_StatsTreeNode ] = field (init = False )
75
- _summary_stats_tree : List [_StatsTreeNode ] = field (init = False )
70
+ _kineto_event_correlation_map : dict [int ,
71
+ list [_KinetoEvent ]] = field (init = False )
72
+ _event_correlation_map : dict [int , list [FunctionEvent ]] = field (init = False )
73
+ _module_tree : list [_ModuleTreeNode ] = field (init = False )
74
+ _model_stats_tree : list [_StatsTreeNode ] = field (init = False )
75
+ _summary_stats_tree : list [_StatsTreeNode ] = field (init = False )
76
76
77
77
# profile metadata
78
78
num_running_seqs : Optional [int ] = None
@@ -82,7 +82,7 @@ def __post_init__(self):
82
82
self ._build_module_tree ()
83
83
self ._build_stats_trees ()
84
84
85
- def print_model_table (self , column_widths : Dict [str , int ] = None ):
85
+ def print_model_table (self , column_widths : dict [str , int ] = None ):
86
86
_column_widths = dict (name = 60 ,
87
87
cpu_time_us = 12 ,
88
88
cuda_time_us = 12 ,
@@ -100,7 +100,7 @@ def print_model_table(self, column_widths: Dict[str, int] = None):
100
100
filtered_model_table ,
101
101
indent_style = lambda indent : "|" + "-" * indent + " " ))
102
102
103
- def print_summary_table (self , column_widths : Dict [str , int ] = None ):
103
+ def print_summary_table (self , column_widths : dict [str , int ] = None ):
104
104
_column_widths = dict (name = 80 ,
105
105
cuda_time_us = 12 ,
106
106
pct_cuda_time = 12 ,
@@ -142,7 +142,7 @@ def convert_stats_to_dict(self) -> dict[str, Any]:
142
142
}
143
143
144
144
@staticmethod
145
- def _indent_row_names_based_on_depth (depths_rows : List [ Tuple [int ,
145
+ def _indent_row_names_based_on_depth (depths_rows : list [ tuple [int ,
146
146
StatsEntry ]],
147
147
indent_style : Union [Callable [[int ],
148
148
str ],
@@ -229,7 +229,7 @@ def _total_cuda_time(self):
229
229
[self ._cumulative_cuda_time (root ) for root in self ._module_tree ])
230
230
231
231
def _build_stats_trees (self ):
232
- summary_dict : Dict [str , _StatsTreeNode ] = {}
232
+ summary_dict : dict [str , _StatsTreeNode ] = {}
233
233
total_cuda_time = self ._total_cuda_time ()
234
234
235
235
def pct_cuda_time (cuda_time_us ):
@@ -238,7 +238,7 @@ def pct_cuda_time(cuda_time_us):
238
238
def build_summary_stats_tree_df (
239
239
node : _ModuleTreeNode ,
240
240
parent : Optional [_StatsTreeNode ] = None ,
241
- summary_trace : Tuple [str ] = ()):
241
+ summary_trace : tuple [str ] = ()):
242
242
243
243
if event_has_module (node .event ):
244
244
name = event_module_repr (node .event )
@@ -313,8 +313,8 @@ def build_model_stats_tree_df(node: _ModuleTreeNode,
313
313
self ._model_stats_tree .append (build_model_stats_tree_df (root ))
314
314
315
315
def _flatten_stats_tree (
316
- self , tree : List [_StatsTreeNode ]) -> List [ Tuple [int , StatsEntry ]]:
317
- entries : List [ Tuple [int , StatsEntry ]] = []
316
+ self , tree : list [_StatsTreeNode ]) -> list [ tuple [int , StatsEntry ]]:
317
+ entries : list [ tuple [int , StatsEntry ]] = []
318
318
319
319
def df_traversal (node : _StatsTreeNode , depth = 0 ):
320
320
entries .append ((depth , node .entry ))
@@ -327,10 +327,10 @@ def df_traversal(node: _StatsTreeNode, depth=0):
327
327
return entries
328
328
329
329
def _convert_stats_tree_to_dict (self ,
330
- tree : List [_StatsTreeNode ]) -> List [ Dict ]:
331
- root_dicts : List [ Dict ] = []
330
+ tree : list [_StatsTreeNode ]) -> list [ dict ]:
331
+ root_dicts : list [ dict ] = []
332
332
333
- def df_traversal (node : _StatsTreeNode , curr_json_list : List [ Dict ]):
333
+ def df_traversal (node : _StatsTreeNode , curr_json_list : list [ dict ]):
334
334
curr_json_list .append ({
335
335
"entry" : asdict (node .entry ),
336
336
"children" : []
0 commit comments