-
Notifications
You must be signed in to change notification settings - Fork 135
Add numba overload for Nonzero #1289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add numba overload for Nonzero #1289
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not how a numba dispatch is implemented. Please review the relevant documentation if you'd like to tackle this issue: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
Hello, I have tried implementing the backend for numba using the documentation and have written some tests too. Would be grateful if you could review them and give feedback |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1289 +/- ##
==========================================
+ Coverage 81.99% 82.00% +0.01%
==========================================
Files 188 188
Lines 48553 48488 -65
Branches 8673 8666 -7
==========================================
- Hits 39812 39765 -47
+ Misses 6579 6575 -4
+ Partials 2162 2148 -14
🚀 New features to boost your workflow:
|
@jessegrabowski Could you please review my changes? |
if a.ndim == 1: | ||
indices = np.where(a != 0)[0] | ||
return indices.astype(np.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ndim == 1
is not a special case in the C-backend. In that case, you get a 1-tuple. All backends should return the same thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The modification returns also 1-tuple because this is already implemented in the nonzero
function itself
class Nonzero(Op):
"""
Return the indices of the elements that are non-zero.
Parameters
----------
a: array_like
Input array.
Returns
-------
indices: list
A list containing the indices of the non-zero elements of `a`.
See Also
--------
nonzero_values : Return the non-zero elements of the input array
flatnonzero : Return the indices of the non-zero elements of the
flattened input array.
"""
__props__ = ()
def make_node(self, a):
a = as_tensor_variable(a)
if a.ndim == 0:
raise ValueError("Nonzero only supports non-scalar arrays.")
output = [TensorType(dtype="int64", shape=(None,))() for i in range(a.ndim)]
return Apply(self, [a], output)
def perform(self, node, inp, out_):
a = inp[0]
result_tuple = np.nonzero(a)
for i, res in enumerate(result_tuple):
out_[i][0] = res.astype("int64")
def grad(self, inp, grads):
return [grad_undefined(self, 0, inp[0])]
_nonzero = Nonzero()
def nonzero(a, return_matrix=False):
"""
Returns one of the following:
If return_matrix is False (default, same as NumPy):
A tuple of vector arrays such that the ith element of the jth array
is the index of the ith non-zero element of the input array in the
jth dimension.
If return_matrix is True (same as PyTensor Op):
Returns a matrix of shape (ndim, number of nonzero elements) such
that element (i,j) is the index in the ith dimension of the jth
non-zero element.
Parameters
----------
a : array_like
Input array.
return_matrix : bool
If True, returns a symbolic matrix. If False, returns a tuple of
arrays. Defaults to False.
Returns
-------
tuple of vectors or matrix
See Also
--------
nonzero_values : Return the non-zero elements of the input array
flatnonzero : Return the indices of the non-zero elements of the
flattened input array.
"""
res = _nonzero(a)
if isinstance(res, list):
res = tuple(res)
else:
res = (res,)
if return_matrix:
if len(res) > 1:
return stack(res, 0)
elif len(res) == 1:
return shape_padleft(res[0])
else:
return res
If we do not handle the ndim = 1 case seperately, for the array [1,2,0] the result would be (([2],),) which is a tuple of tuple of lists. Instead it should be a tuple of lists which is what we get with the modification. I have replaced this modification with :
if(a.ndim == 1):
return result_tuple[0]
for efficiency.
if a.ndim == 0: | ||
raise ValueError("Nonzero only supports non-scalar arrays.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This input validation is done by the nonzero
Op itself, there's no need to repeat it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for this
Description
Added a numba overload in
Nonzero
by adding the@nb.njit
decorator asnp.nonzero
is supported bynumba
already.Related Issue
Nonzero
#1279Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1289.org.readthedocs.build/en/1289/