@@ -102,6 +102,53 @@ def bounded_method(*args, **kwargs):
102
102
return dec_for_method
103
103
104
104
105
+ def log_perf_before_after (pass_ : PassFunc ) -> PassFunc :
106
+ """
107
+ Wraps a pass function to log perf of the module before and after the pass
108
+ """
109
+
110
+ @wraps (pass_ )
111
+ def check_perf_with_before_after_log (
112
+ module : fx .GraphModule , input : Input
113
+ ) -> fx .GraphModule :
114
+ def benchmark_torch_function (iters : int , f , * args ) -> float :
115
+ """Estimates the average time duration for a single inference call in second
116
+
117
+ If the input is batched, then the estimation is for the batches inference call.
118
+
119
+ Args:
120
+ iters: number of inference iterations to run
121
+ f: a function to perform a single inference call
122
+
123
+ Returns:
124
+ estimated average time duration in second for a single inference call
125
+ """
126
+ with torch .inference_mode ():
127
+ f (* args )
128
+ torch .cuda .synchronize ()
129
+ start_event = torch .cuda .Event (enable_timing = True )
130
+ end_event = torch .cuda .Event (enable_timing = True )
131
+ # print("== Start benchmark iterations")
132
+ with torch .inference_mode ():
133
+ start_event .record ()
134
+ for _ in range (iters ):
135
+ f (* args )
136
+ end_event .record ()
137
+ torch .cuda .synchronize ()
138
+ # print("== End benchmark iterations")
139
+ return (start_event .elapsed_time (end_event ) * 1.0e-3 ) / iters
140
+
141
+ time_before = benchmark_torch_function (100 , lambda : module (* input ))
142
+ _LOGGER .info (f"[{ pass_ } ] Perf Before(eager mode): { time_before } " )
143
+
144
+ module = pass_ (module , input )
145
+ time_after = benchmark_torch_function (100 , lambda : module (* input ))
146
+ _LOGGER .info (f"[{ pass_ } ] Perf After(eager mode): { time_after } " )
147
+ return module
148
+
149
+ return check_perf_with_before_after_log
150
+
151
+
105
152
def log_before_after (pass_ : PassFunc ) -> PassFunc :
106
153
"""
107
154
Wraps a pass function to log the module graph before and after the pass
0 commit comments