diff --git a/docs/source/api/model.rst b/docs/source/api/model.rst index 6b7e93e2ad..1e674dc208 100644 --- a/docs/source/api/model.rst +++ b/docs/source/api/model.rst @@ -7,5 +7,6 @@ Model :maxdepth: 2 model/core - model/transform + model/conditioning + model/optimization model/fgraph diff --git a/docs/source/api/model/transform.rst b/docs/source/api/model/conditioning.rst similarity index 56% rename from docs/source/api/model/transform.rst rename to docs/source/api/model/conditioning.rst index 3e83176b17..c84dceb924 100644 --- a/docs/source/api/model/transform.rst +++ b/docs/source/api/model/conditioning.rst @@ -9,12 +9,3 @@ Model Conditioning observe change_value_transforms remove_value_transforms - - -Model Optimization ------------------- -.. currentmodule:: pymc.model.transform.optimization -.. autosummary:: - :toctree: generated/ - - freeze_dims_and_data diff --git a/docs/source/api/model/optimization.rst b/docs/source/api/model/optimization.rst new file mode 100644 index 0000000000..eb208cd4d2 --- /dev/null +++ b/docs/source/api/model/optimization.rst @@ -0,0 +1,7 @@ +Model Optimization +------------------ +.. currentmodule:: pymc.model.transform.optimization +.. autosummary:: + :toctree: generated/ + + freeze_dims_and_data diff --git a/pymc/model/transform/optimization.py b/pymc/model/transform/optimization.py index bcf828ba3e..187e4ee444 100644 --- a/pymc/model/transform/optimization.py +++ b/pymc/model/transform/optimization.py @@ -57,6 +57,28 @@ def freeze_dims_and_data( ------- Model A new model with the specified dimensions and data frozen. + + + Examples + -------- + .. code-block:: python + + import pymc as pm + import pytensor.tensor as pt + + from pymc.model.transform.optimization import freeze_dims_and_data + + with pm.Model() as m: + x = pm.Data("x", [0, 1, 2] * 1000) + y = pm.Normal("y", mu=pt.unique(x).mean()) + + # pt.unique(x).mean() has to be computed in every logp function evaluation + print("Logp eval time (1000x): ", m.profile(m.logp()).fct_call_time) + + # pt.uniqe(x).mean() is cached in the logp function + frozen_m = freeze_dims_and_data(m) + print("Logp eval time (1000x): ", frozen_m.profile(frozen_m.logp()).fct_call_time) + """ fg, memo = fgraph_from_model(model)