Skip to content

Commit 29d5b85

Browse files
committed
Allow the 'vars' argument to draws_pd to filter new columns
1 parent 535c4e6 commit 29d5b85

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

cmdstanpy/stanfit/gq.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ def draws_pd(
323323

324324
self._assemble_generated_quantities()
325325

326+
all_columns = ['chain__', 'iter__', 'draw__'] + list(self.column_names)
327+
326328
gq_cols: List[str] = []
327329
mcmc_vars: List[str] = []
328330
if vars is not None:
@@ -341,10 +343,12 @@ def draws_pd(
341343
info.start_idx : info.end_idx
342344
]
343345
)
346+
elif var in ['chain__', 'iter__', 'draw__']:
347+
gq_cols.append(var)
344348
else:
345349
raise ValueError('Unknown variable: {}'.format(var))
346350
else:
347-
gq_cols = list(self.column_names)
351+
gq_cols = all_columns
348352
vars_list = gq_cols
349353

350354
previous_draws_pd = self._previous_draws_pd(mcmc_vars, inc_warmup)
@@ -369,13 +373,9 @@ def draws_pd(
369373
)
370374
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
371375

372-
vars_list = ['chain__', 'iter__', 'draw__'] + vars_list
373-
if gq_cols:
374-
gq_cols = ['chain__', 'iter__', 'draw__'] + gq_cols
375-
376376
draws_pd = pd.DataFrame(
377377
data=flatten_chains(draws),
378-
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),
378+
columns=all_columns,
379379
)
380380

381381
if inc_sample and mcmc_vars:

cmdstanpy/stanfit/mcmc.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -609,10 +609,12 @@ def draws_pd(
609609
cols.extend(
610610
self.column_names[info.start_idx : info.end_idx]
611611
)
612+
elif var in ['chain__', 'iter__', 'draw__']:
613+
cols.append(var)
612614
else:
613615
raise ValueError(f'Unknown variable: {var}')
614616
else:
615-
cols = list(self.column_names)
617+
cols = ['chain__', 'iter__', 'draw__'] + list(self.column_names)
616618

617619
draws = self.draws(inc_warmup=inc_warmup)
618620
# add long-form columns for chain, iteration, draw
@@ -634,8 +636,6 @@ def draws_pd(
634636
)
635637
draws = np.concatenate([chains_col, iter_col, draw_col, draws], axis=2)
636638

637-
cols = ['chain__', 'iter__', 'draw__'] + cols
638-
639639
return pd.DataFrame(
640640
data=flatten_chains(draws),
641641
columns=['chain__', 'iter__', 'draw__'] + list(self.column_names),

test/test_generate_quantities.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ def test_from_csv_files(caplog: pytest.LogCaptureFixture) -> None:
8585
- 3 # chain, iter, draw duplicates
8686
)
8787

88-
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (
89-
["chain__", "iter__", "draw__"] + column_names
90-
)
88+
assert list(bern_gqs.draws_pd(vars=['y_rep']).columns) == (column_names)
89+
90+
assert list(
91+
bern_gqs.draws_pd(vars=["chain__", "iter__", "draw__", 'y_rep']).columns
92+
) == (["chain__", "iter__", "draw__"] + column_names)
9193

9294

9395
def test_pd_xr_agreement():
@@ -315,9 +317,9 @@ def test_save_warmup(caplog: pytest.LogCaptureFixture) -> None:
315317
assert bern_gqs.draws_pd(inc_warmup=True).shape == (800, 13)
316318
assert bern_gqs.draws_pd(vars=['y_rep'], inc_warmup=False).shape == (
317319
400,
318-
13,
320+
10,
319321
)
320-
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 13)
322+
assert bern_gqs.draws_pd(vars='y_rep', inc_warmup=False).shape == (400, 10)
321323

322324
theta = bern_gqs.stan_variable(var='theta')
323325
assert theta.shape == (400,)

test/test_sample.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -778,24 +778,19 @@ def test_validate_good_run() -> None:
778778
fit.runset.chains * fit.num_draws_sampling,
779779
len(fit.column_names) + 3,
780780
)
781-
assert fit.draws_pd(vars=['theta']).shape == (400, 4)
782-
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 5)
783-
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 5)
784-
assert fit.draws_pd(vars='theta').shape == (400, 4)
781+
assert fit.draws_pd(vars=['theta']).shape == (400, 1)
782+
assert fit.draws_pd(vars=['lp__', 'theta']).shape == (400, 2)
783+
assert fit.draws_pd(vars=['theta', 'lp__']).shape == (400, 2)
784+
assert fit.draws_pd(vars='theta').shape == (400, 1)
785785

786786
assert list(fit.draws_pd(vars=['theta', 'lp__']).columns) == [
787-
'chain__',
788-
'iter__',
789-
'draw__',
790787
'theta',
791788
'lp__',
792789
]
793-
assert list(fit.draws_pd(vars=['lp__', 'theta']).columns) == [
794-
'chain__',
795-
'iter__',
796-
'draw__',
790+
assert list(fit.draws_pd(vars=['lp__', 'theta', 'iter__']).columns) == [
797791
'lp__',
798792
'theta',
793+
'iter__',
799794
]
800795

801796
summary = fit.summary()
@@ -854,7 +849,7 @@ def test_validate_big_run() -> None:
854849
assert fit.step_size.shape == (2,)
855850
assert fit.metric.shape == (2, 2095)
856851
assert fit.draws().shape == (1000, 2, 2102)
857-
assert fit.draws_pd(vars=['phi']).shape == (2000, 2098)
852+
assert fit.draws_pd(vars=['phi']).shape == (2000, 2095)
858853
with raises_nested(ValueError, r'Unknown variable: gamma'):
859854
fit.draws_pd(vars=['gamma'])
860855

0 commit comments

Comments
 (0)