Skip to content

Commit 9dc7665

Browse files
committed
Summarize model as rich table
1 parent 2accca9 commit 9dc7665

File tree

3 files changed

+282
-0
lines changed

3 files changed

+282
-0
lines changed

docs/api_reference.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ Model Transforms
7474

7575
autoreparam.vip_reparametrize
7676
autoreparam.VIP
77+
78+
79+
Printing
80+
========
81+
.. currentmodule:: pymc_experimental.printing
82+
.. autosummary::
83+
:toctree: generated/
84+
85+
model_table

pymc_experimental/printing.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import numpy as np
2+
3+
from pymc import Model
4+
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
5+
from pytensor.compile.sharedvalue import SharedVariable
6+
from pytensor.graph.type import Constant, Variable
7+
from rich.box import SIMPLE_HEAD
8+
from rich.table import Table
9+
10+
11+
def variable_expression(
12+
model: Model,
13+
var: Variable,
14+
truncate_deterministic: int | None,
15+
) -> str:
16+
"""Get the expression of a variable in a human-readable format."""
17+
if var in model.data_vars:
18+
var_expr = "Data"
19+
elif var in model.deterministics:
20+
str_repr = str_for_potential_or_deterministic(var, dist_name="")
21+
_, var_expr = str_repr.split(" ~ ")
22+
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
23+
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
24+
contents = var_expr[2:-1].split(", ")
25+
str_len = 0
26+
for show_n, content in enumerate(contents):
27+
str_len += len(content) + 2
28+
if str_len > truncate_deterministic:
29+
break
30+
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
31+
elif var in model.potentials:
32+
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1]
33+
else: # basic_RVs
34+
var_expr = str_for_dist(var).split(" ~ ")[1]
35+
return var_expr
36+
37+
38+
def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray:
39+
if isinstance(var, SharedVariable):
40+
return var.get_value(borrow=True)
41+
else:
42+
return var.data
43+
44+
45+
def dims_expression(model: Model, var: Variable) -> str:
46+
dim_sizes = {
47+
dim: _extract_dim_value(model.dim_lengths[dim])
48+
for dim in model.named_vars_to_dims.get(var.name, ())
49+
}
50+
dims_and_sizes = " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items())
51+
return dims_and_sizes
52+
53+
54+
def model_parameter_count(model: Model) -> int:
55+
"""Count the number of parameters in the model."""
56+
rv_shapes = model.eval_rv_shapes() # Includes transformed variables
57+
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs])
58+
59+
60+
def model_table(
61+
model: Model,
62+
split_groups: bool = True,
63+
truncate_deterministic: int | None = None,
64+
parameter_count: bool = True,
65+
) -> Table:
66+
"""Create a rich table with a summary of the model's variables and their expressions.
67+
68+
Parameters
69+
----------
70+
model : Model
71+
The PyMC model to summarize.
72+
split_groups : bool
73+
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
74+
will be separated by a section.
75+
truncate_deterministic : int | None
76+
If not None, truncate the expression of deterministic variables that go beyond this length.
77+
parameter_count : bool
78+
If True, add a row with the total number of parameters in the model.
79+
80+
Returns
81+
-------
82+
Table
83+
A rich table with the model's variables, their expressions and dims.
84+
85+
Examples
86+
--------
87+
.. code-block:: python
88+
89+
import numpy as np
90+
import pymc as pm
91+
92+
from pymc_experimental.printing import model_table
93+
94+
coords = {"subject": range(20), "param": ["a", "b"]}
95+
with pm.Model(coords=coords) as m:
96+
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
97+
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
98+
99+
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
100+
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
101+
sigma = pm.HalfNormal("sigma", sigma=1)
102+
103+
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
104+
105+
table = model_table(m)
106+
table # Displays the following table in an interactive environment
107+
'''
108+
Variable Expression Dimensions
109+
─────────────────────────────────────────────────────
110+
x = Data subject[20] × param[2]
111+
y = Data subject[20]
112+
113+
beta ~ Normal(0, 1) param[2]
114+
sigma ~ HalfNormal(0, 1)
115+
Parameter count = 3
116+
117+
mu = f(beta) subject[20]
118+
119+
y_obs ~ Normal(mu, sigma) subject[20]
120+
'''
121+
122+
Output can be explicitly rendered in a rich console or exported to text, html or svg.
123+
124+
.. code-block:: python
125+
126+
from rich.console import Console
127+
128+
console = Console(record=True)
129+
console.print(table)
130+
text_export = console.export_text()
131+
html_export = console.export_html()
132+
svg_export = console.export_svg()
133+
134+
"""
135+
table = Table(
136+
show_header=True,
137+
show_edge=False,
138+
box=SIMPLE_HEAD,
139+
highlight=False,
140+
collapse_padding=True,
141+
)
142+
table.add_column("Variable", justify="right")
143+
table.add_column("Expression", justify="left")
144+
table.add_column("Dimensions")
145+
146+
if split_groups:
147+
groups = (
148+
model.data_vars,
149+
model.free_RVs,
150+
model.deterministics,
151+
model.potentials,
152+
model.observed_RVs,
153+
)
154+
else:
155+
# Show variables in the order they were defined
156+
groups = (model.named_vars.values(),)
157+
158+
for group in groups:
159+
if not group:
160+
continue
161+
162+
for var in group:
163+
var_name = var.name
164+
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]'
165+
var_expr = variable_expression(model, var, truncate_deterministic)
166+
dims_expr = dims_expression(model, var)
167+
table.add_row(var_name + sep, var_expr, dims_expr)
168+
169+
if parameter_count and (not split_groups or group == model.free_RVs):
170+
n_parameters = model_parameter_count(model)
171+
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")
172+
173+
table.add_section()
174+
175+
return table

