Skip to content

Commit 797fa6e

Browse files
authored
Merge pull request #779 from stan-dev/issue/778-get-timing
add timing information for sampler outputs
2 parents 650d2bb + 178f1ad commit 797fa6e

File tree

4 files changed

+133
-1
lines changed

4 files changed

+133
-1
lines changed

cmdstanpy/stanfit/mcmc.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,9 @@ def __init__(
9797
self._max_treedepths: np.ndarray = np.zeros(
9898
self.runset.chains, dtype=int
9999
)
100+
self._chain_time: List[Dict[str, float]] = []
100101

101-
# info from CSV initial comments and header
102+
# info from CSV header and initial and final comment blocks
102103
config = self._validate_csv_files()
103104
self._metadata: InferenceMetadata = InferenceMetadata(config)
104105
if not self._is_fixed_param:
@@ -240,6 +241,14 @@ def max_treedepths(self) -> Optional[np.ndarray]:
240241
"""
241242
return self._max_treedepths if not self._is_fixed_param else None
242243

244+
@property
245+
def time(self) -> List[Dict[str, float]]:
246+
"""
247+
List of per-chain time info scraped from CSV file.
248+
Each chain has dict with keys "warmup", "sampling", "total".
249+
"""
250+
return self._chain_time
251+
243252
def draws(
244253
self, *, inc_warmup: bool = False, concat_chains: bool = False
245254
) -> np.ndarray:
@@ -301,6 +310,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
301310
save_warmup=self._save_warmup,
302311
thin=self._thin,
303312
)
313+
self._chain_time.append(dzero['time']) # type: ignore
304314
if not self._is_fixed_param:
305315
self._divergences[i] = dzero['ct_divergences']
306316
self._max_treedepths[i] = dzero['ct_max_treedepth']
@@ -313,6 +323,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
313323
save_warmup=self._save_warmup,
314324
thin=self._thin,
315325
)
326+
self._chain_time.append(drest['time']) # type: ignore
316327
for key in dzero:
317328
# check args that matter for parsing, plus name, version
318329
if (

cmdstanpy/utils/stancsv.py

+61
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
7979
lineno = scan_warmup_iters(fd, dict, lineno)
8080
lineno = scan_hmc_params(fd, dict, lineno)
8181
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
82+
lineno = scan_time(fd, dict, lineno)
8283
except ValueError as e:
8384
raise ValueError("Error in reading csv file: " + path) from e
8485
return dict
@@ -381,6 +382,66 @@ def scan_sampling_iters(
381382
return lineno
382383

383384

385+
def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
386+
"""
387+
Scan time information from the trailing comment lines in a Stan CSV file.
388+
389+
# Elapsed Time: 0.001332 seconds (Warm-up)
390+
# 0.000249 seconds (Sampling)
391+
# 0.001581 seconds (Total)
392+
393+
394+
It extracts the time values and saves them in the config_dict: key 'time',
395+
value a dictionary with keys 'warmup', 'sampling', and 'total'.
396+
Returns the updated line number after reading the time info.
397+
398+
:param fd: Open file descriptor at comment row following all sample data.
399+
:param config_dict: Dictionary to which the time info is added.
400+
:param lineno: Current line number
401+
"""
402+
time = {}
403+
keys = ['warmup', 'sampling', 'total']
404+
while True:
405+
pos = fd.tell()
406+
line = fd.readline()
407+
if not line:
408+
break
409+
lineno += 1
410+
stripped = line.strip()
411+
if not stripped.startswith('#'):
412+
fd.seek(pos)
413+
lineno -= 1
414+
break
415+
content = stripped.lstrip('#').strip()
416+
if not content:
417+
continue
418+
tokens = content.split()
419+
if len(tokens) < 3:
420+
raise ValueError(f"Invalid time at line {lineno}: {content}")
421+
if 'Warm-up' in content:
422+
key = 'warmup'
423+
time_str = tokens[2]
424+
elif 'Sampling' in content:
425+
key = 'sampling'
426+
time_str = tokens[0]
427+
elif 'Total' in content:
428+
key = 'total'
429+
time_str = tokens[0]
430+
else:
431+
raise ValueError(f"Invalid time at line {lineno}: {content}")
432+
try:
433+
t = float(time_str)
434+
except ValueError as e:
435+
raise ValueError(f"Invalid time at line {lineno}: {content}") from e
436+
time[key] = t
437+
438+
if not all(key in time for key in keys):
439+
raise ValueError(f"Invalid time, stopped at {lineno}")
440+
441+
config_dict['time'] = time
442+
return lineno
443+
444+
384445
def read_metric(path: str) -> List[int]:
385446
"""
386447
Read metric file in JSON or Rdump format.

test/test_sample.py

+6
Original file line numberDiff line numberDiff line change
@@ -1714,6 +1714,12 @@ def test_metadata() -> None:
17141714
assert fit.column_names == col_names
17151715
assert fit.metric_type == 'diag_e'
17161716

1717+
assert len(fit.time) == 4
1718+
for i in range(4):
1719+
assert 'warmup' in fit.time[i].keys()
1720+
assert 'sampling' in fit.time[i].keys()
1721+
assert 'total' in fit.time[i].keys()
1722+
17171723
assert fit.metadata.cmdstan_config['num_samples'] == 100
17181724
assert fit.metadata.cmdstan_config['thin'] == 1
17191725
assert fit.metadata.cmdstan_config['algorithm'] == 'hmc'

test/test_utils.py

+54
Original file line numberDiff line numberDiff line change
@@ -699,3 +699,57 @@ def test_munge_varnames() -> None:
699699

700700
var = 'y.2.3:1.2:5:6'
701701
assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6'
702+
703+
704+
def test_scan_time_normal() -> None:
705+
csv_content = (
706+
"# Elapsed Time: 0.005 seconds (Warm-up)\n"
707+
"# 0 seconds (Sampling)\n"
708+
"# 0.005 seconds (Total)\n"
709+
)
710+
fd = io.StringIO(csv_content)
711+
config_dict = {}
712+
start_line = 0
713+
final_line = stancsv.scan_time(fd, config_dict, start_line)
714+
assert final_line == 3
715+
expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005}
716+
assert config_dict.get('time') == expected
717+
718+
719+
def test_scan_time_no_timing() -> None:
720+
csv_content = (
721+
"# merrily we roll along\n"
722+
"# roll along\n"
723+
"# very merrily we roll along\n"
724+
)
725+
fd = io.StringIO(csv_content)
726+
config_dict = {}
727+
start_line = 0
728+
with pytest.raises(ValueError, match="Invalid time"):
729+
stancsv.scan_time(fd, config_dict, start_line)
730+
731+
732+
def test_scan_time_invalid_value() -> None:
733+
csv_content = (
734+
"# Elapsed Time: abc seconds (Warm-up)\n"
735+
"# 0.200 seconds (Sampling)\n"
736+
"# 0.300 seconds (Total)\n"
737+
)
738+
fd = io.StringIO(csv_content)
739+
config_dict = {}
740+
start_line = 0
741+
with pytest.raises(ValueError, match="Invalid time"):
742+
stancsv.scan_time(fd, config_dict, start_line)
743+
744+
745+
def test_scan_time_invalid_string() -> None:
746+
csv_content = (
747+
"# Elapsed Time: 0.22 seconds (foo)\n"
748+
"# 0.200 seconds (Sampling)\n"
749+
"# 0.300 seconds (Total)\n"
750+
)
751+
fd = io.StringIO(csv_content)
752+
config_dict = {}
753+
start_line = 0
754+
with pytest.raises(ValueError, match="Invalid time"):
755+
stancsv.scan_time(fd, config_dict, start_line)

0 commit comments

Comments
 (0)