Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace batched_convolution by PyTensor native implementation #1583

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Mar 24, 2025

This will depend on the not yet released pymc-devs/pytensor#1318

That will also require a dependency bump on PyMC before we can use it here.

The PyTensor impl gives us batching out of the box, without a graph compile and runtime penalty. I got these runtime differences before and after my changes:

import numpy as np
import pytensor.tensor as pt
import pytensor
from pymc_marketing.mmm.transformers import batched_convolution

for kernel_size in (8, 80, 800):
    for data_size in (12, 120, 1200):
        data = pt.vector("x", shape=(data_size,))
        kernel = pt.vector("kernel", shape=(kernel_size,))

        out = batched_convolution(data, kernel, mode="After")
        fn = pytensor.function([data, kernel], out, trust_input=True)

        rng = np.random.default_rng((kernel_size, data_size))
        data_test = rng.normal(size=data.type.shape)
        kernel_test = rng.normal(size=kernel.type.shape)

        print(f"{kernel_size=}, {data_size=}")
        %timeit fn(data_test, kernel_test)

# Before changes:
# kernel_size=8, data_size=12
# 35.3 μs ± 4.67 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# kernel_size=8, data_size=120
# 29.7 μs ± 432 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# kernel_size=8, data_size=1200
# 30.1 μs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# kernel_size=80, data_size=12
# 398 μs ± 1.82 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# kernel_size=80, data_size=120
# 402 μs ± 5.93 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# kernel_size=80, data_size=1200
# 401 μs ± 6.74 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# kernel_size=800, data_size=12
# 7.28 ms ± 45.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# kernel_size=800, data_size=120
# 7.97 ms ± 469 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# kernel_size=800, data_size=1200
# 7.5 ms ± 406 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# After changes:
# kernel_size=8, data_size=12
# 6.06 μs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# kernel_size=8, data_size=120
# 6.32 μs ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# kernel_size=8, data_size=1200
# 9.18 μs ± 67.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

# kernel_size=80, data_size=12
# 6.67 μs ± 180 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# kernel_size=80, data_size=120
# 9.81 μs ± 60.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# kernel_size=80, data_size=1200
# 39.3 μs ± 1.15 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# kernel_size=800, data_size=12
# 9.22 μs ± 21.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# kernel_size=800, data_size=120
# 24.6 μs ± 425 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# kernel_size=800, data_size=1200
# 179 μs ± 2.92 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Not seen here is also the widely different compile times for very large kernel_sizes, which in the old implementation would be a series of unrolled set_subtensor. Now the compile time is constant on the number of lags. It's also no longer needed to have statically known lags.

The gradient should provide a nice speedup as well, with room for further cleverness outside of pymc-marketing: pymc-devs/pytensor#1320

In practice users won't be doing a crazy number of lags, so speedup on runtime is perhaps just those 3-4x I saw locally for kernel_size=8, and not some of the 1000x I got for other combinations. It may be a higher/ lower delta in other backends, feel free to check.

The compile time and simpler graphs are what drove me to do these changes


📚 Documentation preview 📚: https://pymc-marketing--1583.org.readthedocs.build/en/1583/

@ricardoV94 ricardoV94 changed the title Replace batched_convolution by pytensor native impl Replace batched_convolution by pytensor native implementation Mar 24, 2025
@github-actions github-actions bot added the MMM label Mar 24, 2025
@cetagostini
Copy link
Contributor

Insane improvement, thats awesome!

@juanitorduz
Copy link
Collaborator

Nice 😎 !

@cluhmann
Copy link
Contributor

Can we get some examples of where this helps us in the library? And if you can provide a vignette of how to leverage this new convolution, that would be great too. I see that you import from pytensor.tensor.signal import convolve, but then don't use it (directly). Is that where the magic is happening?

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Mar 27, 2025

Can we get some examples of where this helps us in the library? And if you can provide a vignette of how to leverage this new convolution, that would be great too. I see that you import from pytensor.tensor.signal import convolve, but then don't use it (directly). Is that where the magic is happening?

