Skip to content

Commit 1bd4fda

Browse files
committed
add test for start and start_sigma plus minor fixes
1 parent e9d7757 commit 1bd4fda

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

pymc/tests/test_variational_inference.py

+22
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,28 @@ def test_fit_oo(inference, fit_kwargs, simple_model_data):
571571
np.testing.assert_allclose(np.std(trace.posterior["mu"]), np.sqrt(1.0 / d), rtol=0.2)
572572

573573

574+
def test_fit_start(inference_spec, simple_model):
575+
mu_init = 17
576+
mu_sigma_init = 13
577+
578+
with simple_model:
579+
if type(inference_spec()) == ADVI:
580+
has_start_sigma = True
581+
else:
582+
has_start_sigma = False
583+
584+
kw = {"start": {"mu": mu_init}}
585+
if has_start_sigma:
586+
kw.update({"start_sigma": {"mu": mu_sigma_init}})
587+
588+
with simple_model:
589+
inference = inference_spec(**kw)
590+
trace = inference.fit(n=0).sample(10000)
591+
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
592+
if has_start_sigma:
593+
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
594+
595+
574596
def test_profile(inference):
575597
inference.run_profiling(n=100).summary()
576598

pymc/variational/approximations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def create_shared_params(self, start=None, start_sigma=None):
8888
def _prepare_start_sigma(self, start_sigma):
8989
rho = np.zeros((self.ddim,))
9090
if start_sigma is not None:
91-
for name, slice_, *_ in self.ordering.items():
91+
for name, slice_, *_ in self.ordering.values():
9292
sigma = start_sigma.get(name)
9393
if sigma is not None:
9494
rho[slice_] = np.log(np.exp(np.abs(sigma)) - 1.0)

pymc/variational/inference.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def _infmean(input_array):
257257
)
258258
)
259259
else:
260-
if n < 10:
260+
if n == 0:
261+
logger.info(f"Initialization only")
262+
elif n < 10:
261263
logger.info(f"Finished [100%]: Loss = {scores[-1]:,.5g}")
262264
else:
263265
avg_loss = _infmean(scores[max(0, i - 1000) : i + 1])
@@ -466,7 +468,7 @@ class FullRankADVI(KLqp):
466468
random_seed: None or int
467469
leave None to use package global RandomStream or other
468470
valid value to create instance specific one
469-
start: `Point`
471+
start: `dict[str, np.ndarray]` or `StartDict`
470472
starting point for inference
471473
472474
References
@@ -534,13 +536,11 @@ class SVGD(ImplicitGradient):
534536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
535537
temperature: float
536538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
537-
start: `dict`
539+
start: `dict[str, np.ndarray]` or `StartDict`
538540
initial point for inference
539541
random_seed: None or int
540542
leave None to use package global RandomStream or other
541543
valid value to create instance specific one
542-
start: `Point`
543-
starting point for inference
544544
kwargs: other keyword arguments passed to estimator
545545
546546
References
@@ -631,7 +631,11 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
631631
"is often **underestimated** when using temperature = 1."
632632
)
633633
if approx is None:
634-
approx = FullRank(model=kwargs.pop("model", None))
634+
approx = FullRank(
635+
model=kwargs.pop("model", None),
636+
random_seed=kwargs.pop("random_seed", None),
637+
start=kwargs.pop("start", None),
638+
)
635639
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)
636640

637641
def fit(

0 commit comments

Comments
 (0)