Skip to content

Commit 15ce7bf

Browse files
check graphviz results with all four formatting options
1 parent 185b609 commit 15ce7bf

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

Diff for: pymc3/tests/test_data_container.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self):
179179
pm.Normal("obs", beta * x, obs_sigma, observed=y)
180180
pm.sample(1000, init=None, tune=1000, chains=1)
181181

182-
g = pm.model_to_graphviz(model)
183-
184-
# Data node rendered correctly?
185-
text = 'x [label="x\n~\nData" shape=box style="rounded, filled"]'
186-
assert text in g.source
187-
# Didn't break ordinary variables?
188-
text = 'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]'
189-
assert text in g.source
190-
text = f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]'
191-
assert text in g.source
182+
for formatting in {"latex", "latex_with_params"}:
183+
with pytest.raises(ValueError, match="Unsupported formatting"):
184+
pm.model_to_graphviz(model, formatting=formatting)
185+
186+
exp_without = [
187+
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
188+
'beta [label="beta\n~\nNormal"]',
189+
'obs [label="obs\n~\nNormal" style=filled]',
190+
]
191+
exp_with = [
192+
'x [label="x\n~\nData" shape=box style="rounded, filled"]',
193+
'beta [label="beta\n~\nNormal(mu=0.0, sigma=10.0)"]',
194+
f'obs [label="obs\n~\nNormal(mu=f(f(beta), x), sigma={obs_sigma})" style=filled]',
195+
]
196+
for formatting, expected_substrings in [
197+
("plain", exp_without),
198+
("plain_with_params", exp_with),
199+
]:
200+
g = pm.model_to_graphviz(model, formatting=formatting)
201+
# check formatting of RV nodes
202+
for expected in expected_substrings:
203+
assert expected in g.source
192204

193205
def test_explicit_coords(self):
194206
N_rows = 5

0 commit comments

Comments
 (0)