@@ -430,6 +430,9 @@ def edges(
430
430
]
431
431
432
432
433
+ GraphAttrMapping = dict [Any , Any ]
434
+
435
+
433
436
def make_graph (
434
437
name : str ,
435
438
plates : list [Plate ],
@@ -439,6 +442,7 @@ def make_graph(
439
442
figsize = None ,
440
443
dpi = 300 ,
441
444
node_formatters : NodeTypeFormatterMapping | None = None ,
445
+ graph_attrs : GraphAttrMapping | None = None ,
442
446
create_plate_label : PlateLabelFunc = create_plate_label_with_dim_length ,
443
447
):
444
448
"""Make graphviz Digraph of PyMC model.
@@ -460,6 +464,9 @@ def make_graph(
460
464
node_formatters = update_node_formatters (node_formatters )
461
465
462
466
graph = graphviz .Digraph (name )
467
+ if graph_attrs is not None :
468
+ graph .attr (** graph_attrs )
469
+
463
470
for plate in plates :
464
471
if plate .dim_info :
465
472
# must be preceded by 'cluster' to get a box around it
@@ -676,6 +683,7 @@ def model_to_graphviz(
676
683
figsize : tuple [int , int ] | None = None ,
677
684
dpi : int = 300 ,
678
685
node_formatters : NodeTypeFormatterMapping | None = None ,
686
+ graph_attrs : GraphAttrMapping | None = None ,
679
687
include_dim_lengths : bool = True ,
680
688
):
681
689
"""Produce a graphviz Digraph from a PyMC model.
@@ -704,6 +712,10 @@ def model_to_graphviz(
704
712
the size of the saved figure.
705
713
dpi : int, optional
706
714
Dots per inch. It only affects the resolution of the saved figure. The default is 300.
715
+ graph_attrs : dict, optional
716
+ A dictionary of top-level layout attributes for graphviz
717
+ Check out graphviz documentation for more information on available attributes
718
+ https://graphviz.org/doc/info/attrs.html
707
719
node_formatters : dict, optional
708
720
A dictionary mapping node types to functions that return a dictionary of node attributes.
709
721
Check out graphviz documentation for more information on available
@@ -773,6 +785,7 @@ def model_to_graphviz(
773
785
save = save ,
774
786
figsize = figsize ,
775
787
dpi = dpi ,
788
+ graph_attrs = graph_attrs ,
776
789
node_formatters = node_formatters ,
777
790
create_plate_label = create_plate_label_with_dim_length
778
791
if include_dim_lengths
0 commit comments