diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index c878d8ce..116e11b4 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -29,13 +29,7 @@ import pandas as pd from tqdm.auto import tqdm -from cmdstanpy import ( - _CMDSTAN_REFRESH, - _CMDSTAN_SAMPLING, - _CMDSTAN_WARMUP, - _TMPDIR, - compilation, -) +from cmdstanpy import _CMDSTAN_SAMPLING, _CMDSTAN_WARMUP, _TMPDIR, compilation from cmdstanpy.cmdstan_args import ( CmdStanArgs, GenerateQuantitiesArgs, @@ -1069,9 +1063,6 @@ def sample( iter_total += _CMDSTAN_SAMPLING else: iter_total += iter_sampling - if refresh is None: - refresh = _CMDSTAN_REFRESH - iter_total = iter_total // refresh + 2 progress_hook = self._wrap_sampler_progress_hook( chain_ids=chain_ids, @@ -2138,13 +2129,12 @@ def _wrap_sampler_progress_hook( process, "Chain [id] Iteration" for multi-chain processing. For the latter, manage array of pbars, update accordingly. """ - pat = re.compile(r'Chain \[(\d*)\] (Iteration.*)') + chain_pat = re.compile(r'(Chain \[(\d+)\] )?Iteration:\s+(\d+)') pbars: Dict[int, tqdm] = { chain_id: tqdm( total=total, - bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}", - postfix=[{"value": "Status"}], desc=f'chain {chain_id}', + postfix='(Warmup)', colour='yellow', ) for chain_id in chain_ids @@ -2153,23 +2143,19 @@ def _wrap_sampler_progress_hook( def progress_hook(line: str, idx: int) -> None: if line == "Done": for pbar in pbars.values(): - pbar.postfix[0]["value"] = 'Sampling completed' + pbar.set_postfix_str('(Sampling completed)') pbar.update(total - pbar.n) pbar.close() - else: - match = pat.match(line) - if match: - idx = int(match.group(1)) - mline = match.group(2).strip() - elif line.startswith("Iteration"): - mline = line - idx = chain_ids[idx] - else: - return - if 'Sampling' in mline: - pbars[idx].colour = 'blue' - pbars[idx].update(1) - pbars[idx].postfix[0]["value"] = mline + elif (match := chain_pat.match(line)) is not None: + idx = int(match.group(2) or chain_ids[idx]) + current_iter = int(match.group(3)) + + pbar = pbars[idx] + if pbar.colour == 'yellow' and 'Sampling' in line: + pbar.colour = 'blue' + pbar.set_postfix_str('(Sampling)') + + pbar.update(current_iter - pbar.n) return progress_hook @@ -2225,8 +2211,7 @@ def diagnose( Gradients are evaluated in the unconstrained space. """ - with temp_single_json(data) as _data, \ - temp_single_json(inits) as _inits: + with temp_single_json(data) as _data, temp_single_json(inits) as _inits: cmd = [ str(self.exe_file), "diagnose",