Skip to content

Add numba overload for Nonzero #1289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2025

Conversation

Abhinav-Khot
Copy link
Contributor

@Abhinav-Khot Abhinav-Khot commented Mar 12, 2025

Description

Added a numba overload in Nonzero by adding the @nb.njit decorator as np.nonzero is supported by numba already.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1289.org.readthedocs.build/en/1289/

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how a numba dispatch is implemented. Please review the relevant documentation if you'd like to tackle this issue: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

@Abhinav-Khot
Copy link
Contributor Author

This is not how a numba dispatch is implemented. Please review the relevant documentation if you'd like to tackle this issue: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

Hello,

I have tried implementing the backend for numba using the documentation and have written some tests too. Would be grateful if you could review them and give feedback

Copy link

codecov bot commented Mar 12, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.00%. Comparing base (2a7f3e1) to head (b336df9).
Report is 36 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1289      +/-   ##
==========================================
+ Coverage   81.99%   82.00%   +0.01%     
==========================================
  Files         188      188              
  Lines       48553    48488      -65     
  Branches     8673     8666       -7     
==========================================
- Hits        39812    39765      -47     
+ Misses       6579     6575       -4     
+ Partials     2162     2148      -14     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 78.50% <100.00%> (-0.59%) ⬇️

... and 27 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Abhinav-Khot
Copy link
Contributor Author

@jessegrabowski Could you please review my changes?

Comment on lines 759 to 761
if a.ndim == 1:
indices = np.where(a != 0)[0]
return indices.astype(np.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim == 1 is not a special case in the C-backend. In that case, you get a 1-tuple. All backends should return the same thing.

Copy link
Contributor Author

@Abhinav-Khot Abhinav-Khot Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification returns also 1-tuple because this is already implemented in the nonzero function itself

class Nonzero(Op):
    """
    Return the indices of the elements that are non-zero.

    Parameters
    ----------
    a: array_like
        Input array.

    Returns
    -------
    indices: list
        A list containing the indices of the non-zero elements of `a`.

    See Also
    --------
    nonzero_values : Return the non-zero elements of the input array
    flatnonzero : Return the indices of the non-zero elements of the
        flattened input array.

    """

    __props__ = ()

    def make_node(self, a):
        a = as_tensor_variable(a)
        if a.ndim == 0:
            raise ValueError("Nonzero only supports non-scalar arrays.")
        output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)]
        return Apply(self, [a], output)

    def perform(self, node, inp, out_):
        a = inp[0]

        result_tuple = np.nonzero(a)
        for i, res in enumerate(result_tuple):
            out_[i][0] = res.astype("int64")

    def grad(self, inp, grads):
        return [grad_undefined(self, 0, inp[0])]


_nonzero = Nonzero()


def nonzero(a, return_matrix=False):
    """
    Returns one of the following:

        If return_matrix is False (default, same as NumPy):
            A tuple of vector arrays such that the ith element of the jth array
            is the index of the ith non-zero element of the input array in the
            jth dimension.

        If return_matrix is True (same as PyTensor Op):
            Returns a matrix of shape (ndim, number of nonzero elements) such
            that element (i,j) is the index in the ith dimension of the jth
            non-zero element.

    Parameters
    ----------
    a : array_like
        Input array.
    return_matrix : bool
        If True, returns a symbolic matrix. If False, returns a tuple of
        arrays. Defaults to False.

    Returns
    -------
    tuple of vectors or matrix

    See Also
    --------
    nonzero_values : Return the non-zero elements of the input array
    flatnonzero : Return the indices of the non-zero elements of the
        flattened input array.

    """
    res = _nonzero(a)
    if isinstance(res, list):
        res = tuple(res)
    else:
        res = (res,)

    if return_matrix:
        if len(res) > 1:
            return stack(res, 0)
        elif len(res) == 1:
            return shape_padleft(res[0])
    else:
        return res

If we do not handle the ndim = 1 case seperately, for the array [1,2,0] the result would be (([2],),) which is a tuple of tuple of lists. Instead it should be a tuple of lists which is what we get with the modification. I have replaced this modification with :

if(a.ndim == 1):
       return result_tuple[0]

for efficiency.

Comment on lines 754 to 755
if a.ndim == 0:
raise ValueError("Nonzero only supports non-scalar arrays.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This input validation is done by the nonzero Op itself, there's no need to repeat it here.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for this

@jessegrabowski jessegrabowski merged commit b75c18f into pymc-devs:main Mar 21, 2025
73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add numba overload for Nonzero
2 participants