Skip to content

Hardcode common Op parametrizations to allow numba caching #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 2, 2025

This is perhaps the ugliest PR of my life. Hardcoding common parametrizations of string-generated Ops helps a tiny bit with numba caching.

After caching, I see compile times of the logp-dlogp function of the nutpie readme example going down:

  • first run in interpreter: 9.17s x 5.8s (1.6x speedup)
  • subsequent runs in interpreter: 6.7s x 5.5s (1.2x speedup)

This includes PyMC-PyTensor compile time.
For reference FAST_RUN takes 1.8s then 1.5s, and JAX 2.0s then 1.5s

Test snippet
%env NUMBA_DEBUG_CACHE = 0

import time
import pytensor
import pymc as pm
import numpy as np
import pandas as pd
from pytensor.compile.mode import get_mode

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )

from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
ip = np.concatenate([v.reshape(-1) for v in model.initial_point().values()])

for mode in ("FAST_RUN", "JAX", "NUMBA"):
    print(mode)
    for i in range(3):
        start = time.time()
        fn = model.logp_dlogp_function(ravel_inputs=True, mode=mode)._pytensor_function
        fn(ip)[1].mean()
        end = time.time()
        print(end - start)
    print()

A more aggressive approach may come out of #1326, which would render this code duplication unnecessary, but that's still too green to see the light of day (and may prove completely impractical).

Also the Elemwise overload seems to always trigger some cache writing when the interpreter is launched again even if store_core_outputs and the core_op can be cached on subsequent runs.

I don't know what's going on with that, perhaps @aseyboldt has an idea. Relevant snippet:

%env NUMBA_DEBUG_CACHE = 1
import time
import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.vector("x")
out = pt.erf(x)  # Need a scalar Op that's not string-generated
x_test = np.arange(10).astype(x.dtype)

for i in range(10):
    start = time.time()
    fn = pytensor.function([x], out, mode="NUMBA")
    fn(x_test)
    end = time.time()
    print(end - start)

# [cache] data saved to '/home/ricardo/Documents/pytensor/pytensor/link/numba/dispatch/__pycache__/vectorize_codegen.store_core_outputs.locals.func-42.py312.13.nbc'
# 0.6798737049102783
# 0.11181092262268066
# 0.09507083892822266
# 0.09530901908874512
# 0.09833431243896484
# 0.09389829635620117
# 0.09442591667175293
# 0.09900140762329102
# 0.09177708625793457
# 0.09354805946350098

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look reasonable to me, and not too horribly ugly. Were there other obvious cases where we can do this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants