Skip to content

Add debug mode to Numba linker #1234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,16 +448,46 @@
return opfromgraph


def numba_funcify_debug(op, node, **kwargs):
numba_fun = numba_funcify(op, node=node, **kwargs)

if node is None:
return numba_fun

Check warning on line 455 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L455

Added line #L455 was not covered by tests

args = ", ".join([f"i{i}" for i in range(len(node.inputs))])
str_op = str(op)

f_source = dedent(
f"""
def foo({args}):
print("\\nOp: ", "{str_op}")
print(" inputs: ", {args})
outs = numba_fun({args})
print(" outputs: ", outs)
return outs
"""
)

f = compile_function_src(
f_source,
"foo",
{**globals(), **{"numba_fun": numba_fun}},
)

return numba_njit(f)


@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="numba_funcified_fgraph",
op_conversion_fn=numba_funcify,
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
op_conversion_fn=op_conversion_fn,
type_conversion_fn=numba_typify,
fgraph_name=fgraph_name,
**kwargs,
Expand Down
8 changes: 8 additions & 0 deletions pytensor/link/numba/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""

def __init__(self, *args, debug: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.debug = debug

def fgraph_convert(self, fgraph, **kwargs):
from pytensor.link.numba.dispatch import numba_funcify

if self.debug:
from pytensor.link.numba.dispatch.basic import numba_funcify_debug

kwargs.setdefault("op_conversion_fn", numba_funcify_debug)
return numba_funcify(fgraph, **kwargs)

def jit_compile(self, fn):
Expand Down
33 changes: 33 additions & 0 deletions tests/link/numba/test_linker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from textwrap import dedent

import pytest

from pytensor import function
from pytensor.compile.mode import Mode
from pytensor.link.numba import NumbaLinker
from pytensor.tensor import vector


pytest.importorskip("numba")


def test_debug_mode(capsys):
x = vector("x")
y = (x + 1).sum()

debug_mode = Mode(linker=NumbaLinker(debug=True))
fn = function([x], y, mode=debug_mode)

assert fn([0, 1]) == 3.0
captured = capsys.readouterr()
assert captured.out == dedent(
"""
Op: Add
inputs: [1.] [0. 1.]
outputs: [1. 2.]

Op: Sum{axes=None}
inputs: [1. 2.]
outputs: 3.0
"""
)