Skip to content

Commit f00e814

Browse files
committed
refactor cloning routines
1 parent 573162a commit f00e814

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

pymc3/variational/approximations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def mean(self):
161161
@node_property
162162
def cov(self):
163163
L = self.L
164-
if self.batched:
164+
if self.batched:
165165
return at.batched_dot(L, L.swapaxes(-1, -2))
166166
else:
167167
return L.dot(L.T)
@@ -187,6 +187,7 @@ def symbolic_logq_not_scaled(self):
187187
z0 = self.symbolic_initial
188188
if self.batched:
189189
raise NotImplementedError
190+
190191
def logq(z_b, mu_b, L_b):
191192
return pm.MvNormal.dist(mu=mu_b, chol=L_b).logp(z_b)
192193

@@ -199,7 +200,6 @@ def logq(z_b, mu_b, L_b):
199200
logdet = at.sum(at.diagonal(self.L, 0, self.L.ndim - 2, self.L.ndim - 1), axis=-1)
200201
logq = pm.Normal.logp(z0, 0, 1) - logdet
201202
return logq.sum(range(1, logq.ndim))
202-
203203

204204
@node_property
205205
def symbolic_random(self):

pymc3/variational/opvi.py

+29-16
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
import pymc3 as pm
5959

60-
from pymc3.aesaraf import at_rng, identity
60+
from pymc3.aesaraf import at_rng, identity, rvs_to_value_vars
6161
from pymc3.backends import NDArray
6262
from pymc3.model import modelcontext
6363
from pymc3.util import (
@@ -851,8 +851,9 @@ def __init__(
851851
self.group = group
852852
self.user_params = params
853853
self._user_params = None
854-
self.replacements = dict()
855-
self.ordering = dict()
854+
self.replacements = collections.OrderedDict()
855+
self.value_replacements = collections.OrderedDict()
856+
self.ordering = collections.OrderedDict()
856857
# save this stuff to use in __init_group__ later
857858
self._kwargs = kwargs
858859
if self.group is not None:
@@ -958,7 +959,8 @@ def __init_group__(self, group):
958959

959960
# 1) we need initial point (transformed space)
960961
model_initial_point = self.model.initial_point
961-
962+
_, replace_to_value_vars = rvs_to_value_vars(self.group, apply_transforms=True)
963+
self.value_replacements.update(replace_to_value_vars)
962964
# 2) we'll work with a single group, a subset of the model
963965
# here we need to create a mapping to replace value_vars with slices from the approximation
964966
start_idx = 0
@@ -989,14 +991,14 @@ def __init_group__(self, group):
989991
dtype = test_var.dtype
990992
size = test_var.size
991993
# TODO: There was self.ordering used in other util funcitons
992-
vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype)
994+
vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype)
993995
vr.name = value_var.name + "_vi_replacement"
994996
self.replacements[value_var] = vr
995997
self.ordering[value_var.name] = (
996998
value_var.name,
997-
slice(start_idx, start_idx+size),
999+
slice(start_idx, start_idx + size),
9981000
shape,
999-
dtype
1001+
dtype,
10001002
)
10011003
start_idx += size
10021004

@@ -1166,6 +1168,7 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
11661168

11671169
def to_flat_input(self, node):
11681170
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
1171+
node = aesara.clone_replace(node, self.value_replacements)
11691172
return aesara.clone_replace(node, self.replacements)
11701173

11711174
def symbolic_sample_over_posterior(self, node):
@@ -1468,6 +1471,13 @@ def datalogp_norm(self):
14681471
"""*Dev* - normalized :math:`E_{q}(data term)`"""
14691472
return self.datalogp / self.symbolic_normalizing_constant
14701473

1474+
@property
1475+
def value_replacements(self):
1476+
"""*Dev* - all replacements from groups to replace PyMC random variables with approximation"""
1477+
return collections.OrderedDict(
1478+
itertools.chain.from_iterable(g.value_replacements.items() for g in self.groups)
1479+
)
1480+
14711481
@property
14721482
def replacements(self):
14731483
"""*Dev* - all replacements from groups to replace PyMC random variables with approximation"""
@@ -1528,28 +1538,30 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
15281538
try_to_set_test_value(_node, node, s)
15291539
return node
15301540

1531-
def to_flat_input(self, node):
1541+
def to_flat_input(self, node, more_replacements=None):
15321542
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
1543+
more_replacements = more_replacements or {}
1544+
node = aesara.clone_replace(node, {**self.value_replacements, **more_replacements})
15331545
return aesara.clone_replace(node, self.replacements)
15341546

1535-
def symbolic_sample_over_posterior(self, node):
1547+
def symbolic_sample_over_posterior(self, node, more_replacements=None):
15361548
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
15371549
Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call
15381550
"""
1539-
node = self.to_flat_input(node)
1551+
node = self.to_flat_input(node, more_replacements=more_replacements)
15401552

15411553
def sample(*post):
15421554
return aesara.clone_replace(node, dict(zip(self.inputs, post)))
15431555

15441556
nodes, _ = aesara.scan(sample, self.symbolic_randoms)
15451557
return nodes
15461558

1547-
def symbolic_single_sample(self, node):
1559+
def symbolic_single_sample(self, node, more_replacements=None):
15481560
"""*Dev* - performs sampling of node applying single sample from posterior.
15491561
Note that it is done symbolically and this node needs
15501562
:func:`set_size_and_deterministic` call with `size=1`
15511563
"""
1552-
node = self.to_flat_input(node)
1564+
node = self.to_flat_input(node, more_replacements=more_replacements)
15531565
post = [v[0] for v in self.symbolic_randoms]
15541566
inp = self.inputs
15551567
return aesara.clone_replace(node, dict(zip(inp, post)))
@@ -1585,12 +1597,13 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
15851597
sampled node(s) with replacements
15861598
"""
15871599
node_in = node
1588-
node = aesara.clone_replace(node, more_replacements)
15891600
if size is None:
1590-
node_out = self.symbolic_single_sample(node)
1601+
node_out = self.symbolic_single_sample(node, more_replacements=more_replacements)
15911602
else:
1592-
node_out = self.symbolic_sample_over_posterior(node)
1593-
node_out = self.set_size_and_deterministic(node_out, size, deterministic, more_replacements)
1603+
node_out = self.symbolic_sample_over_posterior(
1604+
node, more_replacements=more_replacements
1605+
)
1606+
node_out = self.set_size_and_deterministic(node_out, size, deterministic)
15941607
try_to_set_test_value(node_in, node_out, size)
15951608
return node_out
15961609

0 commit comments

Comments
 (0)