Skip to content

Mixture random cleanup #3364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 16, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class _DrawValuesContextBlocker(_DrawValuesContext, metaclass=InitContextMeta):
"""
def __new__(cls, *args, **kwargs):
# resolves the parent instance
instance = super(_DrawValuesContextBlocker, cls).__new__(cls)
instance = super().__new__(cls)
instance._parent = None
return instance

Expand Down Expand Up @@ -639,20 +639,30 @@ def generate_samples(generator, *args, **kwargs):
samples = generator(size=broadcast_shape, *args, **kwargs)
elif dist_shape == broadcast_shape:
samples = generator(size=size_tup + dist_shape, *args, **kwargs)
elif len(dist_shape) == 0 and size_tup and broadcast_shape[:len(size_tup)] == size_tup:
# Input's dist_shape is scalar, but it has size repetitions.
# So now the size matches but we have to manually broadcast to
# the right dist_shape
samples = [generator(*args, **kwargs)]
if samples[0].shape == broadcast_shape:
samples = samples[0]
elif len(dist_shape) == 0 and size_tup and broadcast_shape:
# There is no dist_shape (scalar distribution) but the parameters
# broadcast shape and size_tup determine the size to provide to
# the generator
if broadcast_shape[:len(size_tup)] == size_tup:
# Input's dist_shape is scalar, but it has size repetitions.
# So now the size matches but we have to manually broadcast to
# the right dist_shape
samples = [generator(*args, **kwargs)]
if samples[0].shape == broadcast_shape:
samples = samples[0]
else:
suffix = broadcast_shape[len(size_tup):] + dist_shape
samples.extend([generator(*args, **kwargs).
reshape(broadcast_shape)[..., np.newaxis]
for _ in range(np.prod(suffix,
dtype=int) - 1)])
samples = np.hstack(samples).reshape(size_tup + suffix)
else:
suffix = broadcast_shape[len(size_tup):] + dist_shape
samples.extend([generator(*args, **kwargs).
reshape(broadcast_shape)[..., np.newaxis]
for _ in range(np.prod(suffix,
dtype=int) - 1)])
samples = np.hstack(samples).reshape(size_tup + suffix)
# The parameter shape is given, but we have to concatenate it
# with the size tuple
samples = generator(size=size_tup + broadcast_shape,
*args,
**kwargs)
else:
samples = None
# Args have been broadcast correctly, can just ask for the right shape out
Expand Down
217 changes: 157 additions & 60 deletions pymc3/distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from collections.abc import Iterable
import numpy as np
import theano
import theano.tensor as tt

from pymc3.util import get_variable_name
from ..math import logsumexp
from .dist_math import bound, random_choice
from .distribution import (Discrete, Distribution, draw_values,
generate_samples, _DrawValuesContext,
_DrawValuesContextBlocker, to_tuple)
_DrawValuesContextBlocker, to_tuple,
broadcast_distribution_samples)
from .continuous import get_tau_sigma, Normal
from ..theanof import _conversion_map


def all_discrete(comp_dists):
Expand Down Expand Up @@ -79,9 +83,9 @@ def __init__(self, w, comp_dists, *args, **kwargs):
defaults = kwargs.pop('defaults', [])

if all_discrete(comp_dists):
dtype = kwargs.pop('dtype', 'int64')
default_dtype = _conversion_map[theano.config.floatX]
else:
dtype = kwargs.pop('dtype', 'float64')
default_dtype = theano.config.floatX

try:
self.mean = (w * self._comp_means()).sum(axis=-1)
Expand All @@ -90,6 +94,7 @@ def __init__(self, w, comp_dists, *args, **kwargs):
defaults.append('mean')
except AttributeError:
pass
dtype = kwargs.pop('dtype', default_dtype)

try:
comp_modes = self._comp_modes()
Expand All @@ -108,29 +113,37 @@ def comp_dists(self):
return self._comp_dists

@comp_dists.setter
def comp_dists(self, _comp_dists):
self._comp_dists = _comp_dists
# Tests if the comp_dists can call random with non None size
with _DrawValuesContextBlocker():
if isinstance(self.comp_dists, (list, tuple)):
try:
[comp_dist.random(size=23)
for comp_dist in self.comp_dists]
self._comp_dists_vect = True
except Exception:
# The comp_dists cannot call random with non None size or
# without knowledge of the point so we assume that we will
# have to iterate calls to random to get the correct size
self._comp_dists_vect = False
else:
try:
self.comp_dists.random(size=23)
self._comp_dists_vect = True
except Exception:
# The comp_dists cannot call random with non None size or
# without knowledge of the point so we assume that we will
# have to iterate calls to random to get the correct size
self._comp_dists_vect = False
def comp_dists(self, comp_dists):
if isinstance(comp_dists, Distribution):
self._comp_dists = comp_dists
self._comp_dist_shapes = to_tuple(comp_dists.shape)
self._broadcast_shape = self._comp_dist_shapes
self.is_multidim_comp = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit confusion - this is only to distinguish between the comp being a list or a distribution right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're totally right. I'll change this

elif isinstance(comp_dists, Iterable):
if not all((isinstance(comp_dist, Distribution)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check should be move to the __init__, maybe it could also save a few lines in the if/else here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I can do that. I thought that it would be a good check to make whenever comp_dists' value was set. I know that we only really change comp_dists during __init__, but I thought that maybe some users could potentially change its values in derived classes or things like that.

for comp_dist in comp_dists)):
raise TypeError('Supplied Mixture comp_dists must be a '
'Distribution or an iterable of '
'Distributions.')
self._comp_dists = comp_dists
# Now we check the comp_dists distribution shape, see what
# the broadcast shape would be. This shape will be the dist_shape
# used by generate samples (the shape of a single random sample)
# from the mixture
self._comp_dist_shapes = [to_tuple(d.shape) for d in comp_dists]
# All component distributions must broadcast with each other
try:
self._broadcast_shape = np.broadcast(
*[np.empty(shape) for shape in self._comp_dist_shapes]
).shape
except Exception:
raise TypeError('Supplied comp_dists shapes do not broadcast '
'with each other. comp_dists shapes are: '
'{}'.format(self._comp_dist_shapes))
self.is_multidim_comp = False
else:
raise TypeError('Cannot handle supplied comp_dist type {}'
.format(type(comp_dists)))

def _comp_logp(self, value):
comp_dists = self.comp_dists
Expand Down Expand Up @@ -160,35 +173,100 @@ def _comp_modes(self):
for comp_dist in self.comp_dists],
axis=1))

def _comp_samples(self, point=None, size=None):
if self._comp_dists_vect or size is None:
try:
return self.comp_dists.random(point=point, size=size)
except AttributeError:
samples = np.array([comp_dist.random(point=point, size=size)
for comp_dist in self.comp_dists])
samples = np.moveaxis(samples, 0, samples.ndim - 1)
def _comp_samples(self, point=None, size=None,
comp_dist_shapes=None,
broadcast_shape=None):
if self.is_multidim_comp:
samples = self._comp_dists.random(point=point, size=size)
else:
# We must iterate the calls to random manually
size = to_tuple(size)
_size = int(np.prod(size))
try:
samples = np.array([self.comp_dists.random(point=point,
size=None)
for _ in range(_size)])
samples = np.reshape(samples, size + samples.shape[1:])
except AttributeError:
samples = np.array([[comp_dist.random(point=point, size=None)
for _ in range(_size)]
for comp_dist in self.comp_dists])
samples = np.moveaxis(samples, 0, samples.ndim - 1)
samples = np.reshape(samples, size + samples[1:])

if comp_dist_shapes is None:
comp_dist_shapes = self._comp_dist_shapes
if broadcast_shape is None:
broadcast_shape = self._sample_shape
samples = []
for dist_shape, comp_dist in zip(comp_dist_shapes,
self.comp_dists):
def generator(*args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not joking that this is a horrible hack... I think it might be better to overwrite the comp_dist.random method (it's at least more explicit)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean just changing the comp_dists[i].random method somewhere else? Maybe I could add a _generators attribute during comp_dists.setter, with a wrapper of comp_dists[i].random that takes a _raw_shape parameter too. I'll try it out.

# The distribution random methods use the size argument
# differently from scipy.*.rvs, and generate_samples
# follows the latter usage pattern. For this reason we
# decorate (horribly hack) the size kwarg of
# comp_dist.random. We also have to disable pylint W0640
# because comp_dist is changed at each iteration of the
# for loop, and this generator function must be defined
# for each comp_dist.
# pylint: disable=W0640
if len(args) > 2:
args[1] = size
else:
kwargs['size'] = size
return comp_dist.random(*args, **kwargs)
sample = generate_samples(
generator=generator,
dist_shape=dist_shape,
broadcast_shape=broadcast_shape,
point=point,
size=size,
)
samples.append(sample)
samples = np.array(
broadcast_distribution_samples(samples, size=size)
)
# In the logp we assume the last axis holds the mixture components
# so we move the axis to the last dimension
samples = np.moveaxis(samples, 0, -1)
if samples.shape[-1] == 1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if statement is legacy code that I don't really understand. Why should we do this test, wouldn't it be done by Mixture.random?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont remember as well, but at some point there is an error with shape = (100, 1) and shape = (100, ) - not even sure we have a test for that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So error happens when we try to call random from the mixture and comp_dists has shape (100, 1) but Mixture has shape (100,)? I'm not sure how the error goes. Did it happen when comp_dists was a multidimensional distribution or when it was a list of distributions? I'd appreciate any thoughts on extra tests to write.

return samples[..., 0]
else:
return samples

def infer_comp_dist_shapes(self, point=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is better to call this once in __init__ and save the outputs as properties for the mixture class? What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the comp_dist shapes may depend on the values passed with the point dictionary. It's the case for sample_posterior_predictive at least.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, good point.

if self.is_multidim_comp:
if len(self._comp_dist_shapes) > 0:
comp_dist_shapes = self._comp_dist_shapes
else:
# Happens when the distribution is a scalar or when it was not
# given a shape. In these cases we try to draw a single value
# to check its shape, we use the provided point dictionary
# hoping that it can circumvent the Flat and HalfFlat
# undrawable distributions.
with _DrawValuesContextBlocker():
test_sample = self._comp_dists.random(point=point,
size=None)
comp_dist_shapes = test_sample.shape
broadcast_shape = comp_dist_shapes
includes_mixture = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not complete sure what is the function of includes_mixture here - I imagine is something to do with making sure the broadcasting of w is correct. Could you add some small docstrings to this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Yes, I'll add a docstring. includes_mixture maybe was not the best choice of name, and maybe is redundant. It's a bool that indicates whether the comp_dists last axis has the mixture components or not, but that's the same as testing if comp_dists is a distribution or a list. Hmm, I'll see how not to depend on it.

else:
# Now we check the comp_dists distribution shape, see what
# the broadcast shape would be. This shape will be the dist_shape
# used by generate samples (the shape of a single random sample)
# from the mixture
comp_dist_shapes = []
for dist_shape, comp_dist in zip(self._comp_dist_shapes,
self._comp_dists):
if dist_shape == tuple():
# Happens when the distribution is a scalar or when it was
# not given a shape. In these cases we try to draw a single
# value to check its shape, we use the provided point
# dictionary hoping that it can circumvent the Flat and
# HalfFlat undrawable distributions.
with _DrawValuesContextBlocker():
test_sample = comp_dist.random(point=point,
size=None)
dist_shape = test_sample.shape
comp_dist_shapes.append(dist_shape)
# All component distributions must broadcast with each other
try:
broadcast_shape = np.broadcast(
*[np.empty(shape) for shape in comp_dist_shapes]
).shape
except Exception:
raise TypeError('Inferred comp_dist shapes do not broadcast '
'with each other. comp_dists inferred shapes '
'are: {}'.format(comp_dist_shapes))
includes_mixture = False
return comp_dist_shapes, broadcast_shape, includes_mixture

def logp(self, value):
w = self.w

Expand All @@ -203,10 +281,9 @@ def random(self, point=None, size=None):
with _DrawValuesContext() as draw_context:
# We first need to check w and comp_tmp shapes and re compute size
w = draw_values([self.w], point=point, size=size)[0]
with _DrawValuesContextBlocker():
# We don't want to store the values drawn here in the context
# because they wont have the correct size
comp_tmp = self._comp_samples(point=point, size=None)
comp_dist_shapes, broadcast_shape, includes_mixture = (
self.infer_comp_dist_shapes(point=point)
)

# When size is not None, it's hard to tell the w parameter shape
if size is not None and w.shape[:len(size)] == size:
Expand All @@ -215,8 +292,12 @@ def random(self, point=None, size=None):
w_shape = w.shape

# Try to determine parameter shape and dist_shape
param_shape = np.broadcast(np.empty(w_shape),
comp_tmp).shape
if includes_mixture:
param_shape = np.broadcast(np.empty(w_shape),
np.empty(broadcast_shape)).shape
else:
param_shape = np.broadcast(np.empty(w_shape),
np.empty(broadcast_shape + (1,))).shape
if np.asarray(self.shape).size != 0:
dist_shape = np.broadcast(np.empty(self.shape),
np.empty(param_shape[:-1])).shape
Expand Down Expand Up @@ -259,7 +340,11 @@ def random(self, point=None, size=None):
else:
output_size = int(np.prod(dist_shape) * param_shape[-1])
# Get the size we need for the mixture's random call
mixture_size = int(output_size // np.prod(comp_tmp.shape))
if includes_mixture:
mixture_size = int(output_size // np.prod(broadcast_shape))
else:
mixture_size = int(output_size //
(np.prod(broadcast_shape) * param_shape[-1]))
if mixture_size == 1 and _size is None:
mixture_size = None

Expand All @@ -277,11 +362,23 @@ def random(self, point=None, size=None):
size=size)
# Sample from the mixture
with draw_context:
mixed_samples = self._comp_samples(point=point,
size=mixture_size)
w_samples = w_samples.flatten()
mixed_samples = self._comp_samples(
point=point,
size=mixture_size,
broadcast_shape=broadcast_shape,
comp_dist_shapes=comp_dist_shapes,
)
# Test that the mixture has the same number of "samples" as w
if w_samples.size != (mixed_samples.size // w.shape[-1]):
raise ValueError('Inconsistent number of samples from the '
'mixture and mixture weights. Drew {} mixture '
'weights elements, and {} samples from the '
'mixture components.'.
format(w_samples.size,
mixed_samples.size // w.shape[-1]))
# Semiflatten the mixture to be able to zip it with w_samples
mixed_samples = np.reshape(mixed_samples, (-1, comp_tmp.shape[-1]))
w_samples = w_samples.flatten()
mixed_samples = np.reshape(mixed_samples, (-1, w.shape[-1]))
# Select the samples from the mixture
samples = np.array([mixed[choice] for choice, mixed in
zip(w_samples, mixed_samples)])
Expand Down
2 changes: 1 addition & 1 deletion pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def random(self, point=None, size=None):
else:
std_norm_shape = mu.shape
standard_normal = np.random.standard_normal(std_norm_shape)
return mu + np.tensordot(standard_normal, chol, axes=[[-1], [-1]])
return mu + np.einsum('...ij,...j->...i', chol, standard_normal)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I had not understood tensordot's broadcasting rules well and I noticed it started giving weird shaped outputs, and einsum was the most intuitive way to get the right broadcasting.

else:
mu, tau = draw_values([self.mu, self.tau], point=point, size=size)
if mu.shape[-1] != tau[0].shape[-1]:
Expand Down
5 changes: 3 additions & 2 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,9 @@ def test_normal_scalar(self):
ppc = pm.sample_posterior_predictive(trace, samples=1000, vars=[a])
assert 'a' in ppc
assert ppc['a'].shape == (1000,)
_, pval = stats.kstest(ppc['a'],
stats.norm(loc=0, scale=np.sqrt(2)).cdf)
# mu's standard deviation may have changed thanks to a's observed
_, pval = stats.kstest(ppc['a'] - trace['mu'],
stats.norm(loc=0, scale=1).cdf)
assert pval > 0.001

with model:
Expand Down