Skip to content

Don't run local uint constant indices in C/Python backends #1335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 31, 2025

This optimization is actually a slow down, because the implementations always requset a cast of indices to intp. We used to do it explicitly in AdvancedSubtensor1 but the underlying numpy C function also does it: https://github.com/numpy/numpy/blob/c23886c226012a50da11066d3be98fd94e571101/numpy/_core/src/multiarray/item_selection.c#L247-L251

Local benchmark on a PyMC model that uses AdavncedSubtensor1/Inc shows a slowdown compared to when the optimization is not used.

This PR removes the rewrite from these backends. It also makes the rewrite apply only after specialization to avoid too many passes when other rewrites introduce things like x.shape[i] which will be removed anyway.


📚 Documentation preview 📚: https://pytensor--1335.org.readthedocs.build/en/1335/

Copy link

codecov bot commented Mar 31, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.01%. Comparing base (0b56ed9) to head (5739fd5).
Report is 3 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1335   +/-   ##
=======================================
  Coverage   82.01%   82.01%           
=======================================
  Files         203      203           
  Lines       48805    48798    -7     
  Branches     8688     8685    -3     
=======================================
- Hits        40026    40022    -4     
+ Misses       6627     6625    -2     
+ Partials     2152     2151    -1     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 84.72% <ø> (ø)
pytensor/tensor/rewriting/subtensor.py 89.95% <100.00%> (+0.02%) ⬆️
pytensor/tensor/subtensor.py 89.48% <100.00%> (+0.17%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Indices are always cast to int64 by the underlying methods.

Also don't run in specialize, to reduce number of passes. Other rewrites may introduce temporar indexing operations (such as x.shape[i]) which always default to int64, and it's useless to optimize immediately.
@ricardoV94 ricardoV94 force-pushed the local_uint_constant_indices branch from 3ae469f to 5739fd5 Compare April 1, 2025 15:29
@jessegrabowski jessegrabowski merged commit d9b1085 into pymc-devs:main Apr 1, 2025
73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants