|
57 | 57 |
|
58 | 58 | import pymc3 as pm
|
59 | 59 |
|
60 |
| -from pymc3.aesaraf import at_rng, identity |
| 60 | +from pymc3.aesaraf import at_rng, identity, rvs_to_value_vars |
61 | 61 | from pymc3.backends import NDArray
|
62 | 62 | from pymc3.model import modelcontext
|
63 | 63 | from pymc3.util import (
|
@@ -851,8 +851,9 @@ def __init__(
|
851 | 851 | self.group = group
|
852 | 852 | self.user_params = params
|
853 | 853 | self._user_params = None
|
854 |
| - self.replacements = dict() |
855 |
| - self.ordering = dict() |
| 854 | + self.replacements = collections.OrderedDict() |
| 855 | + self.value_replacements = collections.OrderedDict() |
| 856 | + self.ordering = collections.OrderedDict() |
856 | 857 | # save this stuff to use in __init_group__ later
|
857 | 858 | self._kwargs = kwargs
|
858 | 859 | if self.group is not None:
|
@@ -958,7 +959,8 @@ def __init_group__(self, group):
|
958 | 959 |
|
959 | 960 | # 1) we need initial point (transformed space)
|
960 | 961 | model_initial_point = self.model.initial_point
|
961 |
| - |
| 962 | + _, replace_to_value_vars = rvs_to_value_vars(self.group, apply_transforms=True) |
| 963 | + self.value_replacements.update(replace_to_value_vars) |
962 | 964 | # 2) we'll work with a single group, a subset of the model
|
963 | 965 | # here we need to create a mapping to replace value_vars with slices from the approximation
|
964 | 966 | start_idx = 0
|
@@ -989,14 +991,14 @@ def __init_group__(self, group):
|
989 | 991 | dtype = test_var.dtype
|
990 | 992 | size = test_var.size
|
991 | 993 | # TODO: There was self.ordering used in other util funcitons
|
992 |
| - vr = self.input[..., start_idx:start_idx+size].reshape(shape).astype(dtype) |
| 994 | + vr = self.input[..., start_idx : start_idx + size].reshape(shape).astype(dtype) |
993 | 995 | vr.name = value_var.name + "_vi_replacement"
|
994 | 996 | self.replacements[value_var] = vr
|
995 | 997 | self.ordering[value_var.name] = (
|
996 | 998 | value_var.name,
|
997 |
| - slice(start_idx, start_idx+size), |
| 999 | + slice(start_idx, start_idx + size), |
998 | 1000 | shape,
|
999 |
| - dtype |
| 1001 | + dtype, |
1000 | 1002 | )
|
1001 | 1003 | start_idx += size
|
1002 | 1004 |
|
@@ -1166,6 +1168,7 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
|
1166 | 1168 |
|
1167 | 1169 | def to_flat_input(self, node):
|
1168 | 1170 | """*Dev* - replace vars with flattened view stored in `self.inputs`"""
|
| 1171 | + node = aesara.clone_replace(node, self.value_replacements) |
1169 | 1172 | return aesara.clone_replace(node, self.replacements)
|
1170 | 1173 |
|
1171 | 1174 | def symbolic_sample_over_posterior(self, node):
|
@@ -1468,6 +1471,13 @@ def datalogp_norm(self):
|
1468 | 1471 | """*Dev* - normalized :math:`E_{q}(data term)`"""
|
1469 | 1472 | return self.datalogp / self.symbolic_normalizing_constant
|
1470 | 1473 |
|
| 1474 | + @property |
| 1475 | + def value_replacements(self): |
| 1476 | + """*Dev* - all replacements from groups to replace PyMC random variables with approximation""" |
| 1477 | + return collections.OrderedDict( |
| 1478 | + itertools.chain.from_iterable(g.value_replacements.items() for g in self.groups) |
| 1479 | + ) |
| 1480 | + |
1471 | 1481 | @property
|
1472 | 1482 | def replacements(self):
|
1473 | 1483 | """*Dev* - all replacements from groups to replace PyMC random variables with approximation"""
|
@@ -1528,28 +1538,30 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
|
1528 | 1538 | try_to_set_test_value(_node, node, s)
|
1529 | 1539 | return node
|
1530 | 1540 |
|
1531 |
| - def to_flat_input(self, node): |
| 1541 | + def to_flat_input(self, node, more_replacements=None): |
1532 | 1542 | """*Dev* - replace vars with flattened view stored in `self.inputs`"""
|
| 1543 | + more_replacements = more_replacements or {} |
| 1544 | + node = aesara.clone_replace(node, {**self.value_replacements, **more_replacements}) |
1533 | 1545 | return aesara.clone_replace(node, self.replacements)
|
1534 | 1546 |
|
1535 |
| - def symbolic_sample_over_posterior(self, node): |
| 1547 | + def symbolic_sample_over_posterior(self, node, more_replacements=None): |
1536 | 1548 | """*Dev* - performs sampling of node applying independent samples from posterior each time.
|
1537 | 1549 | Note that it is done symbolically and this node needs :func:`set_size_and_deterministic` call
|
1538 | 1550 | """
|
1539 |
| - node = self.to_flat_input(node) |
| 1551 | + node = self.to_flat_input(node, more_replacements=more_replacements) |
1540 | 1552 |
|
1541 | 1553 | def sample(*post):
|
1542 | 1554 | return aesara.clone_replace(node, dict(zip(self.inputs, post)))
|
1543 | 1555 |
|
1544 | 1556 | nodes, _ = aesara.scan(sample, self.symbolic_randoms)
|
1545 | 1557 | return nodes
|
1546 | 1558 |
|
1547 |
| - def symbolic_single_sample(self, node): |
| 1559 | + def symbolic_single_sample(self, node, more_replacements=None): |
1548 | 1560 | """*Dev* - performs sampling of node applying single sample from posterior.
|
1549 | 1561 | Note that it is done symbolically and this node needs
|
1550 | 1562 | :func:`set_size_and_deterministic` call with `size=1`
|
1551 | 1563 | """
|
1552 |
| - node = self.to_flat_input(node) |
| 1564 | + node = self.to_flat_input(node, more_replacements=more_replacements) |
1553 | 1565 | post = [v[0] for v in self.symbolic_randoms]
|
1554 | 1566 | inp = self.inputs
|
1555 | 1567 | return aesara.clone_replace(node, dict(zip(inp, post)))
|
@@ -1585,12 +1597,13 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
|
1585 | 1597 | sampled node(s) with replacements
|
1586 | 1598 | """
|
1587 | 1599 | node_in = node
|
1588 |
| - node = aesara.clone_replace(node, more_replacements) |
1589 | 1600 | if size is None:
|
1590 |
| - node_out = self.symbolic_single_sample(node) |
| 1601 | + node_out = self.symbolic_single_sample(node, more_replacements=more_replacements) |
1591 | 1602 | else:
|
1592 |
| - node_out = self.symbolic_sample_over_posterior(node) |
1593 |
| - node_out = self.set_size_and_deterministic(node_out, size, deterministic, more_replacements) |
| 1603 | + node_out = self.symbolic_sample_over_posterior( |
| 1604 | + node, more_replacements=more_replacements |
| 1605 | + ) |
| 1606 | + node_out = self.set_size_and_deterministic(node_out, size, deterministic) |
1594 | 1607 | try_to_set_test_value(node_in, node_out, size)
|
1595 | 1608 | return node_out
|
1596 | 1609 |
|
|
0 commit comments