diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6fd44b0382..f665d5931c 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -22,7 +22,6 @@ import pytensor.tensor as pt import scipy.sparse as sps -from pytensor import scalar from pytensor.compile import Function, Mode, get_mode from pytensor.compile.builders import OpFromGraph from pytensor.gradient import grad @@ -415,31 +414,6 @@ def hessian_diag(f, vars=None, negate_output=True): return empty_gradient -class IdentityOp(scalar.UnaryScalarOp): - @staticmethod - def st_impl(x): - return x - - def impl(self, x): - return x - - def grad(self, inp, grads): - return grads - - def c_code(self, node, name, inp, out, sub): - return f"{out[0]} = {inp[0]};" - - def __eq__(self, other): - return isinstance(self, type(other)) - - def __hash__(self): - return hash(type(self)) - - -scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity") -identity = Elemwise(scalar_identity, name="identity") - - def make_shared_replacements(point, vars, model): """ Make shared replacements for all *other* variables than the ones passed. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index fce64e3b38..b2d643a5f1 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1553,6 +1553,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False @@ -1566,6 +1567,7 @@ def init_nuts( potential = quadpotential.QuadPotentialDiagAdapt( n, mean, cov, weight, rng=random_seed_list[0] ) + elif init == "advi": approx = pm.fit( random_seed=random_seed_list[0], @@ -1575,6 +1577,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False @@ -1592,6 +1595,7 @@ def init_nuts( callbacks=cb, progressbar=progressbar, obj_optimizer=pm.adagrad_window, + compile_kwargs=compile_kwargs, ) approx_sample = approx.sample( draws=chains, random_seed=random_seed_list[0], return_inferencedata=False diff --git a/pymc/variational/inference.py b/pymc/variational/inference.py index 3e2c07788f..29800e0541 100644 --- a/pymc/variational/inference.py +++ b/pymc/variational/inference.py @@ -82,9 +82,18 @@ def _maybe_score(self, score): def run_profiling(self, n=1000, score=None, **kwargs): score = self._maybe_score(score) - fn_kwargs = kwargs.pop("fn_kwargs", {}) - fn_kwargs["profile"] = True - step_func = self.objective.step_function(score=score, fn_kwargs=fn_kwargs, **kwargs) + if "fn_kwargs" in kwargs: + warnings.warn( + "fn_kwargs is deprecated, please use compile_kwargs instead", DeprecationWarning + ) + compile_kwargs = kwargs.pop("fn_kwargs") + else: + compile_kwargs = kwargs.pop("compile_kwargs", {}) + + compile_kwargs["profile"] = True + step_func = self.objective.step_function( + score=score, compile_kwargs=compile_kwargs, **kwargs + ) try: for _ in track(range(n)): step_func() @@ -134,7 +143,7 @@ def fit( Add custom updates to resulting updates total_grad_norm_constraint: `float` Bounds gradient norm, prevents exploding gradient problem - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) more_replacements: `dict` Apply custom replacements before calculating gradients @@ -729,7 +738,7 @@ def fit( Add custom updates to resulting updates total_grad_norm_constraint: `float` Bounds gradient norm, prevents exploding gradient problem - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) more_replacements: `dict` Apply custom replacements before calculating gradients diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 9829ea2c35..034e2fed87 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -61,6 +61,8 @@ from pytensor.graph.basic import Variable from pytensor.graph.replace import graph_replace +from pytensor.scalar.basic import identity as scalar_identity +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import unbroadcast import pymc as pm @@ -74,7 +76,6 @@ SeedSequenceSeed, compile, find_rng_nodes, - identity, reseed_rngs, ) from pymc.util import ( @@ -332,6 +333,7 @@ def step_function( more_replacements=None, total_grad_norm_constraint=None, score=False, + compile_kwargs=None, fn_kwargs=None, ): R"""Step function that should be called on each optimization step. @@ -362,8 +364,13 @@ def step_function( Bounds gradient norm, prevents exploding gradient problem score: `bool` calculate loss on each step? Defaults to False for speed - fn_kwargs: `dict` + compile_kwargs: `dict` Add kwargs to pytensor.function (e.g. `{'profile': True}`) + fn_kwargs: dict + arbitrary kwargs passed to `pytensor.function` + + .. warning:: `fn_kwargs` is deprecated and will be removed in future versions + more_replacements: `dict` Apply custom replacements before calculating gradients @@ -371,8 +378,16 @@ def step_function( ------- `pytensor.function` """ - if fn_kwargs is None: - fn_kwargs = {} + if fn_kwargs is not None: + warnings.warn( + "`fn_kwargs` is deprecated and will be removed in future versions. Use " + "`compile_kwargs` instead.", + DeprecationWarning, + ) + compile_kwargs = fn_kwargs + + if compile_kwargs is None: + compile_kwargs = {} if score and not self.op.returns_loss: raise NotImplementedError(f"{self.op} does not have loss") updates = self.updates( @@ -388,14 +403,14 @@ def step_function( ) seed = self.approx.rng.randint(2**30, dtype=np.int64) if score: - step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs) + step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs) else: - step_fn = compile([], [], updates=updates, random_seed=seed, **fn_kwargs) + step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs) return step_fn @pytensor.config.change_flags(compute_test_value="off") def score_function( - self, sc_n_mc=None, more_replacements=None, fn_kwargs=None + self, sc_n_mc=None, more_replacements=None, compile_kwargs=None, fn_kwargs=None ): # pragma: no cover R"""Compile scoring function that operates which takes no inputs and returns Loss. @@ -405,22 +420,34 @@ def score_function( number of scoring MC samples more_replacements: Apply custom replacements before compiling a function + compile_kwargs: `dict` + arbitrary kwargs passed to `pytensor.function` fn_kwargs: `dict` arbitrary kwargs passed to `pytensor.function` + .. warning:: `fn_kwargs` is deprecated and will be removed in future versions + Returns ------- pytensor.function """ - if fn_kwargs is None: - fn_kwargs = {} + if fn_kwargs is not None: + warnings.warn( + "`fn_kwargs` is deprecated and will be removed in future versions. Use " + "`compile_kwargs` instead", + DeprecationWarning, + ) + compile_kwargs = fn_kwargs + + if compile_kwargs is None: + compile_kwargs = {} if not self.op.returns_loss: raise NotImplementedError(f"{self.op} does not have loss") if more_replacements is None: more_replacements = {} loss = self(sc_n_mc, more_replacements=more_replacements) seed = self.approx.rng.randint(2**30, dtype=np.int64) - return compile([], loss, random_seed=seed, **fn_kwargs) + return compile([], loss, random_seed=seed, **compile_kwargs) @pytensor.config.change_flags(compute_test_value="off") def __call__(self, nmc, **kwargs): @@ -451,7 +478,7 @@ class Operator: require_logq = True objective_class = ObjectiveFunction supports_aevb = property(lambda self: not self.approx.any_histograms) - T = identity + T = Elemwise(scalar_identity) def __init__(self, approx): self.approx = approx