Skip to content

Commit 78fdd17

Browse files
committed
Add a way to pass in graph level attributes to graphviz
1 parent 5db3779 commit 78fdd17

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

pymc/model_graph.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ def edges(
430430
]
431431

432432

433+
GraphAttrMapping = dict[Any, Any]
434+
435+
433436
def make_graph(
434437
name: str,
435438
plates: list[Plate],
@@ -439,6 +442,7 @@ def make_graph(
439442
figsize=None,
440443
dpi=300,
441444
node_formatters: NodeTypeFormatterMapping | None = None,
445+
graph_attrs: GraphAttrMapping | None = None,
442446
create_plate_label: PlateLabelFunc = create_plate_label_with_dim_length,
443447
):
444448
"""Make graphviz Digraph of PyMC model.
@@ -460,6 +464,9 @@ def make_graph(
460464
node_formatters = update_node_formatters(node_formatters)
461465

462466
graph = graphviz.Digraph(name)
467+
if graph_attrs is not None:
468+
graph.attr(**graph_attrs)
469+
463470
for plate in plates:
464471
if plate.dim_info:
465472
# must be preceded by 'cluster' to get a box around it
@@ -676,6 +683,7 @@ def model_to_graphviz(
676683
figsize: tuple[int, int] | None = None,
677684
dpi: int = 300,
678685
node_formatters: NodeTypeFormatterMapping | None = None,
686+
graph_attrs: GraphAttrMapping | None = None,
679687
include_dim_lengths: bool = True,
680688
):
681689
"""Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +712,10 @@ def model_to_graphviz(
704712
the size of the saved figure.
705713
dpi : int, optional
706714
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
715+
graph_attrs : dict, optional
716+
A dictionary of top-level layout attributes for graphviz
717+
Check out graphviz documentation for more information on available attributes
718+
https://graphviz.org/doc/info/attrs.html
707719
node_formatters : dict, optional
708720
A dictionary mapping node types to functions that return a dictionary of node attributes.
709721
Check out graphviz documentation for more information on available
@@ -773,6 +785,7 @@ def model_to_graphviz(
773785
save=save,
774786
figsize=figsize,
775787
dpi=dpi,
788+
graph_attrs=graph_attrs,
776789
node_formatters=node_formatters,
777790
create_plate_label=create_plate_label_with_dim_length
778791
if include_dim_lengths

0 commit comments

Comments
 (0)