File tree 3 files changed +72
-1
lines changed
3 files changed +72
-1
lines changed Original file line number Diff line number Diff line change @@ -448,16 +448,46 @@ def opfromgraph(*inputs):
448
448
return opfromgraph
449
449
450
450
451
+ def numba_funcify_debug (op , node , ** kwargs ):
452
+ numba_fun = numba_funcify (op , node = node , ** kwargs )
453
+
454
+ if node is None :
455
+ return numba_fun
456
+
457
+ args = ", " .join ([f"i{ i } " for i in range (len (node .inputs ))])
458
+ str_op = str (op )
459
+
460
+ f_source = dedent (
461
+ f"""
462
+ def foo({ args } ):
463
+ print("\\ nOp: ", "{ str_op } ")
464
+ print(" inputs: ", { args } )
465
+ outs = numba_fun({ args } )
466
+ print(" outputs: ", outs)
467
+ return outs
468
+ """
469
+ )
470
+
471
+ f = compile_function_src (
472
+ f_source ,
473
+ "foo" ,
474
+ {** globals (), ** {"numba_fun" : numba_fun }},
475
+ )
476
+
477
+ return numba_njit (f )
478
+
479
+
451
480
@numba_funcify .register (FunctionGraph )
452
481
def numba_funcify_FunctionGraph (
453
482
fgraph ,
454
483
node = None ,
455
484
fgraph_name = "numba_funcified_fgraph" ,
485
+ op_conversion_fn = numba_funcify ,
456
486
** kwargs ,
457
487
):
458
488
return fgraph_to_python (
459
489
fgraph ,
460
- numba_funcify ,
490
+ op_conversion_fn = op_conversion_fn ,
461
491
type_conversion_fn = numba_typify ,
462
492
fgraph_name = fgraph_name ,
463
493
** kwargs ,
Original file line number Diff line number Diff line change 4
4
class NumbaLinker (JITLinker ):
5
5
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
6
6
7
+ def __init__ (self , * args , debug : bool = False , ** kwargs ):
8
+ super ().__init__ (* args , ** kwargs )
9
+ self .debug = debug
10
+
7
11
def fgraph_convert (self , fgraph , ** kwargs ):
8
12
from pytensor .link .numba .dispatch import numba_funcify
9
13
14
+ if self .debug :
15
+ from pytensor .link .numba .dispatch .basic import numba_funcify_debug
16
+
17
+ kwargs .setdefault ("op_conversion_fn" , numba_funcify_debug )
10
18
return numba_funcify (fgraph , ** kwargs )
11
19
12
20
def jit_compile (self , fn ):
Original file line number Diff line number Diff line change
1
+ from textwrap import dedent
2
+
3
+ import pytest
4
+
5
+ from pytensor import function
6
+ from pytensor .compile .mode import Mode
7
+ from pytensor .link .numba import NumbaLinker
8
+ from pytensor .tensor import vector
9
+
10
+
11
+ pytest .importorskip ("numba" )
12
+
13
+
14
+ def test_debug_mode (capsys ):
15
+ x = vector ("x" )
16
+ y = (x + 1 ).sum ()
17
+
18
+ debug_mode = Mode (linker = NumbaLinker (debug = True ))
19
+ fn = function ([x ], y , mode = debug_mode )
20
+
21
+ assert fn ([0 , 1 ]) == 3.0
22
+ captured = capsys .readouterr ()
23
+ assert captured .out == dedent (
24
+ """
25
+ Op: Add
26
+ inputs: [1.] [0. 1.]
27
+ outputs: [1. 2.]
28
+
29
+ Op: Sum{axes=None}
30
+ inputs: [1. 2.]
31
+ outputs: 3.0
32
+ """
33
+ )
You can’t perform that action at this time.
0 commit comments