Skip to content

Commit 7c8c41e

Browse files
rpgoldmantwiecki
authored andcommitted
Check chain lengths when merging.
Raise ValueError for mismatched chains. Added test to verify.
1 parent d9a3167 commit 7c8c41e

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

Diff for: pymc3/backends/base.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import itertools as itl
77
import logging
8+
from typing import List
89

910
import numpy as np
1011
import warnings
@@ -92,7 +93,8 @@ def _set_sampler_vars(self, sampler_vars):
9293

9394
self.sampler_vars = sampler_vars
9495

95-
def setup(self, draws, chain, sampler_vars=None):
96+
# pylint: disable=unused-argument
97+
def setup(self, draws, chain, sampler_vars=None) -> None:
9698
"""Perform chain-specific setup.
9799
98100
Parameters
@@ -542,7 +544,7 @@ def points(self, chains=None):
542544
return itl.chain.from_iterable(self._straces[chain] for chain in chains)
543545

544546

545-
def merge_traces(mtraces):
547+
def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
546548
"""Merge MultiTrace objects.
547549
548550
Parameters
@@ -552,17 +554,26 @@ def merge_traces(mtraces):
552554
553555
Raises
554556
------
555-
A ValueError is raised if any traces have overlapping chain numbers.
557+
A ValueError is raised if any traces have overlapping chain numbers,
558+
or if chains are of different lengths.
556559
557560
Returns
558561
-------
559562
A MultiTrace instance with merged chains
560563
"""
564+
if len(mtraces) == 0:
565+
raise ValueError("Cannot merge an empty set of traces.")
561566
base_mtrace = mtraces[0]
567+
chain_len = len(base_mtrace)
568+
# check base trace
569+
if any((len(st) != chain_len for _, st in base_mtrace._straces.items())):
570+
raise ValueError("Chains are of different lengths.")
562571
for new_mtrace in mtraces[1:]:
563572
for new_chain, strace in new_mtrace._straces.items():
564573
if new_chain in base_mtrace._straces:
565574
raise ValueError("Chains are not unique.")
575+
if len(strace) != chain_len:
576+
raise ValueError("Chains are of different lengths.")
566577
base_mtrace._straces[new_chain] = strace
567578
base_mtrace._report = merge_reports([trace.report for trace in mtraces])
568579
return base_mtrace

Diff for: pymc3/tests/test_ndarray_backend.py

+24
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ def test_multitrace_nonunique(self):
112112
with pytest.raises(ValueError):
113113
base.MultiTrace([self.strace0, self.strace1])
114114

115+
def test_merge_traces_no_traces(self):
116+
with pytest.raises(ValueError):
117+
base.merge_traces([])
118+
119+
def test_merge_traces_diff_lengths(self):
120+
with self.model:
121+
strace0 = self.backend(self.name)
122+
strace0.setup(self.draws, 1)
123+
for i in range(self.draws):
124+
strace0.record(self.test_point)
125+
strace0.close()
126+
mtrace0 = base.MultiTrace([self.strace0])
127+
128+
with self.model:
129+
strace1 = self.backend(self.name)
130+
strace1.setup(2 * self.draws, 1)
131+
for i in range(2 * self.draws):
132+
strace1.record(self.test_point)
133+
strace1.close()
134+
mtrace1 = base.MultiTrace([strace1])
135+
136+
with pytest.raises(ValueError):
137+
base.merge_traces([mtrace0, mtrace1])
138+
115139
def test_merge_traces_nonunique(self):
116140
mtrace0 = base.MultiTrace([self.strace0])
117141
mtrace1 = base.MultiTrace([self.strace1])

0 commit comments

Comments
 (0)