Skip to content

Speedup Scan in different backends #1281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 10, 2025

This PR does a bunch of small optimizations for the Scan, mostly in the Numba backend, but to a lesser extent also in the C and JAX backends.

Tweaks:

  1. Fix failure in trimming number of steps in Scan when the new n_step would be a constant integer. The isinstance(x, int) does not work for np.integer.

  2. Do not try to reduce buffer size just to save the space occupied by the initial values. Scan is always created with a buffer size for the recurring inputs with a size of taps + steps. The scan returned to the user slices away the taps, so the user doesn't see the initial values there. When the user wants the whole returned trace (instead of say only the last state) AND no gradient is requested (the gradient makes use of the initial taps in the trace), the Scan save memory would trim into a buffer with size steps. This however, always requires a roll at the end of the Scan, because the discarded entries (the taps) are at the beginning. After this PR the scan save memory doesn't optimize this specific case, as the default slice is likely better than the roll which requires a copy. Benchmarks confirm this.

  3. Be more aggressive in the buffer memory save in the JIT backends. As described in Missed scan rewrites #787 Scan is keeping one more entry in the circular buffer for recurring states than strictly needed. This happens because the Python-C implementation make clever use of the inner compiled pytensor function to manually specify the output buffers. In this case it's not always safe for PyTensor to override the buffer, because the inputs may still be needed to compute other outputs. An extra entry was provided for the output. This is not a concern in the JIT backends because there's no such machinery to exploit. As such the scan save memory rewrite was split to be more aggressive in the JIT backends.

  4. Minor tweaks in the Numba dispatch of Scan to not try to roll when the buffer size is already aligned. I didn't see any stable changes but seems more clean. Also set checkbounds explicitly to False, but I think that may be the deault. Neither mattered in my benchmarks

  5. Allow inplacing in the inner function for sitsot and the last mitsot when our buffer is so small that the last tap gets discarded immediately anyway (happens when users asks for trace[-1] of a scan). This provides substantial speedups in some cases. We should do the same for mit-mots but I haven't yet groked how they are supposed to behave. Issue to track it Allow inner-function inplace of mit-mot in Scan backend #1282

Note this inplace doesn't work right now when n_steps is symbolic as it relies on PyTensor having figured out the static shape of the buffer matches the size of the taps. We need more work for this: #1282

  1. Try to simplify the symbolic graph for the bufer/n-steps size when the scan doesn't have a constant nsteps. The graph is a mess, due to how subtensor_merge works: local_subtensor_merge can complicate graphs  #112. This is mostly because we can not know if the number of steps will be zero in which case it's invalid for the user to select trace[-1]. The graph is overly complicated, it is perhaps better to just add an assert that n_steps > 0 like we do for the while scan optimization. Anyway, the graph was even worse before this PR because of a bunch of ScalarFromTensor and TensorFromScalar that would show up in the buffer size logic. This PR cleans it up, so at least everything happens in one single fused graph. Graph before for the mitsot test:
