diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 03e64f94c1..950c6b8a9b 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -345,7 +345,8 @@ def _sample_smc_int( while smc.beta < 1: smc.update_beta_and_weights() - progress_dict[task_id] = {"stage": stage, "beta": smc.beta} + # Index by chain because task_id is None if no progressbar is present + progress_dict[chain] = {"stage": stage, "beta": smc.beta, "task_id": task_id} smc.resample() smc.tune() @@ -378,6 +379,7 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): disable=not progressbar, ) as progress: futures = [] # keep track of the jobs + _log = logging.getLogger(__name__) with multiprocessing.Manager() as manager: # this is the key - we share some state between our # main process and our worker functions @@ -391,6 +393,8 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): for c in range(chains): # iterate over the jobs we need to run # set visible false so we don't have a lot of bars all at once: task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0") + if not progressbar: + _log.info(f"Queueing Chain {c} Stage: 0 Beta: 0") futures.append( executor.submit( _sample_smc_int, @@ -406,17 +410,26 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores): # monitor the progress: done = [] remaining = futures + previous = {c: {} for c in range(chains)} while len(remaining) > 0: finished, remaining = wait(remaining, timeout=0.1) done.extend(finished) - for task_id, update_data in _progress.items(): + for chain, update_data in _progress.items(): stage = update_data["stage"] beta = update_data["beta"] + task_id = update_data["task_id"] + # update the progress bar for this task: progress.update( status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id, refresh=True, ) + # use the logger if there is no progress bar and data has changed: + if not progressbar: + # only log if the stage has changed + if previous[chain].get("stage", -1) != stage: + _log.info(f"Chain: {chain} Stage: {stage} Beta: {beta:.3f}") + previous[chain] = {"stage": stage, "beta": beta} return tuple(cloudpickle.loads(r.result()) for r in done)