tests/test_printing.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
import pymc as pm
3+
4+
from rich.console import Console
5+
6+
from pymc_experimental.printing import model_table
7+
8+
9+
def get_text(table) -> str:
10+
console = Console(width=80)
11+
with console.capture() as capture:
12+
console.print(table)
13+
return capture.get()
14+
15+
16+
def test_model_table():
17+
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
18+
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
19+
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
20+
21+
mu = pm.Normal("mu", mu=0, sigma=1)
22+
sigma = pm.HalfNormal("sigma", sigma=1)
23+
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
24+
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, dims="subject")
25+
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")
26+
27+
mu_trial = pm.Deterministic(
28+
"mu_trial",
29+
global_intercept + intercept_subject + beta_subject * x_data,
30+
dims=["trial", "subject"],
31+
)
32+
noise = pm.Exponential("noise", lam=1)
33+
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))
34+
35+
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")
36+
37+
table_txt = get_text(model_table(model))
38+
expected = """ Variable Expression Dimensions
39+
────────────────────────────────────────────────────────────────────────────────
40+
x_data = Data trial[6] × subject[20]
41+
y_data = Data trial[6] × subject[20]
42+
43+
mu ~ Normal(0, 1)
44+
sigma ~ HalfNormal(0, 1)
45+
global_intercept ~ Normal(0, 1)
46+
intercept_subject ~ Normal(0, 1) subject[20]
47+
beta_subject ~ Normal(mu, sigma) subject[20]
48+
noise ~ Exponential(f())
49+
Parameter count = 44
50+
51+
mu_trial = f(beta_subject, trial[6] × subject[20]
52+
intercept_subject,
53+
global_intercept)
54+
55+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
56+
57+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
58+
"""
59+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
60+
61+
table_txt = get_text(model_table(model, split_groups=False))
62+
expected = """ Variable Expression Dimensions
63+
────────────────────────────────────────────────────────────────────────────────
64+
x_data = Data trial[6] × subject[20]
65+
y_data = Data trial[6] × subject[20]
66+
mu ~ Normal(0, 1)
67+
sigma ~ HalfNormal(0, 1)
68+
global_intercept ~ Normal(0, 1)
69+
intercept_subject ~ Normal(0, 1) subject[20]
70+
beta_subject ~ Normal(mu, sigma) subject[20]
71+
mu_trial = f(beta_subject, trial[6] × subject[20]
72+
intercept_subject,
73+
global_intercept)
74+
noise ~ Exponential(f())
75+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
76+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
77+
Parameter count = 44
78+
"""
79+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]
80+
81+
table_txt = get_text(
82+
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
83+
)
84+
expected = """ Variable Expression Dimensions
85+
────────────────────────────────────────────────────────────────────────────
86+
x_data = Data trial[6] × subject[20]
87+
y_data = Data trial[6] × subject[20]
88+
mu ~ Normal(0, 1)
89+
sigma ~ HalfNormal(0, 1)
90+
global_intercept ~ Normal(0, 1)
91+
intercept_subject ~ Normal(0, 1) subject[20]
92+
beta_subject ~ Normal(mu, sigma) subject[20]
93+
mu_trial = f(beta_subject, ...) trial[6] × subject[20]
94+
noise ~ Exponential(f())
95+
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
96+
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
97+
"""
98+
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()]

0 commit comments

Comments
 (0)