Before
Subtensor{i} [id A] 25
 ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 24
 │  ├─ Composite{...}.0 [id C] 16
 │  │  ├─ Composite{...}.0 [id D] 6
 │  │  │  ├─ Composite{...}.1 [id E] 0
 │  │  │  │  └─ n_steps [id F]
 │  │  │  ├─ TensorFromScalar [id G] 5
 │  │  │  │  └─ mul [id H] 4
 │  │  │  │     ├─ add [id I] 2
 │  │  │  │     │  ├─ -1 [id J]
 │  │  │  │     │  └─ ScalarFromTensor [id K] 1
 │  │  │  │     │     └─ Composite{...}.2 [id E] 0
 │  │  │  │     │        └─ ···
 │  │  │  │     └─ 1 [id L]
 │  │  │  ├─ Composite{...}.3 [id E] 0
 │  │  │  │  └─ ···
 │  │  │  ├─ Composite{...}.4 [id E] 0
 │  │  │  │  └─ ···
 │  │  │  ├─ TensorFromScalar [id M] 3
 │  │  │  │  └─ add [id I] 2
 │  │  │  │     └─ ···
 │  │  │  ├─ Composite{...}.2 [id E] 0
 │  │  │  │  └─ ···
 │  │  │  └─ Composite{...}.5 [id E] 0
 │  │  │     └─ ···
 │  │  ├─ TensorFromScalar [id N] 10
 │  │  │  └─ add [id O] 9
 │  │  │     ├─ ScalarFromTensor [id P] 8
 │  │  │     │  └─ Composite{...}.1 [id D] 6
 │  │  │     │     └─ ···
 │  │  │     └─ ScalarFromTensor [id Q] 7
 │  │  │        └─ Composite{...}.0 [id E] 0
 │  │  │           └─ ···
 │  │  ├─ Composite{...}.1 [id D] 6
 │  │  │  └─ ···
 │  │  ├─ TensorFromScalar [id R] 15
 │  │  │  └─ sub [id S] 14
 │  │  │     ├─ add [id T] 13
 │  │  │     │  ├─ ScalarFromTensor [id U] 12
 │  │  │     │  │  └─ Switch [id V] 11
 │  │  │     │  │     ├─ Composite{...}.2 [id D] 6
 │  │  │     │  │     │  └─ ···
 │  │  │     │  │     ├─ TensorFromScalar [id N] 10
 │  │  │     │  │     │  └─ ···
 │  │  │     │  │     └─ Composite{...}.1 [id D] 6
 │  │  │     │  │        └─ ···
 │  │  │     │  └─ 1 [id L]
 │  │  │     └─ 2 [id W]
 │  │  └─ n_steps [id F]
 │  └─ SetSubtensor{:stop} [id X] 23
 │     ├─ AllocEmpty{dtype='float64'} [id Y] 22
 │     │  └─ Composite{...}.2 [id C] 16
 │     │     └─ ···
 │     ├─ init_x [id Z]
 │     └─ 2 [id BA]
 └─ add [id BB] 21
    ├─ sub [id BC] 20
    │  ├─ sub [id BD] 19
    │  │  ├─ ScalarFromTensor [id U] 12
    │  │  │  └─ ···
    │  │  └─ ScalarFromTensor [id BE] 18
    │  │     └─ Composite{...}.0 [id C] 16
    │  │        └─ ···
    │  └─ 2 [id W]
    └─ ScalarFromTensor [id BF] 17
       └─ Composite{...}.1 [id C] 16
          └─ ···

Inner graphs:

Scan{scan_fn, while_loop=False, inplace=all} [id B]
 ← Composite{((2.0 * i0) + i1)} [id BG]
    ├─ *1-<Scalar(float64, shape=())> [id BH] -> [id X]
    └─ *0-<Scalar(float64, shape=())> [id BI] -> [id X]

Composite{...} [id C]
 ← maximum [id BJ] 'o0'
    ├─ minimum [id BK]
    │  ├─ i3 [id BL]
    │  └─ i4 [id BM]
    └─ 1 [id BN]
 ← maximum [id BO] 'o1'
    ├─ add [id BP]
    │  ├─ sub [id BQ]
    │  │  ├─ maximum [id BJ] 'o0'
    │  │  │  └─ ···
    │  │  └─ Switch [id BR]
    │  │     ├─ i0 [id BS]
    │  │     ├─ i1 [id BT]
    │  │     └─ i2 [id BU]
    │  └─ 2 [id BV]
    └─ 2 [id BV]
 ← add [id BW] 'o2'
    ├─ Switch [id BX]
    │  ├─ GT [id BY]
    │  │  ├─ 2 [id BV]
    │  │  └─ maximum [id BO] 'o1'
    │  │     └─ ···
    │  ├─ add [id BZ]
    │  │  ├─ maximum [id BO] 'o1'
    │  │  │  └─ ···
    │  │  └─ 2 [id BV]
    │  └─ sub [id CA]
    │     ├─ maximum [id BO] 'o1'
    │     │  └─ ···
    │     └─ 2 [id BV]
    └─ 2 [id CB]

