File tree Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Expand file tree Collapse file tree 1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -386,6 +386,18 @@ def test_perform(self):
386
386
)
387
387
388
388
def test_grad (self ):
389
+ if isinstance (self .core_op , Solve ) and config .floatX == "float32" :
390
+ # This tolerance relaxation is needed because of the LU-solve rewrite. Ideally, we shouldn't need it. See
391
+ # discussion here: https://github.com/pymc-devs/pytensor/pull/1396
392
+ atol = 1e-1
393
+ rtol = 1e-4
394
+ elif config .floatX == "float32" :
395
+ atol = 1e-4
396
+ rtol = 1e-5
397
+ else : # config.floatX == "float64"
398
+ atol = 1e-6
399
+ rtol = 1e-7
400
+
389
401
base_inputs = [
390
402
tensor (shape = (None ,) * len (param_sig )) for param_sig in self .params_sig
391
403
]
@@ -414,8 +426,8 @@ def test_grad(self):
414
426
np .testing .assert_allclose (
415
427
pt_out ,
416
428
np_out ,
417
- rtol = 1e-7 if config . floatX == "float64" else 1e-5 ,
418
- atol = 1e-6 if config . floatX == "float64" else 1e-4 ,
429
+ rtol = rtol ,
430
+ atol = atol ,
419
431
)
420
432
421
433
You can’t perform that action at this time.
0 commit comments