5
5
"""
6
6
import itertools as itl
7
7
import logging
8
+ from typing import List
8
9
9
10
import numpy as np
10
11
import warnings
@@ -92,7 +93,8 @@ def _set_sampler_vars(self, sampler_vars):
92
93
93
94
self .sampler_vars = sampler_vars
94
95
95
- def setup (self , draws , chain , sampler_vars = None ):
96
+ # pylint: disable=unused-argument
97
+ def setup (self , draws , chain , sampler_vars = None ) -> None :
96
98
"""Perform chain-specific setup.
97
99
98
100
Parameters
@@ -542,7 +544,7 @@ def points(self, chains=None):
542
544
return itl .chain .from_iterable (self ._straces [chain ] for chain in chains )
543
545
544
546
545
- def merge_traces (mtraces ) :
547
+ def merge_traces (mtraces : List [ MultiTrace ]) -> MultiTrace :
546
548
"""Merge MultiTrace objects.
547
549
548
550
Parameters
@@ -552,17 +554,26 @@ def merge_traces(mtraces):
552
554
553
555
Raises
554
556
------
555
- A ValueError is raised if any traces have overlapping chain numbers.
557
+ A ValueError is raised if any traces have overlapping chain numbers,
558
+ or if chains are of different lengths.
556
559
557
560
Returns
558
561
-------
559
562
A MultiTrace instance with merged chains
560
563
"""
564
+ if len (mtraces ) == 0 :
565
+ raise ValueError ("Cannot merge an empty set of traces." )
561
566
base_mtrace = mtraces [0 ]
567
+ chain_len = len (base_mtrace )
568
+ # check base trace
569
+ if any ((len (st ) != chain_len for _ , st in base_mtrace ._straces .items ())):
570
+ raise ValueError ("Chains are of different lengths." )
562
571
for new_mtrace in mtraces [1 :]:
563
572
for new_chain , strace in new_mtrace ._straces .items ():
564
573
if new_chain in base_mtrace ._straces :
565
574
raise ValueError ("Chains are not unique." )
575
+ if len (strace ) != chain_len :
576
+ raise ValueError ("Chains are of different lengths." )
566
577
base_mtrace ._straces [new_chain ] = strace
567
578
base_mtrace ._report = merge_reports ([trace .report for trace in mtraces ])
568
579
return base_mtrace
0 commit comments