Composite{...} [id D]
 ← LT [id CC] 'o0'
    ├─ Switch [id CD] 'o1'
    │  ├─ LT [id CE]
    │  │  ├─ i4 [id CF]
    │  │  └─ 0 [id CG]
    │  ├─ i6 [id CH]
    │  └─ Switch [id CI]
    │     ├─ GE [id CJ]
    │     │  ├─ i4 [id CF]
    │     │  └─ i5 [id CK]
    │     ├─ i3 [id CL]
    │     └─ Switch [id CM]
    │        ├─ i2 [id CN]
    │        ├─ i3 [id CL]
    │        └─ add [id CO]
    │           ├─ i0 [id CP]
    │           └─ i1 [id CQ]
    └─ 0 [id CG]
 ← Switch [id CD] 'o1'
    └─ ···
 ← LT [id CR] 'o2'
    ├─ Switch [id CD] 'o1'
    │  └─ ···
    └─ 0 [id CG]

Composite{...} [id E]
 ← add [id CS] 'o0'
    ├─ 2 [id CT]
    └─ i0 [id CU]
 ← Switch [id CV] 'o1'
    ├─ LT [id CW]
    │  ├─ 2 [id CT]
    │  └─ add [id CS] 'o0'
    │     └─ ···
    ├─ 2 [id CT]
    └─ add [id CS] 'o0'
       └─ ···
 ← sub [id CX] 'o2'
    ├─ add [id CS] 'o0'
    │  └─ ···
    └─ Switch [id CV] 'o1'
       └─ ···
 ← LE [id CY] 'o3'
    ├─ sub [id CX] 'o2'
    │  └─ ···
    └─ 0 [id CZ]
 ← add [id DA] 'o4'
    ├─ 3 [id DB]
    └─ i0 [id CU]
 ← sub [id DC] 'o5'
    ├─ -3 [id DD]
    └─ i0 [id CU]

Composite{((2.0 * i0) + i1)} [id BG]
 ← add [id DE] 'o0'
    ├─ mul [id DF]
    │  ├─ 2.0 [id DG]
    │  └─ i0 [id DH]
    └─ i1 [id DI]
After
Subtensor{i} [id A] 5
 ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 4
 │  ├─ Composite{...}.0 [id C] 0
 │  │  └─ n_steps [id D]
 │  └─ SetSubtensor{:stop} [id E] 3
 │     ├─ AllocEmpty{dtype='float64'} [id F] 2
 │     │  └─ Composite{...}.2 [id C] 0
 │     │     └─ ···
 │     ├─ init_x [id G]
 │     └─ 2 [id H]
 └─ ScalarFromTensor [id I] 1
    └─ Composite{...}.1 [id C] 0
       └─ ···

Inner graphs:

Scan{scan_fn, while_loop=False, inplace=all} [id B]
 ← Composite{((2.0 * i0) + i1)} [id J]
    ├─ *1-<Scalar(float64, shape=())> [id K] -> [id E]
    └─ *0-<Scalar(float64, shape=())> [id L] -> [id E]

Composite{...} [id C]
 ← maximum [id M] 'o0'
    ├─ minimum [id N]
    │  ├─ sub [id O]
    │  │  ├─ add [id P]
    │  │  │  ├─ Switch [id Q] 't24'
    │  │  │  │  ├─ LT [id R]
    │  │  │  │  │  ├─ Switch [id S] 't5'
    │  │  │  │  │  │  ├─ LT [id T]
    │  │  │  │  │  │  │  ├─ sub [id U] 't16'
    │  │  │  │  │  │  │  │  ├─ add [id V] 't20'
    │  │  │  │  │  │  │  │  │  ├─ 1 [id W]
    │  │  │  │  │  │  │  │  │  └─ i0 [id X]
    │  │  │  │  │  │  │  │  └─ Switch [id Y] 't35'
    │  │  │  │  │  │  │  │     ├─ LT [id Z]
    │  │  │  │  │  │  │  │     │  ├─ 2 [id BA]
    │  │  │  │  │  │  │  │     │  └─ add [id BB] 't37'
    │  │  │  │  │  │  │  │     │     ├─ 2 [id BA]
    │  │  │  │  │  │  │  │     │     └─ i0 [id X]
    │  │  │  │  │  │  │  │     ├─ 2 [id BA]
    │  │  │  │  │  │  │  │     └─ add [id BB] 't37'
    │  │  │  │  │  │  │  │        └─ ···
    │  │  │  │  │  │  │  └─ 0 [id BC]
    │  │  │  │  │  │  ├─ sub [id BD]
    │  │  │  │  │  │  │  ├─ -3 [id BE]
    │  │  │  │  │  │  │  └─ i0 [id X]
    │  │  │  │  │  │  └─ Switch [id BF]
    │  │  │  │  │  │     ├─ GE [id BG]
    │  │  │  │  │  │     │  ├─ sub [id U] 't16'
    │  │  │  │  │  │     │  │  └─ ···
    │  │  │  │  │  │     │  └─ sub [id BH] 't2'
    │  │  │  │  │  │     │     ├─ add [id BB] 't37'
    │  │  │  │  │  │     │     │  └─ ···
    │  │  │  │  │  │     │     └─ Switch [id Y] 't35'
    │  │  │  │  │  │     │        └─ ···
    │  │  │  │  │  │     ├─ add [id BI] 't12'
    │  │  │  │  │  │     │  ├─ 3 [id BJ]
    │  │  │  │  │  │     │  └─ i0 [id X]
    │  │  │  │  │  │     └─ Switch [id BK]
    │  │  │  │  │  │        ├─ LE [id BL]
    │  │  │  │  │  │        │  ├─ sub [id BH] 't2'
    │  │  │  │  │  │        │  │  └─ ···
    │  │  │  │  │  │        │  └─ 0 [id BC]
    │  │  │  │  │  │        ├─ add [id BI] 't12'
    │  │  │  │  │  │        │  └─ ···
    │  │  │  │  │  │        └─ add [id V] 't20'
    │  │  │  │  │  │           └─ ···
    │  │  │  │  │  └─ 0 [id BC]
    │  │  │  │  ├─ add [id BM]
    │  │  │  │  │  ├─ Switch [id S] 't5'
    │  │  │  │  │  │  └─ ···
    │  │  │  │  │  └─ add [id BB] 't37'
    │  │  │  │  │     └─ ···
    │  │  │  │  └─ Switch [id S] 't5'
    │  │  │  │     └─ ···
    │  │  │  └─ 1 [id BN]
    │  │  └─ 2 [id BO]
    │  └─ i0 [id X]
    └─ 1 [id BN]
 ← add [id BP] 'o1'
    ├─ sub [id BQ]
    │  ├─ sub [id BR]
    │  │  ├─ Switch [id Q] 't24'
    │  │  │  └─ ···
    │  │  └─ maximum [id M] 'o0'
    │  │     └─ ···
    │  └─ 2 [id BO]
    └─ maximum [id BS] 't21'
       ├─ add [id BT]
       │  ├─ sub [id BU]
       │  │  ├─ maximum [id M] 'o0'
       │  │  │  └─ ···
       │  │  └─ Switch [id Q] 't24'
       │  │     └─ ···
       │  └─ 2 [id BO]
       └─ 2 [id BO]
 ← add [id BV] 'o2'
    ├─ Switch [id BW]
    │  ├─ GT [id BX]
    │  │  ├─ 2 [id BO]
    │  │  └─ maximum [id BS] 't21'
    │  │     └─ ···
    │  ├─ add [id BY]
    │  │  ├─ maximum [id BS] 't21'
    │  │  │  └─ ···
    │  │  └─ 2 [id BO]
    │  └─ sub [id BZ]
    │     ├─ maximum [id BS] 't21'
    │     │  └─ ···
    │     └─ 2 [id BO]
    └─ 2 [id BA]

Composite{((2.0 * i0) + i1)} [id J]
 ← add [id CA] 'o0'
    ├─ mul [id CB]
    │  ├─ 2.0 [id CC]
    │  └─ i0 [id CD]
    └─ i1 [id CE]

Benchmarks

You can see meaningful gains in the following conditions:

Due to inplacing on the inner graph:

  • test_sit_sot_buffer_benchmark[512-2-unit] (2.25x faster median)
  • test_mit_sot_buffer_benchmark[1000-True] (1.95x faster median)

Note test_mit_sot_buffer_benchmark[1000-False] doesn't show difference, because the buffer size is not known statically. Tracked in #1283

Due to not trimming away the taps when the whole (user-facing) trace is requested:

  • test_sit_sot_buffer_benchmark[512-256-whole] (1.5x faster median)

Before

----------------------------------------------------------------------------------------------------- benchmark: 20 tests ------------------------------------------------------------------------------------------------------
Name (time in us)                                          Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_sit_sot_buffer_benchmark[10-2-whole+init]          3.8770 (1.0)       66.8960 (1.87)       4.5554 (1.06)      1.1789 (1.56)       4.2780 (1.04)      0.5610 (6.23)    3001;3787      219.5211 (0.94)      54307           1
test_sit_sot_buffer_benchmark[10-2-aligned]             3.9080 (1.01)      35.8180 (1.0)        4.2945 (1.0)       0.7580 (1.0)        4.1170 (1.0)       0.1410 (1.57)    2528;9835      232.8580 (1.0)       67079           1
test_sit_sot_buffer_benchmark[10-2-misaligned]          3.9080 (1.01)      56.5660 (1.58)       4.4659 (1.04)      0.9656 (1.27)       4.1380 (1.01)      0.6810 (7.57)    4298;3823      223.9209 (0.96)      63984           1
test_sit_sot_buffer_benchmark[10-2-whole]               3.9170 (1.01)      50.4640 (1.41)       4.4979 (1.05)      1.0227 (1.35)       4.1480 (1.01)      0.6210 (6.90)    4653;4466      222.3239 (0.95)      59446           1
test_sit_sot_buffer_benchmark[10-2-unit]                3.9470 (1.02)     157.8550 (4.41)       4.3113 (1.00)      1.3718 (1.81)       4.1380 (1.01)      0.0900 (1.0)     1978;4611      231.9495 (1.00)      53982           1
test_sit_sot_buffer_benchmark[512-2-whole+init]        41.8180 (10.79)    177.1510 (4.95)      46.0673 (10.73)     4.6780 (6.17)      44.5440 (10.82)     4.8192 (53.55)    1270;520       21.7074 (0.09)      17353           1
test_sit_sot_buffer_benchmark[512-2-misaligned]        41.9580 (10.82)     94.9780 (2.65)      45.2083 (10.53)     4.2844 (5.65)      42.4600 (10.31)     5.1900 (57.67)    1807;637       22.1198 (0.09)      17949           1
test_sit_sot_buffer_benchmark[512-2-unit]              41.9590 (10.82)     98.9660 (2.76)      44.1287 (10.28)     3.8746 (5.11)      42.4700 (10.32)     1.0223 (11.36)   1805;4116       22.6610 (0.10)      18717           1
test_sit_sot_buffer_benchmark[512-2-aligned]           42.3590 (10.93)    115.2060 (3.22)      46.1430 (10.74)     4.2601 (5.62)      44.8840 (10.90)     0.3510 (3.90)    1460;3471       21.6717 (0.09)      18511           1
test_sit_sot_buffer_benchmark[512-2-whole]             43.0510 (11.10)    206.8370 (5.77)      47.1326 (10.98)     4.2129 (5.56)      45.6160 (11.08)     0.5510 (6.12)    1338;3636       21.2167 (0.09)      16529           1
test_sit_sot_buffer_benchmark[512-256-misaligned]     101.9010 (26.28)    410.7290 (11.47)    118.6487 (27.63)    20.9061 (27.58)    110.5570 (26.85)     6.7223 (74.69)    713;1557        8.4282 (0.04)       8409           1
test_sit_sot_buffer_benchmark[512-256-aligned]        111.2180 (28.69)    453.5300 (12.66)    126.8957 (29.55)    15.6212 (20.61)    118.3620 (28.75)    13.4450 (149.39)   1692;534        7.8805 (0.03)       7750           1
test_sit_sot_buffer_benchmark[512-256-whole+init]     112.6810 (29.06)    299.4810 (8.36)     132.7842 (30.92)    18.8627 (24.88)    121.3370 (29.47)    28.7640 (319.60)   1068;109        7.5310 (0.03)       6934           1
test_sit_sot_buffer_benchmark[512-256-unit]           114.5450 (29.54)    807.0400 (22.53)    131.1149 (30.53)    21.1005 (27.84)    120.2350 (29.20)    23.3340 (259.27)    637;229        7.6269 (0.03)       6618           1
test_sit_sot_buffer_benchmark[512-256-whole]          139.1810 (35.90)    631.2920 (17.62)    182.8883 (42.59)    34.8052 (45.91)    179.6315 (43.63)    12.8540 (142.82)  1005;1368        5.4678 (0.02)       4444           1

