-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace batched_convolution
by PyTensor native implementation
#1583
base: main
Are you sure you want to change the base?
Conversation
batched_convolution
by pytensor native implementation
Insane improvement, thats awesome! |
Nice 😎 ! |
Can we get some examples of where this helps us in the library? And if you can provide a vignette of how to leverage this new convolution, that would be great too. I see that you import |
There's nothing you need to do to leverage. It's used by the geometric adstock, so any MMM models with it (most) will benefit from the PyTensor implementation, which is better than the custom one we had (atleast 3-4x runtime as reported on the vignette) and much faster compile times. The import of |
batched_convolution
by pytensor native implementationbatched_convolution
by PyTensor native implementation
I see pymc-devs/pytensor#1318 merge. Do what would be next to merge here? @ricardoV94 Do you need a hand? |
It needed a pymc release, did it now |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1583 +/- ##
===========================================
- Coverage 93.53% 67.04% -26.49%
===========================================
Files 55 55
Lines 6357 6345 -12
===========================================
- Hits 5946 4254 -1692
- Misses 411 2091 +1680 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Seems like there's another, perhaps unrelated failure, with the latest PyTensor |
For reference: ERROR pytensor.graph.rewriting.basic:basic.py:1757 Rewrite failure due to: local_shape_to_shape_i
ERROR pytensor.graph.rewriting.basic:basic.py:1758 node: Shape(Blockwise{Subtensor{start:stop}, (i00),(),()->(o00)}.0)
ERROR pytensor.graph.rewriting.basic:basic.py:1759 TRACEBACK:
ERROR pytensor.graph.rewriting.basic:basic.py:1760 Traceback (most recent call last):
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1968, in process_node
fgraph.replace_all_validate_remove( # type: ignore
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/features.py", line 634, in replace_all_validate_remove
chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/features.py", line 579, in replace_all_validate
fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 535, in replace
self.change_node_input(
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 456, in change_node_input
self.import_var(new_var, reason=reason, import_missing=import_missing)
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 323, in import_var
self.import_node(var.owner, reason=reason, import_missing=import_missing)
File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 388, in import_node
raise MissingInputError(error_msg, variable=var)
pytensor.graph.utils.MissingInputError: Input 0 (<Scalar(int64, shape=())>) of the graph (indices start from 0), used to compute Add(<Scalar(int64, shape=())>, Sub.0), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error. Any ideas why this could be? |
Not yet, I'll have to jump on the interpreter |
Found the issue, will be fixed by pymc-devs/pytensor#1353 |
We should lower pin the last minor version of PyTensor separately from the pymc pin, because pymc doesn't pin it |
sure! sounds good! |
This will depend on the not yet released pymc-devs/pytensor#1318
That will also require a dependency bump on PyMC before we can use it here.
The PyTensor impl gives us batching out of the box, without a graph compile and runtime penalty. I got these runtime differences before and after my changes:
Not seen here is also the widely different compile times for very large kernel_sizes, which in the old implementation would be a series of unrolled
set_subtensor
. Now the compile time is constant on the number of lags. It's also no longer needed to have statically known lags.The gradient should provide a nice speedup as well, with room for further cleverness outside of pymc-marketing: pymc-devs/pytensor#1320
In practice users won't be doing a crazy number of lags, so speedup on runtime is perhaps just those 3-4x I saw locally for
kernel_size=8
, and not some of the 1000x I got for other combinations. It may be a higher/ lower delta in other backends, feel free to check.The compile time and simpler graphs are what drove me to do these changes
📚 Documentation preview 📚: https://pymc-marketing--1583.org.readthedocs.build/en/1583/