6
6
from matplotlib .pyplot import figure as _figure
7
7
8
8
9
- def spaceplots (inputs , outputs , input_names = None , output_names = None , ** kwargs ):
9
+ def spaceplots (
10
+ inputs , outputs , input_names = None , output_names = None , limits = None , ** kwargs
11
+ ):
10
12
num_samples , num_inputs = inputs .shape
11
13
if input_names is not None :
12
14
if len (input_names ) != num_inputs :
@@ -23,10 +25,21 @@ def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
23
25
else :
24
26
output_names = [None ] * num_outputs
25
27
28
+ if limits is not None :
29
+ if limits .shape [1 ] != 2 :
30
+ raise RuntimeError (
31
+ "There must be a upper and lower limit for each output"
32
+ )
33
+ elif limits .shape [0 ] != num_outputs :
34
+ raise RuntimeError ("Output data and limits don't match" )
35
+ else :
36
+ limits = [[None , None ]] * num_outputs
37
+
26
38
for out_index in range (num_outputs ):
27
39
yield _subspace_plot (
28
40
inputs , outputs [:, out_index ], input_names = input_names ,
29
- output_name = output_names [out_index ], ** kwargs
41
+ output_name = output_names [out_index ], min_output = limits [out_index ][0 ],
42
+ max_output = limits [out_index ][1 ], ** kwargs
30
43
)
31
44
32
45
@@ -79,7 +92,18 @@ def _setup_axes(
79
92
return fig , axes , grid
80
93
81
94
82
- def _subspace_plot (inputs , output , * , input_names , output_name , ** kwargs ):
95
+ def _subspace_plot (
96
+ inputs , output , * , input_names , output_name , scatter_args = None ,
97
+ histogram_args = None , min_output = None , max_output = None
98
+ ):
99
+ if scatter_args is None :
100
+ scatter_args = {}
101
+ if histogram_args is None :
102
+ histogram_args = {}
103
+ if min_output is None :
104
+ min_output = min (output )
105
+ if max_output is None :
106
+ max_output = max (output )
83
107
84
108
# see https://matplotlib.org/examples/pylab_examples/multi_image.html
85
109
_ , num_inputs = inputs .shape
@@ -89,11 +113,13 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
89
113
if output_name is not None :
90
114
fig .suptitle (output_name )
91
115
92
- norm = _Normalize (min ( output ), max ( output )) # TODO: get from user if needed
116
+ norm = _Normalize (min_output , max_output )
93
117
94
118
hist_plots = []
95
119
for i in range (num_inputs ):
96
- hist_plots .append (_plot_hist (inputs [:, i ], axis = axes [i ][i ]))
120
+ hist_plots .append (_plot_hist (
121
+ inputs [:, i ], axis = axes [i ][i ], ** histogram_args
122
+ ))
97
123
98
124
scatter_plots = []
99
125
scatter_plots_grid = []
@@ -103,7 +129,7 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
103
129
sc_plot = _plot_scatter (
104
130
x = inputs [:, x_index ], y = inputs [:, y_index ], z = output ,
105
131
axis = axes [y_index ][x_index ], # check order
106
- norm = norm
132
+ norm = norm , ** scatter_args
107
133
)
108
134
scatter_plots .append (sc_plot )
109
135
scatter_plots_grid [y_index ].append (sc_plot )
0 commit comments