test_mit_sot_buffer_benchmark[1-False]                  4.7290 (1.22)      64.0800 (1.79)       5.1788 (1.21)      1.3283 (1.75)       4.9990 (1.21)      0.1110 (1.23)    1335;3921      193.0949 (0.83)      40976           1
test_mit_sot_buffer_benchmark[1000-True]              131.0860 (33.81)    325.9300 (9.10)     137.4081 (32.00)     9.6476 (12.73)    132.3770 (32.15)     8.9045 (98.94)     731;309        7.2776 (0.03)       6785           1
test_mit_sot_buffer_benchmark[1000-False]             132.2880 (34.12)    241.2920 (6.74)     139.4650 (32.48)    10.2454 (13.52)    133.7705 (32.49)     9.3020 (103.36)    790;446        7.1703 (0.03)       6176           1

test_scan_multiple_output                              68.3980 (17.64)    414.0350 (11.56)     95.3713 (22.21)    25.3289 (33.41)     84.4630 (20.52)    32.4210 (360.23)     535;54       10.4853 (0.05)       3278           1
test_vector_taps_benchmark                            321.8020 (83.00)    498.2430 (13.91)    338.0677 (78.72)    17.4038 (22.96)    330.5240 (80.28)    27.1010 (301.12)     507;29        2.9580 (0.01)       2740           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After

