Skip to content

Support more cases of numba advanced indexing #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 27, 2025

Support more cases of multi-dimensional advanced indexing and updating in Numba

Extends pre-existing rewrite to ravel multidimensional integer indices, to handle multiple inputs (if no broadcasting is needed) and to place them consecutively if they were not. It also extends it to Set/IncSubtensor as long as y is not broadcasted and advanced indices are consecutive.

The following cases should now be supported without object mode:

  • Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices
  • Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.

Also fixes bug in infer_shape of AdvancedIndexing with slices (which were mistakenly treated as NoneSlices)

Example of new kind of graphs that are supported.

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.tensor("x", shape=(10, 10, 3))
inds = pt.matrix("inds", shape=(50, 2), dtype=int)
y = x[inds, inds]
x_test = np.zeros((10, 10, 3)) 
inds_test = np.ones((50, 2), dtype=int)

fn = pytensor.function([x, inds], [y, pt.grad(y.sum(), x)], mode="NUMBA")
fn.dprint(print_shape=True)
fn.trust_input = True
print(fn(x_test, inds_test)[0].shape)
%timeit fn(x_test, inds_test)

On my machine that runs in 8us, and before with object mode it was 60us.

A case with a single matrix advanced indexing shows up in the logp of the Categorical, the gradient of which was not supported by numba without object mode before and now is.


📚 Documentation preview 📚: https://pytensor--1254.org.readthedocs.build/en/1254/

@ricardoV94 ricardoV94 added bug Something isn't working enhancement New feature or request numba indexing labels Feb 27, 2025
@ricardoV94 ricardoV94 requested review from aseyboldt and jessegrabowski and removed request for aseyboldt and jessegrabowski February 27, 2025 19:23
@ricardoV94 ricardoV94 force-pushed the more_numba_advanced_indexing branch 2 times, most recently from d537caf to a1fc205 Compare February 28, 2025 08:19
Copy link

codecov bot commented Feb 28, 2025

Codecov Report

Attention: Patch coverage is 86.95652% with 9 lines in your changes missing coverage. Please review.

Project coverage is 81.99%. Comparing base (2a7f3e1) to head (d1f18a1).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/subtensor.py 77.77% 5 Missing and 1 partial ⚠️
pytensor/tensor/rewriting/subtensor.py 92.85% 1 Missing and 2 partials ⚠️

❌ Your patch status has failed because the patch coverage (86.95%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1254      +/-   ##
==========================================
- Coverage   81.99%   81.99%   -0.01%     
==========================================
  Files         188      188              
  Lines       48553    48600      +47     
  Branches     8673     8685      +12     
==========================================
+ Hits        39812    39849      +37     
- Misses       6579     6586       +7     
- Partials     2162     2165       +3     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/subtensor.py 95.34% <ø> (ø)
pytensor/tensor/rewriting/subtensor.py 90.12% <92.85%> (-0.13%) ⬇️
pytensor/tensor/subtensor.py 89.35% <77.77%> (-0.37%) ⬇️

@ricardoV94 ricardoV94 force-pushed the more_numba_advanced_indexing branch from a1fc205 to ec99fca Compare February 28, 2025 09:24
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some typos and questions

…g in Numba

Extends pre-existing rewrite to ravel multiple integer indices, and to place them consecutively. The following cases should now be supported without object mode:
* Advanced integer indexing (not mixed with basic or boolean indexing) that do not require broadcasting of indices
* Consecutive advanced integer indexing updating (set/inc) (not mixed with basic or boolean indexing) that do not require broadcasting of indices or y.
@ricardoV94 ricardoV94 force-pushed the more_numba_advanced_indexing branch from ec99fca to cd649ab Compare March 1, 2025 08:33
@ricardoV94 ricardoV94 force-pushed the more_numba_advanced_indexing branch from a0d9ecf to 8ad0812 Compare March 3, 2025 16:03
Started failing in 0.4.36: jax-ml/jax#26888

Skip failing JAX test

Started failing in 0.4.36 and seems to be fixed in 0.5.1
@ricardoV94 ricardoV94 force-pushed the more_numba_advanced_indexing branch from 8ad0812 to d1f18a1 Compare March 3, 2025 16:45
@ricardoV94 ricardoV94 merged commit 757a10c into pymc-devs:main Mar 3, 2025
71 checks passed
@ricardoV94 ricardoV94 deleted the more_numba_advanced_indexing branch March 3, 2025 17:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request indexing numba
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants