Skip to content

Foward compile_kwargs to ADVI when init = "advi+..." #7640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor import scalar
from pytensor.compile import Function, Mode, get_mode
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import grad
Expand Down Expand Up @@ -415,31 +414,6 @@ def hessian_diag(f, vars=None, negate_output=True):
return empty_gradient


class IdentityOp(scalar.UnaryScalarOp):
@staticmethod
def st_impl(x):
return x

def impl(self, x):
return x

def grad(self, inp, grads):
return grads

def c_code(self, node, name, inp, out, sub):
return f"{out[0]} = {inp[0]};"

def __eq__(self, other):
return isinstance(self, type(other))

def __hash__(self):
return hash(type(self))


scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
identity = Elemwise(scalar_identity, name="identity")


def make_shared_replacements(point, vars, model):
"""
Make shared replacements for all *other* variables than the ones passed.
Expand Down
4 changes: 4 additions & 0 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand All @@ -1566,6 +1567,7 @@ def init_nuts(
potential = quadpotential.QuadPotentialDiagAdapt(
n, mean, cov, weight, rng=random_seed_list[0]
)

elif init == "advi":
approx = pm.fit(
random_seed=random_seed_list[0],
Expand All @@ -1575,6 +1577,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand All @@ -1592,6 +1595,7 @@ def init_nuts(
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
compile_kwargs=compile_kwargs,
)
approx_sample = approx.sample(
draws=chains, random_seed=random_seed_list[0], return_inferencedata=False
Expand Down
19 changes: 14 additions & 5 deletions pymc/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,18 @@

def run_profiling(self, n=1000, score=None, **kwargs):
score = self._maybe_score(score)
fn_kwargs = kwargs.pop("fn_kwargs", {})
fn_kwargs["profile"] = True
step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs)
if "fn_kwargs" in kwargs:
warnings.warn(

Check warning on line 86 in pymc/variational/inference.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/inference.py#L86

Added line #L86 was not covered by tests
"fn_kwargs is deprecated, please use compile_kwargs instead", DeprecationWarning
)
compile_kwargs = kwargs.pop("fn_kwargs")

Check warning on line 89 in pymc/variational/inference.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/inference.py#L89

Added line #L89 was not covered by tests
else:
compile_kwargs = kwargs.pop("compile_kwargs", {})

compile_kwargs["profile"] = True
step_func = self.objective.step_function(
score=score, compile_kwargs=compile_kwargs, **kwargs
)
try:
for _ in track(range(n)):
step_func()
Expand Down Expand Up @@ -134,7 +143,7 @@
Add custom updates to resulting updates
total_grad_norm_constraint: `float`
Bounds gradient norm, prevents exploding gradient problem
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
more_replacements: `dict`
Apply custom replacements before calculating gradients
Expand Down Expand Up @@ -729,7 +738,7 @@
Add custom updates to resulting updates
total_grad_norm_constraint: `float`
Bounds gradient norm, prevents exploding gradient problem
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
more_replacements: `dict`
Apply custom replacements before calculating gradients
Expand Down
49 changes: 38 additions & 11 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@

from pytensor.graph.basic import Variable
from pytensor.graph.replace import graph_replace
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import unbroadcast

import pymc as pm
Expand All @@ -74,7 +76,6 @@
SeedSequenceSeed,
compile,
find_rng_nodes,
identity,
reseed_rngs,
)
from pymc.util import (
Expand Down Expand Up @@ -332,6 +333,7 @@
more_replacements=None,
total_grad_norm_constraint=None,
score=False,
compile_kwargs=None,
fn_kwargs=None,
):
R"""Step function that should be called on each optimization step.
Expand Down Expand Up @@ -362,17 +364,30 @@
Bounds gradient norm, prevents exploding gradient problem
score: `bool`
calculate loss on each step? Defaults to False for speed
fn_kwargs: `dict`
compile_kwargs: `dict`
Add kwargs to pytensor.function (e.g. `{'profile': True}`)
fn_kwargs: dict
arbitrary kwargs passed to `pytensor.function`

.. warning:: `fn_kwargs` is deprecated and will be removed in future versions

more_replacements: `dict`
Apply custom replacements before calculating gradients

Returns
-------
`pytensor.function`
"""
if fn_kwargs is None:
fn_kwargs = {}
if fn_kwargs is not None:
warnings.warn(

Check warning on line 382 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L382

Added line #L382 was not covered by tests
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
"`compile_kwargs` instead.",
DeprecationWarning,
)
compile_kwargs = fn_kwargs

Check warning on line 387 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L387

Added line #L387 was not covered by tests

if compile_kwargs is None:
compile_kwargs = {}
if score and not self.op.returns_loss:
raise NotImplementedError(f"{self.op} does not have loss")
updates = self.updates(
Expand All @@ -388,14 +403,14 @@
)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
if score:
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
else:
step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs)
step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)
return step_fn

@pytensor.config.change_flags(compute_test_value="off")
def score_function(
self, sc_n_mc=None, more_replacements=None, fn_kwargs=None
self, sc_n_mc=None, more_replacements=None, compile_kwargs=None, fn_kwargs=None
): # pragma: no cover
R"""Compile scoring function that operates which takes no inputs and returns Loss.

Expand All @@ -405,22 +420,34 @@
number of scoring MC samples
more_replacements:
Apply custom replacements before compiling a function
compile_kwargs: `dict`
arbitrary kwargs passed to `pytensor.function`
fn_kwargs: `dict`
arbitrary kwargs passed to `pytensor.function`

.. warning:: `fn_kwargs` is deprecated and will be removed in future versions

Returns
-------
pytensor.function
"""
if fn_kwargs is None:
fn_kwargs = {}
if fn_kwargs is not None:
warnings.warn(

Check warning on line 435 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L434-L435

Added lines #L434 - L435 were not covered by tests
"`fn_kwargs` is deprecated and will be removed in future versions. Use "
"`compile_kwargs` instead",
DeprecationWarning,
)
compile_kwargs = fn_kwargs

Check warning on line 440 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L440

Added line #L440 was not covered by tests

if compile_kwargs is None:
compile_kwargs = {}

Check warning on line 443 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L442-L443

Added lines #L442 - L443 were not covered by tests
if not self.op.returns_loss:
raise NotImplementedError(f"{self.op} does not have loss")
if more_replacements is None:
more_replacements = {}
loss = self(sc_n_mc, more_replacements=more_replacements)
seed = self.approx.rng.randint(2**30, dtype=np.int64)
return compile([], loss, random_seed=seed, **fn_kwargs)
return compile([], loss, random_seed=seed, **compile_kwargs)

Check warning on line 450 in pymc/variational/opvi.py

View check run for this annotation

Codecov / codecov/patch

pymc/variational/opvi.py#L450

Added line #L450 was not covered by tests

@pytensor.config.change_flags(compute_test_value="off")
def __call__(self, nmc, **kwargs):
Expand Down Expand Up @@ -451,7 +478,7 @@
require_logq = True
objective_class = ObjectiveFunction
supports_aevb = property(lambda self: not self.approx.any_histograms)
T = identity
T = Elemwise(scalar_identity)

def __init__(self, approx):
self.approx = approx
Expand Down
Loading