-
Notifications
You must be signed in to change notification settings - Fork 129
Speedup Scan
in different backends
#1281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ab9763f
to
375fec1
Compare
375fec1
to
ab57a79
Compare
isinstance(..., int) does not recognize numpy.integers Also remove maxsize logic
This will always require a roll at the end, for a minimal gain
ab57a79
to
249dfae
Compare
@@ -29,7 +29,7 @@ def scan(*outer_inputs): | |||
# Extract JAX scan inputs | |||
outer_inputs = list(outer_inputs) | |||
n_steps = outer_inputs[0] # JAX `length` | |||
seqs = op.outer_seqs(outer_inputs) # JAX `xs` | |||
seqs = [seq[:n_steps] for seq in op.outer_seqs(outer_inputs)] # JAX `xs` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTensor Scan allows sequences to be longer than steps, in which case they just get unused. The Scan save memory rewrite doesn't bother with trimming them.
JAX however doesn't allow it. Fixing the constant nsteps optimization revealed this issue.
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (94.11%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1281 +/- ##
==========================================
+ Coverage 81.98% 82.03% +0.04%
==========================================
Files 188 188
Lines 48575 48567 -8
Branches 8685 8675 -10
==========================================
+ Hits 39826 39841 +15
+ Misses 6585 6574 -11
+ Partials 2164 2152 -12
🚀 New features to boost your workflow:
|
acb486e
to
af8ccb1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big PR but it all looks kosher as far as I can tell. Maybe some extra tests for new functionality (checking the destroy map of input buffers in numba mode).
Side comment: it's a bit frustrating that the order of the tests in the benchmark output isn't consistent, is that possible to fix (instead of sort by speed)
@@ -417,4 +442,4 @@ def scan({", ".join(outer_in_names)}): | |||
|
|||
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) | |||
|
|||
return numba_basic.numba_njit(scan_op_fn) | |||
return numba_basic.numba_njit(scan_op_fn, boundscheck=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
boundscheck
defaults to False in numba, do we override that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not found this easier to read
else: | ||
return idx + length, 1 | ||
else: | ||
return switch(lt(idx, 0), idx + length, idx), 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit and unrelated to this PR: We should have a helper function like numpy's normalize negative index to make it more obvious what's being done here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of the symbolic version of it. It's used to normalize and infer the shape of an index operation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I understand that, but I think ideally we'd have return pt.indexing.normalize_negative_indices(idx, length)
here.
To be clear, I am not asking for this in this PR, just mentioning it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this is pt.indexing.normalize_negative_indices itself
output, _ = scan( | ||
f_pow2, | ||
sequences=[], | ||
outputs_info=[{"initial": init_x, "taps": [-2, -1]}], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand well, we should have a case with two outputs info with different taps lengths to test the changes to oldest_inner_mitsot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test
[init_x] if constant_n_steps else [init_x, n_steps], | ||
[output[-1]], | ||
test_vals, | ||
numba_mode="NUMBA", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test that the relevant buffers have the expected destroy_map
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test for that specifically
Not sure, I want to move to asv maybe their output is saner. I do think I could have compared two runs explicitly but didn't remember to do that |
af8ccb1
to
af218a6
Compare
…e discarded immediately
Graph was being broken by Scalar/Tensor conversions that prevented fusion
af218a6
to
cae417f
Compare
Scan
in different backends
This PR does a bunch of small optimizations for the Scan, mostly in the Numba backend, but to a lesser extent also in the C and JAX backends.
Tweaks:
Fix failure in trimming number of steps in Scan when the new n_step would be a constant integer. The isinstance(x, int) does not work for
np.integer
.Do not try to reduce buffer size just to save the space occupied by the initial values. Scan is always created with a buffer size for the recurring inputs with a size of taps + steps. The scan returned to the user slices away the taps, so the user doesn't see the initial values there. When the user wants the whole returned trace (instead of say only the last state) AND no gradient is requested (the gradient makes use of the initial taps in the trace), the Scan save memory would trim into a buffer with size steps. This however, always requires a roll at the end of the Scan, because the discarded entries (the taps) are at the beginning. After this PR the scan save memory doesn't optimize this specific case, as the default slice is likely better than the roll which requires a copy. Benchmarks confirm this.
Be more aggressive in the buffer memory save in the JIT backends. As described in Missed scan rewrites #787 Scan is keeping one more entry in the circular buffer for recurring states than strictly needed. This happens because the Python-C implementation make clever use of the inner compiled pytensor function to manually specify the output buffers. In this case it's not always safe for PyTensor to override the buffer, because the inputs may still be needed to compute other outputs. An extra entry was provided for the output. This is not a concern in the JIT backends because there's no such machinery to exploit. As such the scan save memory rewrite was split to be more aggressive in the JIT backends.
Minor tweaks in the Numba dispatch of Scan to not try to roll when the buffer size is already aligned. I didn't see any stable changes but seems more clean. Also set checkbounds explicitly to False, but I think that may be the deault. Neither mattered in my benchmarks
Allow inplacing in the inner function for sitsot and the last mitsot when our buffer is so small that the last tap gets discarded immediately anyway (happens when users asks for trace[-1] of a scan). This provides substantial speedups in some cases. We should do the same for mit-mots but I haven't yet groked how they are supposed to behave. Issue to track it Allow inner-function inplace of mit-mot in Scan backend #1282
Note this inplace doesn't work right now when n_steps is symbolic as it relies on PyTensor having figured out the static shape of the buffer matches the size of the taps. We need more work for this: #1282
subtensor_merge
works:local_subtensor_merge
can complicate graphs #112. This is mostly because we can not know if the number of steps will be zero in which case it's invalid for the user to selecttrace[-1]
. The graph is overly complicated, it is perhaps better to just add an assert that n_steps > 0 like we do for the while scan optimization. Anyway, the graph was even worse before this PR because of a bunch of ScalarFromTensor and TensorFromScalar that would show up in the buffer size logic. This PR cleans it up, so at least everything happens in one single fused graph. Graph before for the mitsot test:Before
After
Benchmarks
You can see meaningful gains in the following conditions:
Due to inplacing on the inner graph:
Note test_mit_sot_buffer_benchmark[1000-False] doesn't show difference, because the buffer size is not known statically. Tracked in #1283
Due to not trimming away the taps when the whole (user-facing) trace is requested:
Before
After
📚 Documentation preview 📚: https://pytensor--1281.org.readthedocs.build/en/1281/