@@ -179,16 +179,28 @@ def test_model_to_graphviz_for_model_with_data_container(self):
179
179
pm .Normal ("obs" , beta * x , obs_sigma , observed = y )
180
180
pm .sample (1000 , init = None , tune = 1000 , chains = 1 )
181
181
182
- g = pm .model_to_graphviz (model )
183
-
184
- # Data node rendered correctly?
185
- text = 'x [label="x\n ~\n Data" shape=box style="rounded, filled"]'
186
- assert text in g .source
187
- # Didn't break ordinary variables?
188
- text = 'beta [label="beta\n ~\n Normal(mu=0.0, sigma=10.0)"]'
189
- assert text in g .source
190
- text = f'obs [label="obs\n ~\n Normal(mu=f(f(beta), x), sigma={ obs_sigma } )" style=filled]'
191
- assert text in g .source
182
+ for formatting in {"latex" , "latex_with_params" }:
183
+ with pytest .raises (ValueError , match = "Unsupported formatting" ):
184
+ pm .model_to_graphviz (model , formatting = formatting )
185
+
186
+ exp_without = [
187
+ 'x [label="x\n ~\n Data" shape=box style="rounded, filled"]' ,
188
+ 'beta [label="beta\n ~\n Normal"]' ,
189
+ 'obs [label="obs\n ~\n Normal" style=filled]' ,
190
+ ]
191
+ exp_with = [
192
+ 'x [label="x\n ~\n Data" shape=box style="rounded, filled"]' ,
193
+ 'beta [label="beta\n ~\n Normal(mu=0.0, sigma=10.0)"]' ,
194
+ f'obs [label="obs\n ~\n Normal(mu=f(f(beta), x), sigma={ obs_sigma } )" style=filled]' ,
195
+ ]
196
+ for formatting , expected_substrings in [
197
+ ("plain" , exp_without ),
198
+ ("plain_with_params" , exp_with ),
199
+ ]:
200
+ g = pm .model_to_graphviz (model , formatting = formatting )
201
+ # check formatting of RV nodes
202
+ for expected in expected_substrings :
203
+ assert expected in g .source
192
204
193
205
def test_explicit_coords (self ):
194
206
N_rows = 5
0 commit comments