Skip to content

Commit eb1f844

Browse files
authored
Fix Discretization serialization when num_bins is used. (#20971)
Previously, serialization / deserialization would fail if: - the layer was saved / restored before `adapt` was called - the layer was saved / restored after `adapt` was called, but the dataset was such that the number of bins learned was fewer than `num_bins` The fix consists in adding a `from_config` to handle `bin_boundaries` separately. This is because at initial creation, `bin_boundaries` and `num_bins` cannot be both set, but when restoring the layer after `adapt`, they are both set. Tightened the error checking: - never allow `num_bins` and `bin_boundaries` to be specified at the same time, even if they match (same as `tf_keras`) - don't allow `num_bins` and `bin_boundaries` to be `None` at the same time - verify that `adapt` has been called in `call` Also removed `init_bin_boundaries` as the value was never used and its presence can be inferred.
1 parent 19b1418 commit eb1f844

File tree

2 files changed

+81
-15
lines changed

2 files changed

+81
-15
lines changed

keras/src/layers/preprocessing/discretization.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Discretization(TFDataLayer):
3434
and `[2., +inf)`.
3535
If this option is set, `adapt()` should not be called.
3636
num_bins: The integer number of bins to compute.
37-
If this option is set,
37+
If this option is set, `bin_boundaries` should not be set and
3838
`adapt()` should be called to learn the bin boundaries.
3939
epsilon: Error tolerance, typically a small fraction
4040
close to zero (e.g. 0.01). Higher values of epsilon increase
@@ -130,17 +130,17 @@ def __init__(
130130
f"Received: `num_bins={num_bins}`"
131131
)
132132
if num_bins is not None and bin_boundaries is not None:
133-
if len(bin_boundaries) != num_bins - 1:
134-
raise ValueError(
135-
"Both `num_bins` and `bin_boundaries` should not be "
136-
f"set. Received: `num_bins={num_bins}` and "
137-
f"`bin_boundaries={bin_boundaries}`"
138-
)
139-
140-
self.input_bin_boundaries = bin_boundaries
141-
self.bin_boundaries = (
142-
bin_boundaries if bin_boundaries is not None else []
143-
)
133+
raise ValueError(
134+
"Both `num_bins` and `bin_boundaries` should not be set. "
135+
f"Received: `num_bins={num_bins}` and "
136+
f"`bin_boundaries={bin_boundaries}`"
137+
)
138+
if num_bins is None and bin_boundaries is None:
139+
raise ValueError(
140+
"You need to set either `num_bins` or `bin_boundaries`."
141+
)
142+
143+
self.bin_boundaries = bin_boundaries
144144
self.num_bins = num_bins
145145
self.epsilon = epsilon
146146
self.output_mode = output_mode
@@ -183,7 +183,7 @@ def adapt(self, data, steps=None):
183183
repeating dataset, you must specify the `steps` argument. This
184184
argument is not supported with array inputs or list inputs.
185185
"""
186-
if self.input_bin_boundaries is not None:
186+
if self.num_bins is None:
187187
raise ValueError(
188188
"Cannot adapt a Discretization layer that has been initialized "
189189
"with `bin_boundaries`, use `num_bins` instead."
@@ -204,14 +204,14 @@ def update_state(self, data):
204204
self.summary = merge_summaries(summary, self.summary, self.epsilon)
205205

206206
def finalize_state(self):
207-
if self.input_bin_boundaries is not None:
207+
if self.num_bins is None:
208208
return
209209
self.bin_boundaries = get_bin_boundaries(
210210
self.summary, self.num_bins
211211
).tolist()
212212

213213
def reset_state(self):
214-
if self.input_bin_boundaries is not None:
214+
if self.num_bins is None:
215215
return
216216
self.summary = np.array([[], []], dtype="float32")
217217

@@ -225,6 +225,13 @@ def load_own_variables(self, store):
225225
return
226226

227227
def call(self, inputs):
228+
if self.bin_boundaries is None:
229+
raise ValueError(
230+
"You need to either pass the `bin_boundaries` argument at "
231+
"construction time or call `adapt(dataset)` before you can "
232+
"start using the `Discretization` layer."
233+
)
234+
228235
indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)
229236
return numerical_utils.encode_categorical_inputs(
230237
indices,
@@ -246,6 +253,23 @@ def get_config(self):
246253
"dtype": self.dtype,
247254
}
248255

256+
@classmethod
257+
def from_config(cls, config, custom_objects=None):
258+
if (
259+
config.get("bin_boundaries", None) is not None
260+
and config.get("num_bins", None) is not None
261+
):
262+
# After `adapt` was called, both `bin_boundaries` and `num_bins` are
263+
# populated, but `__init__` won't let us create a new layer with
264+
# both `bin_boundaries` and `num_bins`. We therefore apply
265+
# `bin_boundaries` after creation.
266+
config = config.copy()
267+
bin_boundaries = config.pop("bin_boundaries")
268+
discretization = cls(**config)
269+
discretization.bin_boundaries = bin_boundaries
270+
return discretization
271+
return cls(**config)
272+
249273

250274
def summarize(values, epsilon):
251275
"""Reduce a 1D sequence of values to a summary.

keras/src/layers/preprocessing/discretization_test.py

+42
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,32 @@ def test_tf_data_compatibility(self):
131131
for output in ds.take(1):
132132
output.numpy()
133133

134+
def test_serialization(self):
135+
layer = layers.Discretization(num_bins=5)
136+
137+
# Serialization before `adapt` is called.
138+
config = layer.get_config()
139+
revived_layer = layers.Discretization.from_config(config)
140+
self.assertEqual(config, revived_layer.get_config())
141+
142+
# Serialization after `adapt` is called but `num_bins` was not reached.
143+
layer.adapt(np.array([0.0, 1.0, 5.0]))
144+
config = layer.get_config()
145+
revived_layer = layers.Discretization.from_config(config)
146+
self.assertEqual(config, revived_layer.get_config())
147+
148+
# Serialization after `adapt` is called and `num_bins` is reached.
149+
layer.adapt(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]))
150+
config = layer.get_config()
151+
revived_layer = layers.Discretization.from_config(config)
152+
self.assertEqual(config, revived_layer.get_config())
153+
154+
# Serialization with `bin_boundaries`.
155+
layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])
156+
config = layer.get_config()
157+
revived_layer = layers.Discretization.from_config(config)
158+
self.assertEqual(config, revived_layer.get_config())
159+
134160
def test_saving(self):
135161
# With fixed bins
136162
layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])
@@ -163,3 +189,19 @@ def test_saving(self):
163189
model.save(fpath)
164190
model = saving_api.load_model(fpath)
165191
self.assertAllClose(layer(ref_input), ref_output)
192+
193+
def test_init_num_bins_and_bin_boundaries_raises(self):
194+
with self.assertRaisesRegex(
195+
ValueError, "Both `num_bins` and `bin_boundaries`"
196+
):
197+
layers.Discretization(num_bins=3, bin_boundaries=[0.0, 1.0])
198+
199+
with self.assertRaisesRegex(
200+
ValueError, "either `num_bins` or `bin_boundaries`"
201+
):
202+
layers.Discretization()
203+
204+
def test_call_before_adapt_raises(self):
205+
layer = layers.Discretization(num_bins=3)
206+
with self.assertRaisesRegex(ValueError, "You need .* call .*adapt"):
207+
layer([[0.1, 0.8, 0.9]])

0 commit comments

Comments
 (0)