Simplify subtensor shape inference #1299
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR simplifies the expressions returned by
Subtensor.infer_shape
. Some examples:This PR also implements some basic minmax algebra rewrites to handle constraints implied by slice inputs. The following shape graph is now as succinct as it can be without knowing the shape of
x
(in which case it would be constant ofc)If you dare check what was the output before:
This is rather important as slices are the most common source of dynamic shapes. We want to make sure these are clever.
This starts working towards fixing #112, it already simplifies standard slice shapes well enough by avoiding the
canonical_form_slice
monster. The next step is for the rewrite to also avoid it, although the number of parametrized combinations between two adjacent slices is no small feat. We may not need to merge allsubtensors
, after all a slice is a pretty cheap operation.The most important thing is to have a good inference on the shape of multiple slices, as the scan save memory rewrite uses that to decide how many steps (cute but not critical) and how many entries to store in the buffer (rather important).
The new rewrites themselves also suggest a way we can start doing more clever type inference stuff. The basics for knowing upper/lower bounding are here. It probably makes sense to offer this as a feature, but for now they are just used two of the new rewrites.
Related Issue
local_subtensor_merge
can complicate graphs #112Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1299.org.readthedocs.build/en/1299/