Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify subtensor shape inference #1299

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 17, 2025

This PR simplifies the expressions returned by Subtensor.infer_shape. Some examples:

This PR also implements some basic minmax algebra rewrites to handle constraints implied by slice inputs. The following shape graph is now as succinct as it can be without knowing the shape of x (in which case it would be constant ofc)

import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
y = x[1:-1][1:-1][1:-1].shape[0]
fn = pytensor.function([x], y)
fn.dprint()
Composite{maximum(0, (-6 + i0))} [id A] 1
 └─ Shape_i{0} [id B] 0
    └─ x [id C]

Inner graphs:

Composite{maximum(0, (-6 + i0))} [id A]
 ← maximum [id D] 'o0'
    ├─ 0 [id E]
    └─ add [id F]
       ├─ -6 [id G]
       └─ i0 [id H]

If you dare check what was the output before:

Composite{...} [id A] 1
 └─ Shape_i{0} [id B] 0
    └─ x [id C]
Inner graphs:
Composite{...} [id A]
 ← sub [id D] 'o0'
    ├─ Switch [id E] 't31'
    │  ├─ LT [id F]
    │  │  ├─ Switch [id G] 't17'
    │  │  │  ├─ GE [id H]
    │  │  │  │  ├─ Switch [id I] 't2'
    │  │  │  │  │  ├─ LT [id J]
    │  │  │  │  │  │  ├─ add [id K] 't37'
    │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │  └─ sub [id M] 't34'
    │  │  │  │  │  │  │     ├─ Switch [id N] 't33'
    │  │  │  │  │  │  │     │  ├─ LT [id O]
    │  │  │  │  │  │  │     │  │  ├─ Switch [id P] 't19'
    │  │  │  │  │  │  │     │  │  │  ├─ GE [id Q]
    │  │  │  │  │  │  │     │  │  │  │  ├─ Switch [id R] 't5'
    │  │  │  │  │  │  │     │  │  │  │  │  ├─ LT [id S]
    │  │  │  │  │  │  │     │  │  │  │  │  │  ├─ add [id T] 't39'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     ├─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  ├─ LT [id W]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  ├─ Switch [id X] 't21'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  ├─ GE [id Y]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  ├─ Switch [id Z] 't12'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  ├─ LT [id BA]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  ├─ add [id BB] 't4'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  └─ add [id BB] 't4'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  ├─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  └─ Switch [id Z] 't12'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  └─ Switch [id X] 't21'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     └─ Switch [id BF]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        ├─ LT [id BG]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  ├─ Switch [id BH] 't15'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  ├─ LT [id BI]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  ├─ Switch [id BJ] 't45'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  ├─ GE [id BK]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  │  ├─ 1 [id BL]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  ├─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  └─ 1 [id BL]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  └─ Switch [id BJ] 't45'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  └─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        ├─ Switch [id BH] 't15'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        └─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │           └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  │  │     │  │  │  │  │  └─ add [id T] 't39'
    │  │  │  │  │  │  │     │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  ├─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │  └─ ···
    │  │  │  │  │  │  │     │  │  │  └─ Switch [id R] 't5'
    │  │  │  │  │  │  │     │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  └─ Switch [id P] 't19'
    │  │  │  │  │  │  │     │     └─ ···
    │  │  │  │  │  │  │     └─ Switch [id BM]
    │  │  │  │  │  │  │        ├─ LT [id BN]
    │  │  │  │  │  │  │        │  ├─ Switch [id BO] 't13'
    │  │  │  │  │  │  │        │  │  ├─ LT [id BP]
    │  │  │  │  │  │  │        │  │  │  ├─ Switch [id BQ] 't43'
    │  │  │  │  │  │  │        │  │  │  │  ├─ GE [id BR]
    │  │  │  │  │  │  │        │  │  │  │  │  ├─ 1 [id BL]
    │  │  │  │  │  │  │        │  │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │        │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │        │  │  │  │  ├─ sub [id U] 't36'
    │  │  │  │  │  │  │        │  │  │  │  │  └─ ···
    │  │  │  │  │  │  │        │  │  │  │  └─ 1 [id BL]
    │  │  │  │  │  │  │        │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │        │  │  ├─ 0 [id BD]
    │  │  │  │  │  │  │        │  │  └─ Switch [id BQ] 't43'
    │  │  │  │  │  │  │        │  │     └─ ···
    │  │  │  │  │  │  │        │  └─ Switch [id N] 't33'
    │  │  │  │  │  │  │        │     └─ ···
    │  │  │  │  │  │  │        ├─ Switch [id BO] 't13'
    │  │  │  │  │  │  │        │  └─ ···
    │  │  │  │  │  │  │        └─ Switch [id N] 't33'
    │  │  │  │  │  │  │           └─ ···
    │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  └─ add [id K] 't37'
    │  │  │  │  │     └─ ···
    │  │  │  │  └─ sub [id M] 't34'
    │  │  │  │     └─ ···
    │  │  │  ├─ sub [id M] 't34'
    │  │  │  │  └─ ···
    │  │  │  └─ Switch [id I] 't2'
    │  │  │     └─ ···
    │  │  └─ 0 [id BD]
    │  ├─ 0 [id BD]
    │  └─ Switch [id G] 't17'
    │     └─ ···
    └─ Switch [id BS]
       ├─ LT [id BT]
       │  ├─ Switch [id BU] 't9'
       │  │  ├─ LT [id BV]
       │  │  │  ├─ Switch [id BW] 't42'
       │  │  │  │  ├─ GE [id BX]
       │  │  │  │  │  ├─ 1 [id BL]
       │  │  │  │  │  └─ sub [id M] 't34'
       │  │  │  │  │     └─ ···
       │  │  │  │  ├─ sub [id M] 't34'
       │  │  │  │  │  └─ ···
       │  │  │  │  └─ 1 [id BL]
       │  │  │  └─ 0 [id BD]
       │  │  ├─ 0 [id BD]
       │  │  └─ Switch [id BW] 't42'
       │  │     └─ ···
       │  └─ Switch [id E] 't31'
       │     └─ ···
       ├─ Switch [id BU] 't9'
       │  └─ ···
       └─ Switch [id E] 't31'
          └─ ···

This is rather important as slices are the most common source of dynamic shapes. We want to make sure these are clever.

This starts working towards fixing #112, it already simplifies standard slice shapes well enough by avoiding the canonical_form_slice monster. The next step is for the rewrite to also avoid it, although the number of parametrized combinations between two adjacent slices is no small feat. We may not need to merge all subtensors, after all a slice is a pretty cheap operation.

The most important thing is to have a good inference on the shape of multiple slices, as the scan save memory rewrite uses that to decide how many steps (cute but not critical) and how many entries to store in the buffer (rather important).

The new rewrites themselves also suggest a way we can start doing more clever type inference stuff. The basics for knowing upper/lower bounding are here. It probably makes sense to offer this as a feature, but for now they are just used two of the new rewrites.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1299.org.readthedocs.build/en/1299/

@ricardoV94 ricardoV94 changed the title Simplify subtensor shape Simplify subtensor shape inference Mar 17, 2025
@ricardoV94 ricardoV94 force-pushed the simplify_subtensor_shape branch 4 times, most recently from a2d4aca to c9e539d Compare March 21, 2025 15:48
@ricardoV94 ricardoV94 force-pushed the simplify_subtensor_shape branch from c9e539d to 7923437 Compare March 25, 2025 10:20
@ricardoV94 ricardoV94 force-pushed the simplify_subtensor_shape branch from 7923437 to 9a9fbbc Compare March 25, 2025 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant