-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Mixture random cleanup #3364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mixture random cleanup #3364
Changes from 3 commits
b6ae4b1
6da77ac
af7ea76
fe44a29
5e3db64
fae4c11
e8affd1
74ff181
86f69ea
46347b2
2200b46
d047a09
dd81ec1
a058bfe
6f75956
bf39de6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,17 @@ | ||
from collections.abc import Iterable | ||
import numpy as np | ||
import theano | ||
import theano.tensor as tt | ||
|
||
from pymc3.util import get_variable_name | ||
from ..math import logsumexp | ||
from .dist_math import bound, random_choice | ||
from .distribution import (Discrete, Distribution, draw_values, | ||
generate_samples, _DrawValuesContext, | ||
_DrawValuesContextBlocker, to_tuple) | ||
_DrawValuesContextBlocker, to_tuple, | ||
broadcast_distribution_samples) | ||
from .continuous import get_tau_sigma, Normal | ||
from ..theanof import _conversion_map | ||
|
||
|
||
def all_discrete(comp_dists): | ||
|
@@ -79,9 +83,9 @@ def __init__(self, w, comp_dists, *args, **kwargs): | |
defaults = kwargs.pop('defaults', []) | ||
|
||
if all_discrete(comp_dists): | ||
dtype = kwargs.pop('dtype', 'int64') | ||
default_dtype = _conversion_map[theano.config.floatX] | ||
else: | ||
dtype = kwargs.pop('dtype', 'float64') | ||
default_dtype = theano.config.floatX | ||
|
||
try: | ||
self.mean = (w * self._comp_means()).sum(axis=-1) | ||
|
@@ -90,6 +94,7 @@ def __init__(self, w, comp_dists, *args, **kwargs): | |
defaults.append('mean') | ||
except AttributeError: | ||
pass | ||
dtype = kwargs.pop('dtype', default_dtype) | ||
|
||
try: | ||
comp_modes = self._comp_modes() | ||
|
@@ -108,29 +113,37 @@ def comp_dists(self): | |
return self._comp_dists | ||
|
||
@comp_dists.setter | ||
def comp_dists(self, _comp_dists): | ||
self._comp_dists = _comp_dists | ||
# Tests if the comp_dists can call random with non None size | ||
with _DrawValuesContextBlocker(): | ||
if isinstance(self.comp_dists, (list, tuple)): | ||
try: | ||
[comp_dist.random(size=23) | ||
for comp_dist in self.comp_dists] | ||
self._comp_dists_vect = True | ||
except Exception: | ||
# The comp_dists cannot call random with non None size or | ||
# without knowledge of the point so we assume that we will | ||
# have to iterate calls to random to get the correct size | ||
self._comp_dists_vect = False | ||
else: | ||
try: | ||
self.comp_dists.random(size=23) | ||
self._comp_dists_vect = True | ||
except Exception: | ||
# The comp_dists cannot call random with non None size or | ||
# without knowledge of the point so we assume that we will | ||
# have to iterate calls to random to get the correct size | ||
self._comp_dists_vect = False | ||
def comp_dists(self, comp_dists): | ||
if isinstance(comp_dists, Distribution): | ||
self._comp_dists = comp_dists | ||
self._comp_dist_shapes = to_tuple(comp_dists.shape) | ||
self._broadcast_shape = self._comp_dist_shapes | ||
self.is_multidim_comp = True | ||
elif isinstance(comp_dists, Iterable): | ||
if not all((isinstance(comp_dist, Distribution) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this check should be move to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I can do that. I thought that it would be a good check to make whenever |
||
for comp_dist in comp_dists)): | ||
raise TypeError('Supplied Mixture comp_dists must be a ' | ||
'Distribution or an iterable of ' | ||
'Distributions.') | ||
self._comp_dists = comp_dists | ||
# Now we check the comp_dists distribution shape, see what | ||
# the broadcast shape would be. This shape will be the dist_shape | ||
# used by generate samples (the shape of a single random sample) | ||
# from the mixture | ||
self._comp_dist_shapes = [to_tuple(d.shape) for d in comp_dists] | ||
# All component distributions must broadcast with each other | ||
try: | ||
self._broadcast_shape = np.broadcast( | ||
*[np.empty(shape) for shape in self._comp_dist_shapes] | ||
).shape | ||
except Exception: | ||
raise TypeError('Supplied comp_dists shapes do not broadcast ' | ||
'with each other. comp_dists shapes are: ' | ||
'{}'.format(self._comp_dist_shapes)) | ||
self.is_multidim_comp = False | ||
else: | ||
raise TypeError('Cannot handle supplied comp_dist type {}' | ||
.format(type(comp_dists))) | ||
|
||
def _comp_logp(self, value): | ||
comp_dists = self.comp_dists | ||
|
@@ -160,35 +173,100 @@ def _comp_modes(self): | |
for comp_dist in self.comp_dists], | ||
axis=1)) | ||
|
||
def _comp_samples(self, point=None, size=None): | ||
if self._comp_dists_vect or size is None: | ||
try: | ||
return self.comp_dists.random(point=point, size=size) | ||
except AttributeError: | ||
samples = np.array([comp_dist.random(point=point, size=size) | ||
for comp_dist in self.comp_dists]) | ||
samples = np.moveaxis(samples, 0, samples.ndim - 1) | ||
def _comp_samples(self, point=None, size=None, | ||
comp_dist_shapes=None, | ||
broadcast_shape=None): | ||
if self.is_multidim_comp: | ||
samples = self._comp_dists.random(point=point, size=size) | ||
else: | ||
# We must iterate the calls to random manually | ||
size = to_tuple(size) | ||
_size = int(np.prod(size)) | ||
try: | ||
samples = np.array([self.comp_dists.random(point=point, | ||
size=None) | ||
for _ in range(_size)]) | ||
samples = np.reshape(samples, size + samples.shape[1:]) | ||
except AttributeError: | ||
samples = np.array([[comp_dist.random(point=point, size=None) | ||
for _ in range(_size)] | ||
for comp_dist in self.comp_dists]) | ||
samples = np.moveaxis(samples, 0, samples.ndim - 1) | ||
samples = np.reshape(samples, size + samples[1:]) | ||
|
||
if comp_dist_shapes is None: | ||
comp_dist_shapes = self._comp_dist_shapes | ||
if broadcast_shape is None: | ||
broadcast_shape = self._sample_shape | ||
samples = [] | ||
for dist_shape, comp_dist in zip(comp_dist_shapes, | ||
self.comp_dists): | ||
def generator(*args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are not joking that this is a horrible hack... I think it might be better to overwrite the comp_dist.random method (it's at least more explicit) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean just changing the |
||
# The distribution random methods use the size argument | ||
# differently from scipy.*.rvs, and generate_samples | ||
# follows the latter usage pattern. For this reason we | ||
# decorate (horribly hack) the size kwarg of | ||
# comp_dist.random. We also have to disable pylint W0640 | ||
# because comp_dist is changed at each iteration of the | ||
# for loop, and this generator function must be defined | ||
# for each comp_dist. | ||
# pylint: disable=W0640 | ||
if len(args) > 2: | ||
args[1] = size | ||
else: | ||
kwargs['size'] = size | ||
return comp_dist.random(*args, **kwargs) | ||
sample = generate_samples( | ||
generator=generator, | ||
dist_shape=dist_shape, | ||
broadcast_shape=broadcast_shape, | ||
point=point, | ||
size=size, | ||
) | ||
samples.append(sample) | ||
samples = np.array( | ||
broadcast_distribution_samples(samples, size=size) | ||
) | ||
# In the logp we assume the last axis holds the mixture components | ||
# so we move the axis to the last dimension | ||
samples = np.moveaxis(samples, 0, -1) | ||
if samples.shape[-1] == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This if statement is legacy code that I don't really understand. Why should we do this test, wouldn't it be done by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont remember as well, but at some point there is an error with shape = (100, 1) and shape = (100, ) - not even sure we have a test for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So error happens when we try to call |
||
return samples[..., 0] | ||
else: | ||
return samples | ||
|
||
def infer_comp_dist_shapes(self, point=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it is better to call this once in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, good point. |
||
if self.is_multidim_comp: | ||
if len(self._comp_dist_shapes) > 0: | ||
comp_dist_shapes = self._comp_dist_shapes | ||
else: | ||
# Happens when the distribution is a scalar or when it was not | ||
# given a shape. In these cases we try to draw a single value | ||
# to check its shape, we use the provided point dictionary | ||
# hoping that it can circumvent the Flat and HalfFlat | ||
# undrawable distributions. | ||
with _DrawValuesContextBlocker(): | ||
test_sample = self._comp_dists.random(point=point, | ||
size=None) | ||
comp_dist_shapes = test_sample.shape | ||
broadcast_shape = comp_dist_shapes | ||
includes_mixture = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not complete sure what is the function of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. Yes, I'll add a docstring. |
||
else: | ||
# Now we check the comp_dists distribution shape, see what | ||
# the broadcast shape would be. This shape will be the dist_shape | ||
# used by generate samples (the shape of a single random sample) | ||
# from the mixture | ||
comp_dist_shapes = [] | ||
for dist_shape, comp_dist in zip(self._comp_dist_shapes, | ||
self._comp_dists): | ||
if dist_shape == tuple(): | ||
# Happens when the distribution is a scalar or when it was | ||
# not given a shape. In these cases we try to draw a single | ||
# value to check its shape, we use the provided point | ||
# dictionary hoping that it can circumvent the Flat and | ||
# HalfFlat undrawable distributions. | ||
with _DrawValuesContextBlocker(): | ||
test_sample = comp_dist.random(point=point, | ||
size=None) | ||
dist_shape = test_sample.shape | ||
comp_dist_shapes.append(dist_shape) | ||
# All component distributions must broadcast with each other | ||
try: | ||
broadcast_shape = np.broadcast( | ||
*[np.empty(shape) for shape in comp_dist_shapes] | ||
).shape | ||
except Exception: | ||
raise TypeError('Inferred comp_dist shapes do not broadcast ' | ||
'with each other. comp_dists inferred shapes ' | ||
'are: {}'.format(comp_dist_shapes)) | ||
includes_mixture = False | ||
return comp_dist_shapes, broadcast_shape, includes_mixture | ||
|
||
def logp(self, value): | ||
w = self.w | ||
|
||
|
@@ -203,10 +281,9 @@ def random(self, point=None, size=None): | |
with _DrawValuesContext() as draw_context: | ||
# We first need to check w and comp_tmp shapes and re compute size | ||
w = draw_values([self.w], point=point, size=size)[0] | ||
with _DrawValuesContextBlocker(): | ||
# We don't want to store the values drawn here in the context | ||
# because they wont have the correct size | ||
comp_tmp = self._comp_samples(point=point, size=None) | ||
comp_dist_shapes, broadcast_shape, includes_mixture = ( | ||
self.infer_comp_dist_shapes(point=point) | ||
) | ||
|
||
# When size is not None, it's hard to tell the w parameter shape | ||
if size is not None and w.shape[:len(size)] == size: | ||
|
@@ -215,8 +292,12 @@ def random(self, point=None, size=None): | |
w_shape = w.shape | ||
|
||
# Try to determine parameter shape and dist_shape | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
comp_tmp).shape | ||
if includes_mixture: | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
np.empty(broadcast_shape)).shape | ||
else: | ||
param_shape = np.broadcast(np.empty(w_shape), | ||
np.empty(broadcast_shape + (1,))).shape | ||
if np.asarray(self.shape).size != 0: | ||
dist_shape = np.broadcast(np.empty(self.shape), | ||
np.empty(param_shape[:-1])).shape | ||
|
@@ -259,7 +340,11 @@ def random(self, point=None, size=None): | |
else: | ||
output_size = int(np.prod(dist_shape) * param_shape[-1]) | ||
# Get the size we need for the mixture's random call | ||
mixture_size = int(output_size // np.prod(comp_tmp.shape)) | ||
if includes_mixture: | ||
mixture_size = int(output_size // np.prod(broadcast_shape)) | ||
else: | ||
mixture_size = int(output_size // | ||
(np.prod(broadcast_shape) * param_shape[-1])) | ||
if mixture_size == 1 and _size is None: | ||
mixture_size = None | ||
|
||
|
@@ -277,11 +362,23 @@ def random(self, point=None, size=None): | |
size=size) | ||
# Sample from the mixture | ||
with draw_context: | ||
mixed_samples = self._comp_samples(point=point, | ||
size=mixture_size) | ||
w_samples = w_samples.flatten() | ||
mixed_samples = self._comp_samples( | ||
point=point, | ||
size=mixture_size, | ||
broadcast_shape=broadcast_shape, | ||
comp_dist_shapes=comp_dist_shapes, | ||
) | ||
# Test that the mixture has the same number of "samples" as w | ||
if w_samples.size != (mixed_samples.size // w.shape[-1]): | ||
raise ValueError('Inconsistent number of samples from the ' | ||
'mixture and mixture weights. Drew {} mixture ' | ||
'weights elements, and {} samples from the ' | ||
'mixture components.'. | ||
format(w_samples.size, | ||
mixed_samples.size // w.shape[-1])) | ||
# Semiflatten the mixture to be able to zip it with w_samples | ||
mixed_samples = np.reshape(mixed_samples, (-1, comp_tmp.shape[-1])) | ||
w_samples = w_samples.flatten() | ||
mixed_samples = np.reshape(mixed_samples, (-1, w.shape[-1])) | ||
# Select the samples from the mixture | ||
samples = np.array([mixed[choice] for choice, mixed in | ||
zip(w_samples, mixed_samples)]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,7 +267,7 @@ def random(self, point=None, size=None): | |
else: | ||
std_norm_shape = mu.shape | ||
standard_normal = np.random.standard_normal(std_norm_shape) | ||
return mu + np.tensordot(standard_normal, chol, axes=[[-1], [-1]]) | ||
return mu + np.einsum('...ij,...j->...i', chol, standard_normal) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I had not understood |
||
else: | ||
mu, tau = draw_values([self.mu, self.tau], point=point, size=size) | ||
if mu.shape[-1] != tau[0].shape[-1]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is a bit confusion - this is only to distinguish between the comp being a list or a distribution right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're totally right. I'll change this