Skip to content

Commit 0233065

Browse files
committed
Summarize model as rich table
1 parent 2accca9 commit 0233065

File tree

3 files changed

+278
-0
lines changed

3 files changed

+278
-0
lines changed

docs/api_reference.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ Model Transforms
7474

7575
autoreparam.vip_reparametrize
7676
autoreparam.VIP
77+
78+
79+
Printing
80+
========
81+
.. currentmodule:: pymc_experimental.printing
82+
.. autosummary::
83+
:toctree: generated/
84+
85+
model_table

pymc_experimental/printing.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import itertools
2+
3+
import numpy as np
4+
5+
from pymc import Model
6+
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
7+
from pytensor.compile.sharedvalue import SharedVariable
8+
from pytensor.graph.type import Constant
9+
from rich.box import SIMPLE_HEAD
10+
from rich.table import Table
11+
12+
13+
def _extract_value(var: SharedVariable | Constant) -> np.ndarray:
14+
if isinstance(var, SharedVariable):
15+
return var.get_value(borrow=True)
16+
else:
17+
return var.data
18+
19+
20+
def model_table(
21+
model: Model,
22+
split_groups: bool = True,
23+
truncate_deterministic: int | None = None,
24+
parameter_count: bool = True,
25+
) -> Table:
26+
"""Create a rich table with a summary of the model's variables and their expressions.
27+
28+
Parameters
29+
----------
30+
model : Model
31+
The PyMC model to summarize.
32+
split_groups : bool
33+
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
34+
will be separated by a section.
35+
truncate_deterministic : int | None
36+
If not None, truncate the expression of deterministic variables that go beyond this length.
37+
parameter_count : bool
38+
If True, add a row with the total number of parameters in the model.
39+
40+
Returns
41+
-------
42+
Table
43+
A rich table with the model's variables, their expressions and dims.
44+
45+
Examples
46+
--------
47+
.. code-block:: python
48+
49+
import numpy as np
50+
import pymc as pm
51+
52+
from pymc_experimental.printing import model_table
53+
54+
coords = {"subject": range(20), "param": ["a", "b"]}
55+
with pm.Model(coords=coords) as m:
56+
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
57+
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
58+
59+
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
60+
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
61+
sigma = pm.HalfNormal("sigma", sigma=1)
62+
63+
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
64+
65+
table = model_table(m)
66+
table # Displays the following table in an interactive environment
67+
'''
68+
Variable Expression Dimensions
69+
─────────────────────────────────────────────────────
70+
x = Data subject[20] × param[2]
71+
y = Data subject[20]
72+
73+
beta ~ Normal(0, 1) param[2]
74+
sigma ~ HalfNormal(0, 1)
75+
Parameter count = 3
76+
77+
mu = f(beta) subject[20]
78+
79+
y_obs ~ Normal(mu, sigma) subject[20]
80+
'''
81+
82+
Output can be explicitly rendered in a rich console or exported to text, html or svg.
83+
84+
.. code-block:: python
85+
86+
from rich.console import Console
87+
88+
console = Console(record=True)
89+
console.print(table)
90+
text_export = console.export_text()
91+
html_export = console.export_html()
92+
svg_export = console.export_svg()
93+
94+
"""
95+
table = Table(
96+
show_header=True,
97+
show_edge=False,
98+
box=SIMPLE_HEAD,
99+
highlight=False,
100+
collapse_padding=True,
101+
)
102+
table.add_column("Variable", justify="right")
103+
table.add_column("Expression", justify="left")
104+
table.add_column("Dimensions")
105+
106+
dim_sizes = {k: _extract_value(v) for k, v in model.dim_lengths.items()}
107+
108+
groups = (
109+
model.data_vars,
110+
model.free_RVs,
111+
model.deterministics,
112+
model.potentials,
113+
model.observed_RVs,
114+
)
115+
if not split_groups:
116+
groups = (itertools.chain.from_iterable(groups),)
117+
118+
for group in groups:
119+
if not group:
120+
continue
121+
122+
for var in group:
123+
var_name = var.name
124+
dims = model.named_vars_to_dims.get(var_name, ())
125+
126+
is_data = var in model.data_vars
127+
is_deterministic = var in model.deterministics
128+
is_potential = var in model.potentials
129+
130+
if is_data:
131+
var_expr = "Data"
132+
elif is_deterministic:
133+
str_repr = str_for_potential_or_deterministic(var, dist_name="")
134+
_, var_expr = str_repr.split(" ~ ")
135+
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
136+
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
137+
contents = var_expr[2:-1].split(", ")
138+
str_len = 0
139+
for show_n, content in enumerate(contents):
140+
str_len += len(content) + 2
141+
if str_len > truncate_deterministic:
142+
break
143+
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
144+
elif is_potential:
145+
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(
146+
" ~ "
147+
)[1]
148+
else:
149+
var_expr = str_for_dist(var).split(" ~ ")[1]
150+
151+
dims_and_sizes = " × ".join(f"{dim}[{dim_sizes[dim]}]" for dim in dims)
152+
sep = f'[b]{" =" if (is_data or is_deterministic or is_potential) else " ~"}[/b]'
153+
table.add_row(var_name + sep, var_expr, dims_and_sizes)
154+
155+
if parameter_count and (not split_groups or group == model.free_RVs):
156+
rv_shapes = model.eval_rv_shapes()
157+
n_parameters = np.sum(
158+
[np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs]
159+
)
160+
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")
161+
162+
table.add_section()
163+
164+
return table