There's nothing you need to do to leverage. It's used by the geometric adstock, so any MMM models with it (most) will benefit from the PyTensor implementation, which is better than the custom one we had (atleast 3-4x runtime as reported on the vignette) and much faster compile times.

The import of convolve was just left from older code I was using to create the benchmark. Updated it

@ricardoV94 ricardoV94 changed the title Replace batched_convolution by pytensor native implementation Replace batched_convolution by PyTensor native implementation Mar 27, 2025
@cetagostini
Copy link
Contributor

I see pymc-devs/pytensor#1318 merge. Do what would be next to merge here? @ricardoV94

Do you need a hand?

@ricardoV94
Copy link
Contributor Author

I see pymc-devs/pytensor#1318 merge. Do what would be next to merge here? @ricardoV94

Do you need a hand?

It needed a pymc release, did it now

Copy link

codecov bot commented Apr 3, 2025

Codecov Report

Attention: Patch coverage is 62.50000% with 6 lines in your changes missing coverage. Please review.

Project coverage is 67.04%. Comparing base (a87f1fa) to head (8d1c5de).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pymc_marketing/mmm/transformers.py 62.50% 6 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (a87f1fa) and HEAD (8d1c5de). Click for more details.

HEAD has 6 uploads less than BASE
Flag BASE (a87f1fa) HEAD (8d1c5de)
23 17
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1583       +/-   ##
===========================================
- Coverage   93.53%   67.04%   -26.49%     
===========================================
  Files          55       55               
  Lines        6357     6345       -12     
===========================================
- Hits         5946     4254     -1692     
- Misses        411     2091     +1680     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Apr 5, 2025

Seems like there's another, perhaps unrelated failure, with the latest PyTensor

@juanitorduz
Copy link
Collaborator

Seems like there's another, perhaps unrelated failure, with the latest PyTensor

For reference:

ERROR    pytensor.graph.rewriting.basic:basic.py:1757 Rewrite failure due to: local_shape_to_shape_i
ERROR    pytensor.graph.rewriting.basic:basic.py:1758 node: Shape(Blockwise{Subtensor{start:stop}, (i00),(),()->(o00)}.0)
ERROR    pytensor.graph.rewriting.basic:basic.py:1759 TRACEBACK:
ERROR    pytensor.graph.rewriting.basic:basic.py:1760 Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/rewriting/basic.py", line 1968, in process_node
    fgraph.replace_all_validate_remove(  # type: ignore
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/features.py", line 634, in replace_all_validate_remove
    chk = fgraph.replace_all_validate(replacements, reason=reason, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/features.py", line 579, in replace_all_validate
    fgraph.replace(r, new_r, reason=reason, verbose=False, **kwargs)
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 535, in replace
    self.change_node_input(
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 456, in change_node_input
    self.import_var(new_var, reason=reason, import_missing=import_missing)
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 323, in import_var
    self.import_node(var.owner, reason=reason, import_missing=import_missing)
  File "/opt/hostedtoolcache/Python/3.12.9/x64/lib/python3.12/site-packages/pytensor/graph/fg.py", line 388, in import_node
    raise MissingInputError(error_msg, variable=var)
pytensor.graph.utils.MissingInputError: Input 0 (<Scalar(int64, shape=())>) of the graph (indices start from 0), used to compute Add(<Scalar(int64, shape=())>, Sub.0), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error.

Any ideas why this could be?

@ricardoV94
Copy link
Contributor Author

Not yet, I'll have to jump on the interpreter

@ricardoV94
Copy link
Contributor Author

Found the issue, will be fixed by pymc-devs/pytensor#1353

@ricardoV94
Copy link
Contributor Author

We should lower pin the last minor version of PyTensor separately from the pymc pin, because pymc doesn't pin it

@juanitorduz
Copy link
Collaborator

We should lower pin the last minor version of PyTensor separately from the pymc pin, because pymc doesn't pin it

sure! sounds good!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants