Skip to content

Add example on freeze_data_and_dims #7594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ Model
:maxdepth: 2

model/core
model/transform
model/conditioning
model/optimization
model/fgraph
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,3 @@ Model Conditioning
observe
change_value_transforms
remove_value_transforms


Model Optimization
------------------
.. currentmodule:: pymc.model.transform.optimization
.. autosummary::
:toctree: generated/

freeze_dims_and_data
7 changes: 7 additions & 0 deletions docs/source/api/model/optimization.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Model Optimization
------------------
.. currentmodule:: pymc.model.transform.optimization
.. autosummary::
:toctree: generated/

freeze_dims_and_data
22 changes: 22 additions & 0 deletions pymc/model/transform/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,28 @@ def freeze_dims_and_data(
-------
Model
A new model with the specified dimensions and data frozen.


Examples
--------
.. code-block:: python

import pymc as pm
import pytensor.tensor as pt

from pymc.model.transform.optimization import freeze_dims_and_data

with pm.Model() as m:
x = pm.Data("x", [0, 1, 2] * 1000)
y = pm.Normal("y", mu=pt.unique(x).mean())

# pt.unique(x).mean() has to be computed in every logp function evaluation
print("Logp eval time (1000x): ", m.profile(m.logp()).fct_call_time)

# pt.uniqe(x).mean() is cached in the logp function
frozen_m = freeze_dims_and_data(m)
print("Logp eval time (1000x): ", frozen_m.profile(frozen_m.logp()).fct_call_time)

"""
fg, memo = fgraph_from_model(model)

Expand Down