Skip to content

Commit e419d53

Browse files
lucianopazricardoV94
authored andcommitted
Check shared variable values to determine volatility in posterior_predictive_sampling
1 parent 4836bc1 commit e419d53

File tree

3 files changed

+277
-34
lines changed

3 files changed

+277
-34
lines changed

Diff for: pymc/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ def add_coord(
11491149
length = len(values)
11501150
if not isinstance(length, Variable):
11511151
if mutable:
1152-
length = aesara.shared(length)
1152+
length = aesara.shared(length, name=name)
11531153
else:
11541154
length = aesara.tensor.constant(length)
11551155
self._dim_lengths[name] = length

Diff for: pymc/sampling.py

+62-6
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,8 @@ def compile_forward_sampling_function(
16211621
vars_in_trace: List[Variable],
16221622
basic_rvs: Optional[List[Variable]] = None,
16231623
givens_dict: Optional[Dict[Variable, Any]] = None,
1624+
constant_data: Optional[Dict[str, np.ndarray]] = None,
1625+
constant_coords: Optional[Set[str]] = None,
16241626
**kwargs,
16251627
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
16261628
"""Compile a function to draw samples, conditioned on the values of some variables.
@@ -1634,18 +1636,18 @@ def compile_forward_sampling_function(
16341636
compiled function or after inference has been run. These variables are:
16351637
16361638
- Variables in the outputs list
1637-
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1639+
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
16381640
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
16391641
- Variables that are keys in the ``givens_dict``
16401642
- Variables that have volatile inputs
16411643
16421644
Concretely, this function can be used to compile a function to sample from the
16431645
posterior predictive distribution of a model that has variables that are conditioned
1644-
on ``MutableData`` instances. The variables that depend on the mutable data will be
1645-
considered volatile, and as such, they wont be included as inputs into the compiled function.
1646-
This means that if they have values stored in the posterior, these values will be ignored
1647-
and new values will be computed (in the case of deterministics and potentials) or sampled
1648-
(in the case of random variables).
1646+
on ``MutableData`` instances. The variables that depend on the mutable data that have changed
1647+
will be considered volatile, and as such, they wont be included as inputs into the compiled
1648+
function. This means that if they have values stored in the posterior, these values will be
1649+
ignored and new values will be computed (in the case of deterministics and potentials) or
1650+
sampled (in the case of random variables).
16491651
16501652
This function also enables a way to impute values for any variable in the computational
16511653
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
@@ -1672,6 +1674,25 @@ def compile_forward_sampling_function(
16721674
A dictionary that maps tensor variables to the values that should be used to replace them
16731675
in the compiled function. The types of the key and value should match or an error will be
16741676
raised during compilation.
1677+
constant_data : Optional[Dict[str, numpy.ndarray]]
1678+
A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their
1679+
corresponding values at inference time. If a model was created with ``MutableData``, these
1680+
are stored as ``SharedVariable`` with the name of the data variable and a value equal to
1681+
the initial data. At inference time, this information is stored in ``InferenceData``
1682+
objects under the ``constant_data`` group, which allows us to check whether a
1683+
``SharedVariable`` instance changed its values after inference or not. If the values have
1684+
changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then
1685+
the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found
1686+
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
1687+
Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary.
1688+
constant_coords : Optional[Set[str]]
1689+
A set with the names of the mutable coordinates that have not changed their shape after
1690+
inference. If a model was created with mutable coordinates, these are stored as
1691+
``SharedVariable`` with the name of the coordinate and a value equal to the length of said
1692+
coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in
1693+
which case, it is considered volatile. If a ``SharedVariable`` is not found
1694+
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
1695+
Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
16751696
16761697
Returns
16771698
-------
@@ -1687,6 +1708,20 @@ def compile_forward_sampling_function(
16871708
if basic_rvs is None:
16881709
basic_rvs = []
16891710

1711+
if constant_data is None:
1712+
constant_data = {}
1713+
if constant_coords is None:
1714+
constant_coords = set()
1715+
1716+
# We define a helper function to check if shared values match to an array
1717+
def shared_value_matches(var):
1718+
try:
1719+
old_array_value = constant_data[var.name]
1720+
except KeyError:
1721+
return var.name in constant_coords
1722+
current_shared_value = var.get_value(borrow=True)
1723+
return np.array_equal(old_array_value, current_shared_value)
1724+
16901725
# We need a function graph to walk the clients and propagate the volatile property
16911726
fg = FunctionGraph(outputs=outputs, clone=False)
16921727

@@ -1702,6 +1737,7 @@ def compile_forward_sampling_function(
17021737
or ( # SharedVariables, except RandomState/Generators
17031738
isinstance(node, SharedVariable)
17041739
and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
1740+
and not shared_value_matches(node)
17051741
)
17061742
or ( # Basic RVs that are not in the trace
17071743
node in basic_rvs and node not in vars_in_trace
@@ -1835,16 +1871,24 @@ def sample_posterior_predictive(
18351871
idata_kwargs = {}
18361872
else:
18371873
idata_kwargs = idata_kwargs.copy()
1874+
constant_data: Dict[str, np.ndarray] = {}
1875+
trace_coords: Dict[str, np.ndarray] = {}
18381876
if "coords" not in idata_kwargs:
18391877
idata_kwargs["coords"] = {}
18401878
if isinstance(trace, InferenceData):
18411879
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
18421880
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
1881+
_constant_data = getattr(trace, "constant_data", None)
1882+
if _constant_data is not None:
1883+
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
1884+
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
1885+
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
18431886
_trace = dataset_to_point_list(trace["posterior"])
18441887
nchain, len_trace = chains_and_samples(trace)
18451888
elif isinstance(trace, xarray.Dataset):
18461889
idata_kwargs["coords"].setdefault("draw", trace["draw"])
18471890
idata_kwargs["coords"].setdefault("chain", trace["chain"])
1891+
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
18481892
_trace = dataset_to_point_list(trace)
18491893
nchain, len_trace = chains_and_samples(trace)
18501894
elif isinstance(trace, MultiTrace):
@@ -1901,6 +1945,16 @@ def sample_posterior_predictive(
19011945
stacklevel=2,
19021946
)
19031947

1948+
constant_coords = set()
1949+
for dim, coord in trace_coords.items():
1950+
current_coord = model.coords.get(dim, None)
1951+
if (
1952+
current_coord is not None
1953+
and len(coord) == len(current_coord)
1954+
and np.all(coord == current_coord)
1955+
):
1956+
constant_coords.add(dim)
1957+
19041958
if var_names is not None:
19051959
vars_ = [model[x] for x in var_names]
19061960
else:
@@ -1935,6 +1989,8 @@ def sample_posterior_predictive(
19351989
basic_rvs=model.basic_RVs,
19361990
givens_dict=None,
19371991
random_seed=random_seed,
1992+
constant_data=constant_data,
1993+
constant_coords=constant_coords,
19381994
**compile_kwargs,
19391995
)
19401996
sampler_fn = point_wrapper(_sampler_fn)

0 commit comments

Comments
 (0)