-
-
Notifications
You must be signed in to change notification settings - Fork 60
Summarize model as rich table #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import numpy as np | ||
|
||
from pymc import Model | ||
from pymc.printing import str_for_dist, str_for_potential_or_deterministic | ||
from pytensor import Mode | ||
from pytensor.compile.sharedvalue import SharedVariable | ||
from pytensor.graph.type import Constant, Variable | ||
from rich.box import SIMPLE_HEAD | ||
from rich.table import Table | ||
|
||
|
||
def variable_expression( | ||
model: Model, | ||
var: Variable, | ||
truncate_deterministic: int | None, | ||
) -> str: | ||
"""Get the expression of a variable in a human-readable format.""" | ||
if var in model.data_vars: | ||
var_expr = "Data" | ||
elif var in model.deterministics: | ||
str_repr = str_for_potential_or_deterministic(var, dist_name="") | ||
_, var_expr = str_repr.split(" ~ ") | ||
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...)) | ||
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic: | ||
contents = var_expr[2:-1].split(", ") | ||
str_len = 0 | ||
for show_n, content in enumerate(contents): | ||
str_len += len(content) + 2 | ||
if str_len > truncate_deterministic: | ||
break | ||
var_expr = f"f({', '.join(contents[:show_n])}, ...)" | ||
elif var in model.potentials: | ||
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(" ~ ")[1] | ||
else: # basic_RVs | ||
var_expr = str_for_dist(var).split(" ~ ")[1] | ||
return var_expr | ||
|
||
|
||
def _extract_dim_value(var: SharedVariable | Constant) -> np.ndarray: | ||
if isinstance(var, SharedVariable): | ||
return var.get_value(borrow=True) | ||
else: | ||
return var.data | ||
|
||
|
||
def dims_expression(model: Model, var: Variable) -> str: | ||
"""Get the dimensions of a variable in a human-readable format.""" | ||
if (dims := model.named_vars_to_dims.get(var.name)) is not None: | ||
dim_sizes = {dim: _extract_dim_value(model.dim_lengths[dim]) for dim in dims} | ||
return " × ".join(f"{dim}[{dim_size}]" for dim, dim_size in dim_sizes.items()) | ||
else: | ||
dim_sizes = list(var.shape.eval(mode=Mode(linker="py", optimizer="fast_compile"))) | ||
return f"[{', '.join(map(str, dim_sizes))}]" if dim_sizes else "" | ||
|
||
|
||
def model_parameter_count(model: Model) -> int: | ||
"""Count the number of parameters in the model.""" | ||
rv_shapes = model.eval_rv_shapes() # Includes transformed variables | ||
return np.sum([np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs]) | ||
|
||
|
||
def model_table( | ||
model: Model, | ||
*, | ||
split_groups: bool = True, | ||
truncate_deterministic: int | None = None, | ||
parameter_count: bool = True, | ||
) -> Table: | ||
"""Create a rich table with a summary of the model's variables and their expressions. | ||
|
||
Parameters | ||
---------- | ||
model : Model | ||
The PyMC model to summarize. | ||
split_groups : bool | ||
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs) | ||
will be separated by a section. | ||
truncate_deterministic : int | None | ||
If not None, truncate the expression of deterministic variables that go beyond this length. | ||
empty_dims : bool | ||
If True, show the dimensions of scalar variables as an empty list. | ||
parameter_count : bool | ||
If True, add a row with the total number of parameters in the model. | ||
|
||
Returns | ||
------- | ||
Table | ||
A rich table with the model's variables, their expressions and dims. | ||
|
||
Examples | ||
-------- | ||
.. code-block:: python | ||
|
||
import numpy as np | ||
import pymc as pm | ||
|
||
from pymc_experimental.printing import model_table | ||
|
||
coords = {"subject": range(20), "param": ["a", "b"]} | ||
with pm.Model(coords=coords) as m: | ||
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param")) | ||
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject") | ||
|
||
beta = pm.Normal("beta", mu=0, sigma=1, dims="param") | ||
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject") | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
|
||
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject") | ||
|
||
table = model_table(m) | ||
table # Displays the following table in an interactive environment | ||
''' | ||
Variable Expression Dimensions | ||
───────────────────────────────────────────────────── | ||
x = Data subject[20] × param[2] | ||
y = Data subject[20] | ||
|
||
beta ~ Normal(0, 1) param[2] | ||
sigma ~ HalfNormal(0, 1) | ||
Parameter count = 3 | ||
|
||
mu = f(beta) subject[20] | ||
|
||
y_obs ~ Normal(mu, sigma) subject[20] | ||
''' | ||
|
||
Output can be explicitly rendered in a rich console or exported to text, html or svg. | ||
|
||
.. code-block:: python | ||
|
||
from rich.console import Console | ||
|
||
console = Console(record=True) | ||
console.print(table) | ||
text_export = console.export_text() | ||
html_export = console.export_html() | ||
svg_export = console.export_svg() | ||
|
||
""" | ||
table = Table( | ||
show_header=True, | ||
show_edge=False, | ||
box=SIMPLE_HEAD, | ||
highlight=False, | ||
collapse_padding=True, | ||
) | ||
table.add_column("Variable", justify="right") | ||
table.add_column("Expression", justify="left") | ||
table.add_column("Dimensions") | ||
|
||
if split_groups: | ||
groups = ( | ||
model.data_vars, | ||
model.free_RVs, | ||
model.deterministics, | ||
model.potentials, | ||
model.observed_RVs, | ||
) | ||
else: | ||
# Show variables in the order they were defined | ||
groups = (model.named_vars.values(),) | ||
|
||
for group in groups: | ||
if not group: | ||
continue | ||
|
||
for var in group: | ||
var_name = var.name | ||
sep = f'[b]{" ~" if (var in model.basic_RVs) else " ="}[/b]' | ||
var_expr = variable_expression(model, var, truncate_deterministic) | ||
dims_expr = dims_expression(model, var) | ||
if dims_expr == "[]": | ||
dims_expr = "" | ||
table.add_row(var_name + sep, var_expr, dims_expr) | ||
|
||
if parameter_count and (not split_groups or group == model.free_RVs): | ||
n_parameters = model_parameter_count(model) | ||
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]") | ||
|
||
table.add_section() | ||
|
||
return table |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import numpy as np | ||
import pymc as pm | ||
|
||
from rich.console import Console | ||
|
||
from pymc_experimental.printing import model_table | ||
|
||
|
||
def get_text(table) -> str: | ||
console = Console(width=80) | ||
with console.capture() as capture: | ||
console.print(table) | ||
return capture.get() | ||
|
||
|
||
def test_model_table(): | ||
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model: | ||
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
|
||
mu = pm.Normal("mu", mu=0, sigma=1) | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1) | ||
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, shape=(20, 1)) | ||
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject") | ||
|
||
mu_trial = pm.Deterministic( | ||
"mu_trial", | ||
global_intercept.squeeze() + intercept_subject + beta_subject * x_data, | ||
dims=["trial", "subject"], | ||
) | ||
noise = pm.Exponential("noise", lam=1) | ||
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject")) | ||
|
||
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject") | ||
|
||
table_txt = get_text(model_table(model)) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
|
||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
noise ~ Exponential(f()) | ||
Parameter count = 44 | ||
|
||
mu_trial = f(intercept_subject, trial[6] × subject[20] | ||
beta_subject, | ||
global_intercept) | ||
|
||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
|
||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] | ||
|
||
table_txt = get_text(model_table(model, split_groups=False)) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
mu_trial = f(intercept_subject, trial[6] × subject[20] | ||
beta_subject, | ||
global_intercept) | ||
noise ~ Exponential(f()) | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
Parameter count = 44 | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] | ||
|
||
table_txt = get_text( | ||
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False) | ||
) | ||
expected = """ Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) [20, 1] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
mu_trial = f(intercept_subject, ...) trial[6] × subject[20] | ||
noise ~ Exponential(f()) | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
""" | ||
assert [s.strip() for s in table_txt.splitlines()] == [s.strip() for s in expected.splitlines()] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.