Skip to content

Commit 2d477c4

Browse files
committed
Add debug mode to Numba linker
1 parent 5d4e9e0 commit 2d477c4

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

Diff for: pytensor/link/numba/dispatch/basic.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -448,16 +448,46 @@ def opfromgraph(*inputs):
448448
return opfromgraph
449449

450450

451+
def numba_funcify_debug(op, node, **kwargs):
452+
numba_fun = numba_funcify(op, node=node, **kwargs)
453+
454+
if node is None:
455+
return numba_fun
456+
457+
args = ", ".join([f"i{i}" for i in range(len(node.inputs))])
458+
str_op = str(op)
459+
460+
f_source = dedent(
461+
f"""
462+
def foo({args}):
463+
print("\\nOp: ", "{str_op}")
464+
print(" inputs: ", {args})
465+
outs = numba_fun({args})
466+
print(" outputs: ", outs)
467+
return outs
468+
"""
469+
)
470+
471+
f = compile_function_src(
472+
f_source,
473+
"foo",
474+
{**globals(), **{"numba_fun": numba_fun}},
475+
)
476+
477+
return numba_njit(f)
478+
479+
451480
@numba_funcify.register(FunctionGraph)
452481
def numba_funcify_FunctionGraph(
453482
fgraph,
454483
node=None,
455484
fgraph_name="numba_funcified_fgraph",
485+
op_conversion_fn=numba_funcify,
456486
**kwargs,
457487
):
458488
return fgraph_to_python(
459489
fgraph,
460-
numba_funcify,
490+
op_conversion_fn=op_conversion_fn,
461491
type_conversion_fn=numba_typify,
462492
fgraph_name=fgraph_name,
463493
**kwargs,

Diff for: pytensor/link/numba/linker.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@
44
class NumbaLinker(JITLinker):
55
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
66

7+
def __init__(self, *args, debug: bool = False, **kwargs):
8+
super().__init__(*args, **kwargs)
9+
self.debug = debug
10+
711
def fgraph_convert(self, fgraph, **kwargs):
812
from pytensor.link.numba.dispatch import numba_funcify
913

14+
if self.debug:
15+
from pytensor.link.numba.dispatch.basic import numba_funcify_debug
16+
17+
kwargs.setdefault("op_conversion_fn", numba_funcify_debug)
1018
return numba_funcify(fgraph, **kwargs)
1119

1220
def jit_compile(self, fn):

Diff for: tests/link/numba/test_linker.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
from pytensor import function
6+
from pytensor.compile.mode import Mode
7+
from pytensor.link.numba import NumbaLinker
8+
from pytensor.tensor import vector
9+
10+
11+
pytest.importorskip("numba")
12+
13+
14+
def test_debug_mode(capsys):
15+
x = vector("x")
16+
y = (x + 1).sum()
17+
18+
debug_mode = Mode(linker=NumbaLinker(debug=True))
19+
fn = function([x], y, mode=debug_mode)
20+
21+
assert fn([0, 1]) == 3.0
22+
captured = capsys.readouterr()
23+
assert captured.out == dedent(
24+
"""
25+
Op: Add
26+
inputs: [1.] [0. 1.]
27+
outputs: [1. 2.]
28+
29+
Op: Sum{axes=None}
30+
inputs: [1. 2.]
31+
outputs: 3.0
32+
"""
33+
)

0 commit comments

Comments
 (0)