tests/test_printing.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import io
2+
3+
import numpy as np
4+
import pymc as pm
5+
6+
from rich.console import Console
7+
8+
from pymc_experimental.printing import model_table
9+
10+
11+
def get_text(table) -> str:
12+
console = Console(
13+
record=True,
14+
file=io.StringIO(),
15+
force_terminal=False,
16+
force_interactive=False,
17+
force_jupyter=False,
18+
)
19+
console.print(table)
20+
return console.export_text()
21+
22+
23+
def test_model_table():
24+
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
25+
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
26+
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
27+
28+
mu = pm.Normal("mu", mu=0, sigma=1)
29+
sigma = pm.HalfNormal("sigma", sigma=1)
30+
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
31+
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, dims="subject")
32+
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
33+
34+
mu_trial = pm.Deterministic(
35+
"mu_trial",
36+
global_intercept + intercept_subject + beta_subject * x_data,
37+
dims=["trial", "subject"],
38+
)
39+
noise = pm.Exponential("noise", lam=1)
40+
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
41+
42+
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
43+
44+
table_txt = get_text(model_table(model))
45+
expected = """ Variable Expression Dimensions
46+
────────────────────────────────────────────────────────────────────────────────
47+
x_data = Data trial[6] × subject[20]
48+
y_data = Data trial[6] × subject[20]
49+
50+
mu ~ Normal(0, 1)
51+
sigma ~ HalfNormal(0, 1)
52+
global_intercept ~ Normal(0, 1)
53+
intercept_subject ~ Normal(0, 1) subject[20]
54+
beta_subject ~ Normal(mu, sigma) subject[20]
55+
noise ~ Exponential(f())
56+
Parameter count = 44
57+
58+
mu_trial = f(beta_subject, trial[6] × subject[20]
59+
intercept_subject,
60+
global_intercept)
61+
62+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
63+
64+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
65+
"""
66+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
67+
68+
table_txt = get_text(model_table(model, split_groups=False))
69+
expected = """ Variable Expression Dimensions
70+
────────────────────────────────────────────────────────────────────────────────
71+
x_data = Data trial[6] × subject[20]
72+
y_data = Data trial[6] × subject[20]
73+
mu ~ Normal(0, 1)
74+
sigma ~ HalfNormal(0, 1)
75+
global_intercept ~ Normal(0, 1)
76+
intercept_subject ~ Normal(0, 1) subject[20]
77+
beta_subject ~ Normal(mu, sigma) subject[20]
78+
noise ~ Exponential(f())
79+
mu_trial = f(beta_subject, trial[6] × subject[20]
80+
intercept_subject,
81+
global_intercept)
82+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
83+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
84+
Parameter count = 44
85+
"""
86+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
87+
88+
table_txt = get_text(
89+
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
90+
)
91+
expected = """ Variable Expression Dimensions
92+
────────────────────────────────────────────────────────────────────────────
93+
x_data = Data trial[6] × subject[20]
94+
y_data = Data trial[6] × subject[20]
95+
mu ~ Normal(0, 1)
96+
sigma ~ HalfNormal(0, 1)
97+
global_intercept ~ Normal(0, 1)
98+
intercept_subject ~ Normal(0, 1) subject[20]
99+
beta_subject ~ Normal(mu, sigma) subject[20]
100+
noise ~ Exponential(f())
101+
mu_trial = f(beta_subject, ...) trial[6] × subject[20]
102+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
103+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
104+
"""
105+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

0 commit comments

Comments
 (0)