----------------------------------------------------------------------------------------------------- benchmark: 20 tests ------------------------------------------------------------------------------------------------------
Name (time in us)                                          Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_sit_sot_buffer_benchmark[10-2-unit]                3.5460 (1.0)       30.9580 (1.0)        4.2143 (1.0)       0.8748 (1.03)       3.8870 (1.0)       0.6100 (5.08)    4511;4073      237.2855 (1.0)       58786           1
test_sit_sot_buffer_benchmark[10-2-aligned]             3.9480 (1.11)      60.8740 (1.97)       4.5460 (1.08)      1.1993 (1.41)       4.2480 (1.09)      0.5210 (4.34)    3571;3929      219.9745 (0.93)      58748           1
test_sit_sot_buffer_benchmark[10-2-misaligned]          4.0470 (1.14)      65.7730 (2.12)       4.5319 (1.08)      1.1273 (1.33)       4.2780 (1.10)      0.2610 (2.18)    2999;6837      220.6601 (0.93)      62074           1
test_sit_sot_buffer_benchmark[10-2-whole]               4.0770 (1.15)      50.5450 (1.63)       4.6089 (1.09)      1.0177 (1.20)       4.3380 (1.12)      0.1200 (1.0)     4726;9100      216.9707 (0.91)      56393           1
test_sit_sot_buffer_benchmark[10-2-whole+init]          4.0870 (1.15)      62.8380 (2.03)       4.5927 (1.09)      1.1138 (1.31)       4.3280 (1.11)      0.1700 (1.42)   3271;11037      217.7381 (0.92)      53749           1
test_sit_sot_buffer_benchmark[512-2-unit]              18.5050 (5.22)      82.6150 (2.67)      21.1160 (5.01)      4.4988 (5.30)      18.8750 (4.86)      4.2680 (35.57)   3554;2332       47.3576 (0.20)      31687           1
test_sit_sot_buffer_benchmark[512-2-whole]             44.9340 (12.67)    799.0560 (25.81)     48.7547 (11.57)     8.9102 (10.50)     45.9560 (11.82)     2.9060 (24.22)    963;1910       20.5109 (0.09)      14211           1
test_sit_sot_buffer_benchmark[512-2-whole+init]        44.9540 (12.68)    170.6100 (5.51)      48.7200 (11.56)     6.3039 (7.43)      46.3470 (11.92)     2.0450 (17.04)   1532;2643       20.5255 (0.09)      16528           1
test_sit_sot_buffer_benchmark[512-2-aligned]           46.1160 (13.01)    156.8040 (5.07)      48.4194 (11.49)     3.7840 (4.46)      46.8080 (12.04)     2.0340 (16.95)   1597;1777       20.6529 (0.09)      17057           1
test_sit_sot_buffer_benchmark[512-2-misaligned]        46.1960 (13.03)    125.4850 (4.05)      48.3574 (11.47)     3.7634 (4.44)      46.7580 (12.03)     1.8640 (15.53)   1820;2029       20.6794 (0.09)      17159           1
test_sit_sot_buffer_benchmark[512-256-aligned]        101.7210 (28.69)    248.5050 (8.03)     115.9539 (27.51)    14.5836 (17.19)    109.2740 (28.11)     6.3492 (52.91)   1489;1583        8.6241 (0.04)       8587           1
test_sit_sot_buffer_benchmark[512-256-misaligned]     112.8110 (31.81)    259.5560 (8.38)     125.5590 (29.79)    17.0010 (20.04)    118.3720 (30.45)     5.5700 (46.42)   1052;1289        7.9644 (0.03)       8056           1
test_sit_sot_buffer_benchmark[512-256-whole]          113.4730 (32.00)    291.7460 (9.42)     131.2832 (31.15)    18.1112 (21.35)    120.8060 (31.08)    25.2845 (210.70)   1683;117        7.6171 (0.03)       6953           1
test_sit_sot_buffer_benchmark[512-256-unit]           117.1990 (33.05)    282.5890 (9.13)     137.9032 (32.72)    21.3947 (25.22)    126.0060 (32.42)    26.6100 (221.75)    922;243        7.2515 (0.03)       7317           1
test_sit_sot_buffer_benchmark[512-256-whole+init]     118.4920 (33.42)    299.9210 (9.69)     126.3285 (29.98)    16.2633 (19.17)    120.1050 (30.90)     5.6185 (46.82)     714;892        7.9159 (0.03)       6859           1

test_mit_sot_buffer_benchmark[1-False]                  4.1180 (1.16)      70.0510 (2.26)       4.5866 (1.09)      0.8482 (1.0)        4.4190 (1.14)      0.1510 (1.26)    1522;6659      218.0249 (0.92)      45702           1
test_mit_sot_buffer_benchmark[1000-True]               66.8950 (18.86)    145.5430 (4.70)      69.4895 (16.49)     4.9975 (5.89)      67.5870 (17.39)     1.0120 (8.43)    1272;2391       14.3907 (0.06)      12095           1
test_mit_sot_buffer_benchmark[1000-False]             144.9420 (40.87)    307.1260 (9.92)     151.2527 (35.89)     9.9623 (11.74)    147.0650 (37.84)     6.0610 (50.51)     558;553        6.6115 (0.03)       5809           1

test_scan_multiple_output                              69.3700 (19.56)    458.3890 (14.81)     77.6957 (18.44)    13.1526 (15.51)     72.2860 (18.60)     7.7642 (64.70)     439;448       12.8707 (0.05)       4729           1
test_vector_taps_benchmark                            328.4050 (92.61)    531.9460 (17.18)    344.4654 (81.74)    18.5545 (21.87)    337.0910 (86.72)    23.8542 (198.78)     407;60        2.9030 (0.01)       2529           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


📚 Documentation preview 📚: https://pytensor--1281.org.readthedocs.build/en/1281/

