Skip to content

Commit beae526

Browse files
committed
Improve option handling
1 parent bcbaa6c commit beae526

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

spaceplot/__init__.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from matplotlib.pyplot import figure as _figure
77

88

9-
def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
9+
def spaceplots(
10+
inputs, outputs, input_names=None, output_names=None, limits=None, **kwargs
11+
):
1012
num_samples, num_inputs = inputs.shape
1113
if input_names is not None:
1214
if len(input_names) != num_inputs:
@@ -23,10 +25,21 @@ def spaceplots(inputs, outputs, input_names=None, output_names=None, **kwargs):
2325
else:
2426
output_names = [None] * num_outputs
2527

28+
if limits is not None:
29+
if limits.shape[1] != 2:
30+
raise RuntimeError(
31+
"There must be a upper and lower limit for each output"
32+
)
33+
elif limits.shape[0] != num_outputs:
34+
raise RuntimeError("Output data and limits don't match")
35+
else:
36+
limits = [[None, None]] * num_outputs
37+
2638
for out_index in range(num_outputs):
2739
yield _subspace_plot(
2840
inputs, outputs[:, out_index], input_names=input_names,
29-
output_name=output_names[out_index], **kwargs
41+
output_name=output_names[out_index], min_output=limits[out_index][0],
42+
max_output=limits[out_index][1], **kwargs
3043
)
3144

3245

@@ -79,7 +92,18 @@ def _setup_axes(
7992
return fig, axes, grid
8093

8194

82-
def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
95+
def _subspace_plot(
96+
inputs, output, *, input_names, output_name, scatter_args=None,
97+
histogram_args=None, min_output=None, max_output=None
98+
):
99+
if scatter_args is None:
100+
scatter_args = {}
101+
if histogram_args is None:
102+
histogram_args = {}
103+
if min_output is None:
104+
min_output = min(output)
105+
if max_output is None:
106+
max_output = max(output)
83107

84108
# see https://matplotlib.org/examples/pylab_examples/multi_image.html
85109
_, num_inputs = inputs.shape
@@ -89,11 +113,13 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
89113
if output_name is not None:
90114
fig.suptitle(output_name)
91115

92-
norm = _Normalize(min(output), max(output)) # TODO: get from user if needed
116+
norm = _Normalize(min_output, max_output)
93117

94118
hist_plots = []
95119
for i in range(num_inputs):
96-
hist_plots.append(_plot_hist(inputs[:, i], axis=axes[i][i]))
120+
hist_plots.append(_plot_hist(
121+
inputs[:, i], axis=axes[i][i], **histogram_args
122+
))
97123

98124
scatter_plots = []
99125
scatter_plots_grid = []
@@ -103,7 +129,7 @@ def _subspace_plot(inputs, output, *, input_names, output_name, **kwargs):
103129
sc_plot = _plot_scatter(
104130
x=inputs[:, x_index], y=inputs[:, y_index], z=output,
105131
axis=axes[y_index][x_index], # check order
106-
norm=norm
132+
norm=norm, **scatter_args
107133
)
108134
scatter_plots.append(sc_plot)
109135
scatter_plots_grid[y_index].append(sc_plot)

0 commit comments

Comments
 (0)