diff --git a/.github/workflows/arviz_compat.yml b/.github/workflows/arviz_compat.yml index be5450e003..4d1046bd22 100644 --- a/.github/workflows/arviz_compat.yml +++ b/.github/workflows/arviz_compat.yml @@ -9,9 +9,10 @@ jobs: pytest: strategy: matrix: - os: [ubuntu-18.04] + os: [ubuntu-latest, macos-latest] floatx: [float64] test-subset: + - pymc3/tests/test_distributions_random.py - pymc3/tests/test_sampling.py fail-fast: false runs-on: ${{ matrix.os }} diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml new file mode 100644 index 0000000000..83add15839 --- /dev/null +++ b/.github/workflows/windows.yml @@ -0,0 +1,43 @@ +name: windows + +on: + pull_request: + push: + branches: [master] + +jobs: + pytest: + strategy: + matrix: + os: [windows-latest] + floatx: [float64] + test-subset: + - pymc3/tests/test_distributions_random.py + - pymc3/tests/test_sampling.py + runs-on: ${{ matrix.os }} + env: + TEST_SUBSET: ${{ matrix.test-subset }} + THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc.cxxflags='-march=core2' + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v2 + - name: Cache conda + uses: actions/cache@v1 + env: + # Increase this value to reset cache if conda-envs/environment-dev-py37.yml has not changed + CACHE_NUMBER: 0 + with: + path: ~/conda_pkgs_dir + key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ + hashFiles('conda-envs/environment-dev-py37.yml') }} + - uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: pymc3-dev-py37 + channel-priority: strict + environment-file: conda-envs/environment-dev-py37.yml + use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! + - run: | + conda activate pymc3-dev-py37 + python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET diff --git a/pymc3/distributions/posterior_predictive.py b/pymc3/distributions/posterior_predictive.py index 7a4ab43d30..8efec54af7 100644 --- a/pymc3/distributions/posterior_predictive.py +++ b/pymc3/distributions/posterior_predictive.py @@ -42,7 +42,7 @@ ) from ..exceptions import IncorrectArgumentsError from ..vartypes import theano_constant -from ..util import dataset_to_point_dict, chains_and_samples, get_var_name +from ..util import dataset_to_point_list, chains_and_samples, get_var_name # Failing tests: # test_mixture_random_shape::test_mixture_random_shape @@ -209,10 +209,10 @@ def fast_sample_posterior_predictive( if isinstance(trace, InferenceData): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace.posterior) + trace = dataset_to_point_list(trace.posterior) elif isinstance(trace, Dataset): nchains, ndraws = chains_and_samples(trace) - trace = dataset_to_point_dict(trace) + trace = dataset_to_point_list(trace) elif isinstance(trace, MultiTrace): nchains = trace.nchains ndraws = len(trace) diff --git a/pymc3/memoize.py b/pymc3/memoize.py index 349d03188f..55a540ca8d 100644 --- a/pymc3/memoize.py +++ b/pymc3/memoize.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -import pickle +import dill import collections from .util import biwrap @@ -23,7 +23,16 @@ @biwrap def memoize(obj, bound=False): """ - An expensive memoizer that works with unhashables + Decorator to apply memoization to expensive functions. + It uses a custom `hashable` helper function to hash typically unhashable Python objects. + + Parameters + ---------- + obj : callable + the function to apply the caching to + bound : bool + indicates if the [obj] is a bound method (self as first argument) + For bound methods, the cache is kept in a `_cache` attribute on [self]. """ # this is declared not to be a bound method, so just attach new attr to obj if not bound: @@ -40,7 +49,7 @@ def memoizer(*args, **kwargs): key = (hashable(args[1:]), hashable(kwargs)) if not hasattr(args[0], "_cache"): setattr(args[0], "_cache", collections.defaultdict(dict)) - # do not add to cache regestry + # do not add to cache registry cache = getattr(args[0], "_cache")[obj.__name__] if key not in cache: cache[key] = obj(*args, **kwargs) @@ -75,19 +84,26 @@ def __setstate__(self, state): self.__dict__.update(state) -def hashable(a): +def hashable(a) -> int: """ - Turn some unhashable objects into hashable ones. + Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function. + Lists and tuples are hashed based on their elements. """ if isinstance(a, dict): - return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items())) + # first hash the keys and values with hashable + # then hash the tuple of int-tuples with the builtin + return hash(tuple((hashable(k), hashable(v)) for k, v in a.items())) + if isinstance(a, (tuple, list)): + # lists are mutable and not hashable by default + # for memoization, we need the hash to depend on the items + return hash(tuple(hashable(i) for i in a)) try: return hash(a) except TypeError: pass # Not hashable >>> try: - return hash(pickle.dumps(a)) + return hash(dill.dumps(a)) except Exception: if hasattr(a, "__dict__"): return hashable(a.__dict__) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 1dfa82c5d7..96b2b6df18 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -56,7 +56,7 @@ get_untransformed_name, is_transformed_name, get_default_varnames, - dataset_to_point_dict, + dataset_to_point_list, chains_and_samples, ) from .vartypes import discrete_types @@ -1642,9 +1642,9 @@ def sample_posterior_predictive( _trace: Union[MultiTrace, PointList] if isinstance(trace, InferenceData): - _trace = dataset_to_point_dict(trace.posterior) + _trace = dataset_to_point_list(trace.posterior) elif isinstance(trace, xarray.Dataset): - _trace = dataset_to_point_dict(trace) + _trace = dataset_to_point_list(trace) else: _trace = trace @@ -1780,10 +1780,10 @@ def sample_posterior_predictive_w( n_samples = [ trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces ] - traces = [dataset_to_point_dict(trace.posterior) for trace in traces] + traces = [dataset_to_point_list(trace.posterior) for trace in traces] elif isinstance(traces[0], xarray.Dataset): n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces] - traces = [dataset_to_point_dict(trace) for trace in traces] + traces = [dataset_to_point_list(trace) for trace in traces] else: n_samples = [len(i) * i.nchains for i in traces] diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index a789674095..2b9358787c 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -21,6 +21,7 @@ from scipy import linalg import numpy.random as nr import theano +import sys import pymc3 as pm from pymc3.distributions.dist_math import clipped_beta_rvs @@ -713,6 +714,10 @@ def test_half_flat(self): def test_binomial(self): pymc3_random_discrete(pm.Binomial, {"n": Nat, "p": Unit}, ref_rand=st.binom.rvs) + @pytest.mark.xfail( + sys.platform.startswith("win"), + reason="Known issue: https://github.com/pymc-devs/pymc3/pull/4269", + ) def test_beta_binomial(self): pymc3_random_discrete( pm.BetaBinomial, {"n": Nat, "alpha": Rplus, "beta": Rplus}, ref_rand=self._beta_bin diff --git a/pymc3/tests/test_memo.py b/pymc3/tests/test_memo.py index aa8fdb265c..4d1955a442 100644 --- a/pymc3/tests/test_memo.py +++ b/pymc3/tests/test_memo.py @@ -11,21 +11,57 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +import pymc3 as pm -from pymc3.memoize import memoize +from pymc3 import memoize -def getmemo(): - @memoize - def f(a, b=("a")): - return str(a) + str(b) +def test_memo(): + def fun(inputs, suffix="_a"): + return str(inputs) + str(suffix) - return f + inputs = ["i1", "i2"] + assert fun(inputs) == "['i1', 'i2']_a" + assert fun(inputs, "_b") == "['i1', 'i2']_b" + funmem = memoize.memoize(fun) + assert hasattr(fun, "cache") + assert isinstance(fun.cache, dict) + assert len(fun.cache) == 0 + + # call the memoized function with a list input + # and check the size of the cache! + assert funmem(inputs) == "['i1', 'i2']_a" + assert funmem(inputs) == "['i1', 'i2']_a" + assert len(fun.cache) == 1 + assert funmem(inputs, "_b") == "['i1', 'i2']_b" + assert funmem(inputs, "_b") == "['i1', 'i2']_b" + assert len(fun.cache) == 2 + + # add items to the inputs list (the list instance remains identical !!) + inputs.append("i3") + assert funmem(inputs) == "['i1', 'i2', 'i3']_a" + assert funmem(inputs) == "['i1', 'i2', 'i3']_a" + assert len(fun.cache) == 3 -def test_memo(): - f = getmemo() - assert f("x", ["y", "z"]) == "x['y', 'z']" - assert f("x", ["a", "z"]) == "x['a', 'z']" - assert f("x", ["y", "z"]) == "x['y', 'z']" +def test_hashing_of_rv_tuples(): + obs = np.random.normal(-1, 0.1, size=10) + with pm.Model() as pmodel: + mu = pm.Normal("mu", 0, 1) + sd = pm.Gamma("sd", 1, 2) + dd = pm.DensityDist( + "dd", + pm.Normal.dist(mu, sd).logp, + random=pm.Normal.dist(mu, sd).random, + observed=obs, + ) + for freerv in [mu, sd, dd] + pmodel.free_RVs: + for structure in [ + freerv, + {"alpha": freerv, "omega": None}, + [freerv, []], + (freerv, []), + ]: + assert isinstance(memoize.hashable(structure), int) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index c8c6f13483..b2dbd309bc 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -31,7 +31,6 @@ import pytest -@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32") class TestSample(SeededTest): def setup_method(self): @@ -931,7 +930,6 @@ def test_shared(self): assert gen2["y"].shape == (draws, n2) def test_density_dist(self): - obs = np.random.normal(-1, 0.1, size=10) with pm.Model(): mu = pm.Normal("mu", 0, 1) diff --git a/pymc3/util.py b/pymc3/util.py index 310cf6a524..c18296b17f 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -15,6 +15,7 @@ import re import functools from typing import List, Dict, Tuple, Union +import warnings import numpy as np import xarray @@ -258,6 +259,14 @@ def enhanced(*args, **kwargs): # FIXME: this function is poorly named, because it returns a LIST of # points, not a dictionary of points. def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: + warnings.warn( + "dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!", + DeprecationWarning, + ) + return dataset_to_point_list(ds) + + +def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: # grab posterior samples for each variable _samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()} # make dicts