Skip to content

Commit 41f9633

Browse files
committed
changes per code review
1 parent eeea954 commit 41f9633

File tree

4 files changed

+63
-45
lines changed

4 files changed

+63
-45
lines changed

cmdstanpy/stanfit/mcmc.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
self._max_treedepths: np.ndarray = np.zeros(
9898
self.runset.chains, dtype=int
9999
)
100-
self._chain_timing: List[Optional[Dict[str, float]]] = []
100+
self._chain_time: List[Dict[str, float]] = []
101101

102102
# info from CSV header and initial and final comment blocks
103103
config = self._validate_csv_files()
@@ -242,12 +242,12 @@ def max_treedepths(self) -> Optional[np.ndarray]:
242242
return self._max_treedepths if not self._is_fixed_param else None
243243

244244
@property
245-
def timing(self) -> List[Optional[Dict[str, float]]]:
245+
def time(self) -> List[Dict[str, float]]:
246246
"""
247-
List of per-chain timing info scraped from CSV file.
247+
List of per-chain time info scraped from CSV file.
248248
Each chain has dict with keys "warmup", "sampling", "total".
249249
"""
250-
return self._chain_timing
250+
return self._chain_time
251251

252252
def draws(
253253
self, *, inc_warmup: bool = False, concat_chains: bool = False
@@ -310,7 +310,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
310310
save_warmup=self._save_warmup,
311311
thin=self._thin,
312312
)
313-
self._chain_timing.append(dzero.get("timing"))
313+
self._chain_time.append(dzero.get("time"))
314314
if not self._is_fixed_param:
315315
self._divergences[i] = dzero['ct_divergences']
316316
self._max_treedepths[i] = dzero['ct_max_treedepth']
@@ -323,7 +323,7 @@ def _validate_csv_files(self) -> Dict[str, Any]:
323323
save_warmup=self._save_warmup,
324324
thin=self._thin,
325325
)
326-
self._chain_timing.append(drest.get("timing"))
326+
self._chain_time.append(drest.get("time"))
327327
for key in dzero:
328328
# check args that matter for parsing, plus name, version
329329
if (

cmdstanpy/utils/stancsv.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def scan_sampler_csv(path: str, is_fixed_param: bool = False) -> Dict[str, Any]:
7979
lineno = scan_warmup_iters(fd, dict, lineno)
8080
lineno = scan_hmc_params(fd, dict, lineno)
8181
lineno = scan_sampling_iters(fd, dict, lineno, is_fixed_param)
82-
lineno = scan_timing(fd, dict, lineno)
82+
lineno = scan_time(fd, dict, lineno)
8383
except ValueError as e:
8484
raise ValueError("Error in reading csv file: " + path) from e
8585
return dict
@@ -381,24 +381,24 @@ def scan_sampling_iters(
381381
config_dict['ct_max_treedepth'] = ct_max_treedepth
382382
return lineno
383383

384-
def scan_timing(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
384+
def scan_time(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
385385
"""
386-
Scan timing information from the trailing comment lines in a Stan CSV file.
386+
Scan time information from the trailing comment lines in a Stan CSV file.
387387
388388
# Elapsed Time: 0.001332 seconds (Warm-up)
389389
# 0.000249 seconds (Sampling)
390390
# 0.001581 seconds (Total)
391391
392392
393-
It extracts the time values and saves them in the config_dict under the key 'timing'
393+
It extracts the time values and saves them in the config_dict under the key 'time'
394394
as a dictionary with keys 'warmup', 'sampling', and 'total'.
395-
Returns the updated line number after reading the timing info.
395+
Returns the updated line number after reading the time info.
396396
397-
:param fd: Open file descriptor positioned at the timing section.
398-
:param config_dict: Dictionary to which the timing info is added.
397+
:param fd: Open file descriptor at comment row following all sample data.
398+
:param config_dict: Dictionary to which the time info is added.
399399
:param lineno: Current line number
400400
"""
401-
timing = {}
401+
time = {}
402402
keys = ['warmup', 'sampling', 'total']
403403
while True:
404404
pos = fd.tell()
@@ -414,27 +414,29 @@ def scan_timing(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int:
414414
content = stripped.lstrip('#').strip()
415415
if not content:
416416
continue
417-
tokens = content.lower().split()
418-
if 'elapsed' in tokens[0]:
417+
tokens = content.split()
418+
if len(tokens) < 3:
419+
raise ValueError(f"Invalid time at line {lineno}: {content}")
420+
if 'Warm-up' in content:
419421
key = 'warmup'
420-
try:
421-
t = float(tokens[2])
422-
except ValueError:
423-
raise ValueError(f"Invalid timing value at line {lineno}: {content}")
422+
time_str = tokens[2]
423+
elif 'Sampling' in content:
424+
key = 'sampling'
425+
time_str = tokens[0]
426+
elif 'Total' in content:
427+
key = 'total'
428+
time_str = tokens[0]
424429
else:
425-
if 'sampling' in tokens[2]:
426-
key = 'sampling'
427-
elif 'total' in tokens[2]:
428-
key = 'total'
429-
try:
430-
t = float(tokens[0])
431-
except ValueError:
432-
raise ValueError(f"Invalid timing value at line {lineno}: {content}")
433-
timing[key] = t
434-
if not all(key in timing for key in keys):
435-
raise ValueError(f"Invalid timing, stopped at {lineno}")
430+
raise ValueError(f"Invalid time at line {lineno}: {content}")
431+
try:
432+
t = float(time_str)
433+
except ValueError:
434+
raise ValueError(f"Invalid time value at line {lineno}: {content}")
435+
time[key] = t
436+
if not all(key in time for key in keys):
437+
raise ValueError(f"Invalid time, stopped at {lineno}")
436438

437-
config_dict['timing'] = timing
439+
config_dict['time'] = time
438440
return lineno
439441

440442

test/test_sample.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1714,11 +1714,11 @@ def test_metadata() -> None:
17141714
assert fit.column_names == col_names
17151715
assert fit.metric_type == 'diag_e'
17161716

1717-
assert len(fit.timing) == 4
1717+
assert len(fit.time) == 4
17181718
for i in range(4):
1719-
assert 'warmup' in fit.timing[i].keys()
1720-
assert 'sampling' in fit.timing[i].keys()
1721-
assert 'total' in fit.timing[i].keys()
1719+
assert 'warmup' in fit.time[i].keys()
1720+
assert 'sampling' in fit.time[i].keys()
1721+
assert 'total' in fit.time[i].keys()
17221722

17231723
assert fit.metadata.cmdstan_config['num_samples'] == 100
17241724
assert fit.metadata.cmdstan_config['thin'] == 1

test/test_utils.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,8 @@ def test_munge_varnames() -> None:
700700
var = 'y.2.3:1.2:5:6'
701701
assert stancsv.munge_varname(var) == 'y[2,3].1[2].5.6'
702702

703-
def test_scan_timing_normal() -> None:
703+
704+
def test_scan_time_normal() -> None:
704705
csv_content = (
705706
"# Elapsed Time: 0.005 seconds (Warm-up)\n"
706707
"# 0 seconds (Sampling)\n"
@@ -709,12 +710,13 @@ def test_scan_timing_normal() -> None:
709710
fd = io.StringIO(csv_content)
710711
config_dict = {}
711712
start_line = 0
712-
final_line = stancsv.scan_timing(fd, config_dict, start_line)
713+
final_line = stancsv.scan_time(fd, config_dict, start_line)
713714
assert final_line == 3
714715
expected = {'warmup': 0.005, 'sampling': 0.0, 'total': 0.005}
715-
assert config_dict.get('timing') == expected
716+
assert config_dict.get('time') == expected
717+
716718

717-
def test_scan_timing_no_timing() -> None:
719+
def test_scan_time_no_timing() -> None:
718720
csv_content = (
719721
"# merrily we roll along\n"
720722
"# roll along\n"
@@ -723,11 +725,11 @@ def test_scan_timing_no_timing() -> None:
723725
fd = io.StringIO(csv_content)
724726
config_dict = {}
725727
start_line = 0
726-
with pytest.raises(ValueError, match="Invalid timing"):
727-
stancsv.scan_timing(fd, config_dict, start_line)
728+
with pytest.raises(ValueError, match="Invalid time"):
729+
stancsv.scan_time(fd, config_dict, start_line)
728730

729731

730-
def test_scan_timing_invalid_value() -> None:
732+
def test_scan_time_invalid_value() -> None:
731733
csv_content = (
732734
"# Elapsed Time: abc seconds (Warm-up)\n"
733735
"# 0.200 seconds (Sampling)\n"
@@ -736,5 +738,19 @@ def test_scan_timing_invalid_value() -> None:
736738
fd = io.StringIO(csv_content)
737739
config_dict = {}
738740
start_line = 0
739-
with pytest.raises(ValueError, match="Invalid timing"):
740-
stancsv.scan_timing(fd, config_dict, start_line)
741+
with pytest.raises(ValueError, match="Invalid time"):
742+
stancsv.scan_time(fd, config_dict, start_line)
743+
744+
745+
def test_scan_time_invalid_string() -> None:
746+
csv_content = (
747+
"# Elapsed Time: 0.22 seconds (foo)\n"
748+
"# 0.200 seconds (Sampling)\n"
749+
"# 0.300 seconds (Total)\n"
750+
)
751+
fd = io.StringIO(csv_content)
752+
config_dict = {}
753+
start_line = 0
754+
with pytest.raises(ValueError, match="Invalid time"):
755+
stancsv.scan_time(fd, config_dict, start_line)
756+

0 commit comments

Comments
 (0)