@ricardoV94 ricardoV94 force-pushed the scan_save_mem branch 6 times, most recently from ab9763f to 375fec1 Compare March 10, 2025 14:00
@ricardoV94 ricardoV94 marked this pull request as ready for review March 10, 2025 15:41
@@ -29,7 +29,7 @@ def scan(*outer_inputs):
# Extract JAX scan inputs
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs`
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs`
Copy link
Member Author

@ricardoV94 ricardoV94 Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTensor Scan allows sequences to be longer than steps, in which case they just get unused. The Scan save memory rewrite doesn't bother with trimming them.

JAX however doesn't allow it. Fixing the constant nsteps optimization revealed this issue.

Copy link

codecov bot commented Mar 10, 2025

Codecov Report

Attention: Patch coverage is 94.11765% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.03%. Comparing base (c0860f8) to head (cae417f).
Report is 15 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/scan/rewriting.py 92.59% 0 Missing and 2 partials ⚠️
pytensor/tensor/subtensor.py 90.47% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (94.11%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1281      +/-   ##
==========================================
+ Coverage   81.98%   82.03%   +0.04%     
==========================================
  Files         188      188              
  Lines       48575    48567       -8     
  Branches     8685     8675      -10     
==========================================
+ Hits        39826    39841      +15     
+ Misses       6585     6574      -11     
+ Partials     2164     2152      -12     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 84.72% <100.00%> (ø)
pytensor/configdefaults.py 73.54% <ø> (ø)
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/scan.py 96.05% <100.00%> (+0.07%) ⬆️
pytensor/scan/op.py 84.82% <100.00%> (+0.09%) ⬆️
pytensor/scan/utils.py 87.56% <100.00%> (ø)
pytensor/tensor/basic.py 91.16% <100.00%> (-0.21%) ⬇️
pytensor/scan/rewriting.py 82.62% <92.59%> (+0.93%) ⬆️
pytensor/tensor/subtensor.py 89.30% <90.47%> (-0.06%) ⬇️

... and 8 files with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the scan_save_mem branch 2 times, most recently from acb486e to af8ccb1 Compare March 12, 2025 17:09
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big PR but it all looks kosher as far as I can tell. Maybe some extra tests for new functionality (checking the destroy map of input buffers in numba mode).

Side comment: it's a bit frustrating that the order of the tests in the benchmark output isn't consistent, is that possible to fix (instead of sort by speed)

@@ -417,4 +442,4 @@ def scan({", ".join(outer_in_names)}):

scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})

return numba_basic.numba_njit(scan_op_fn)
return numba_basic.numba_njit(scan_op_fn, boundscheck=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boundscheck defaults to False in numba, do we override that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not found this easier to read

else:
return idx + length, 1
else:
return switch(lt(idx, 0), idx + length, idx), 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit and unrelated to this PR: We should have a helper function like numpy's normalize negative index to make it more obvious what's being done here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of the symbolic version of it. It's used to normalize and infer the shape of an index operation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I understand that, but I think ideally we'd have return pt.indexing.normalize_negative_indices(idx, length) here.

To be clear, I am not asking for this in this PR, just mentioning it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this is pt.indexing.normalize_negative_indices itself

output, _ = scan(
f_pow2,
sequences=[],
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand well, we should have a case with two outputs info with different taps lengths to test the changes to oldest_inner_mitsot

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test

[init_x] if constant_n_steps else [init_x, n_steps],
[output[-1]],
test_vals,
numba_mode="NUMBA",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test that the relevant buffers have the expected destroy_map?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for that specifically

@ricardoV94
Copy link
Member Author

Side comment: it's a bit frustrating that the order of the tests in the benchmark output isn't consistent, is that possible to fix (instead of sort by speed)

Not sure, I want to move to asv maybe their output is saner. I do think I could have compared two runs explicitly but didn't remember to do that

Graph was being broken by Scalar/Tensor conversions that prevented fusion
@ricardoV94 ricardoV94 merged commit c822a8e into pymc-devs:main Mar 13, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 changed the title Speedup Scan in different backends Speedup Scan in different backends Mar 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants