@@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
126
126
# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
127
127
# on pass that failed accuracy check.
128
128
def validate_inference (
129
- rtol = None , atol = None , device = torch .device (torch .cuda .current_device ())
129
+ rtol = None ,
130
+ atol = None ,
131
+ device = torch .device (torch .cuda .current_device ()),
132
+ suppress_accuracy_check_failure = True ,
130
133
):
131
134
def _validate_inference (pass_ : PassFunc ) -> PassFunc :
132
135
"""
@@ -141,48 +144,51 @@ def pass_with_validation(
141
144
* args ,
142
145
** kwargs ,
143
146
) -> fx .GraphModule :
144
- input_tensors = extract_example_tensors_from_input (input , device )
145
- res0 = module (* input_tensors )
146
- processed_module = pass_ (module , input , * args , ** kwargs )
147
- res1 = processed_module (* input_tensors )
148
- tensor_res_0 = _collect_tensors (res0 )
149
- tensor_res_1 = _collect_tensors (res1 )
150
- relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
151
-
152
- for kk , (x , y ) in enumerate (zip (tensor_res_0 , tensor_res_1 )):
153
- kwargs2 = {"equal_nan" : True }
154
- if rtol :
155
- kwargs2 ["rtol" ] = rtol
156
- if atol :
157
- kwargs2 ["atol" ] = atol
158
- kwargs2 [
159
- "msg"
160
- ] = (
161
- lambda msg : f"Pass { pass_ } failed correctness check due at output { kk } :\n { msg } "
162
- )
163
- # If tensors are on different devices, make sure to compare
164
- # their copies that are on the same device.
165
- if x .get_device () != y .get_device ():
166
- x = x .cpu ()
167
- y = y .cpu ()
168
- try :
169
- torch .testing .assert_close (x , y , ** kwargs2 )
170
- except Exception as e :
171
- if relax_accuracy_check_failure :
172
- _LOGGER .error (f"{ e } " )
173
- kwargs2 ["rtol" ] *= FINAL_CHECK_RTOL_MULTIPLIER
174
- kwargs2 ["atol" ] *= FINAL_CHECK_ATOL_MULTIPLIER
175
- new_atol = kwargs2 ["atol" ]
176
- new_rtol = kwargs2 ["rtol" ]
177
- _LOGGER .info (
178
- f"Do a sanity check to see whether things are completely wrong with { new_atol = } , { new_rtol = } "
179
- )
147
+ if suppress_accuracy_check_failure :
148
+ return pass_ (module , input , * args , ** kwargs )
149
+ else :
150
+ input_tensors = extract_example_tensors_from_input (input , device )
151
+ res0 = module (* input_tensors )
152
+ processed_module = pass_ (module , input , * args , ** kwargs )
153
+ res1 = processed_module (* input_tensors )
154
+ tensor_res_0 = _collect_tensors (res0 )
155
+ tensor_res_1 = _collect_tensors (res1 )
156
+ relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
157
+
158
+ for kk , (x , y ) in enumerate (zip (tensor_res_0 , tensor_res_1 )):
159
+ kwargs2 = {"equal_nan" : True }
160
+ if rtol :
161
+ kwargs2 ["rtol" ] = rtol
162
+ if atol :
163
+ kwargs2 ["atol" ] = atol
164
+ kwargs2 [
165
+ "msg"
166
+ ] = (
167
+ lambda msg : f"Pass { pass_ } failed correctness check due at output { kk } :\n { msg } "
168
+ )
169
+ # If tensors are on different devices, make sure to compare
170
+ # their copies that are on the same device.
171
+ if x .get_device () != y .get_device ():
172
+ x = x .cpu ()
173
+ y = y .cpu ()
174
+ try :
180
175
torch .testing .assert_close (x , y , ** kwargs2 )
181
- return processed_module
182
- else :
183
- raise e
184
-
185
- return processed_module
176
+ except Exception as e :
177
+ if relax_accuracy_check_failure :
178
+ _LOGGER .error (f"{ e } " )
179
+ kwargs2 ["rtol" ] *= FINAL_CHECK_RTOL_MULTIPLIER
180
+ kwargs2 ["atol" ] *= FINAL_CHECK_ATOL_MULTIPLIER
181
+ new_atol = kwargs2 ["atol" ]
182
+ new_rtol = kwargs2 ["rtol" ]
183
+ _LOGGER .info (
184
+ f"Do a sanity check to see whether things are completely wrong with { new_atol = } , { new_rtol = } "
185
+ )
186
+ torch .testing .assert_close (x , y , ** kwargs2 )
187
+ return processed_module
188
+ else :
189
+ raise e
190
+
191
+ return processed_module
186
192
187
193
return pass_with_validation
188
194
0 commit comments