Skip to content

Commit 7efaff8

Browse files
authored
fix legacy model saving & reloading with axis argument in its layer (#20973)
* fix legacy model saving & relaoding with axis arg in layer * fix formatting issue * add temp_file_path
1 parent 86dce6f commit 7efaff8

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

keras/src/legacy/saving/legacy_h5_format_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,21 @@ class RegisteredSubLayer(layers.Layer):
279279
self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer)
280280
self.assertEqual(loaded_layer.sublayers[1].name, "MySubLayer")
281281

282+
def test_model_loading_with_axis_arg(self):
283+
input1 = layers.Input(shape=(1, 4), name="input1")
284+
input2 = layers.Input(shape=(1, 4), name="input2")
285+
concat1 = layers.Concatenate(axis=1)([input1, input2])
286+
output = layers.Dense(1, activation="sigmoid")(concat1)
287+
model = models.Model(inputs=[input1, input2], outputs=output)
288+
model.compile(
289+
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
290+
)
291+
temp_filepath = os.path.join(
292+
self.get_temp_dir(), "model_with_axis_arg.h5"
293+
)
294+
legacy_h5_format.save_model_to_hdf5(model, temp_filepath)
295+
legacy_h5_format.load_model_from_hdf5(temp_filepath)
296+
282297

283298
@pytest.mark.requires_trainable_backend
284299
@pytest.mark.skipif(tf_keras is None, reason="Test requires tf_keras")

keras/src/legacy/saving/saving_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ def model_from_config(config, custom_objects=None):
6363
config["config"]["input_shape"] = batch_input_shape
6464

6565
axis = config["config"].pop("axis", None)
66-
if axis is not None and isinstance(axis, list) and len(axis) == 1:
67-
config["config"]["axis"] = int(axis[0])
66+
if axis is not None:
67+
if isinstance(axis, list) and len(axis) == 1:
68+
config["config"]["axis"] = int(axis[0])
69+
elif isinstance(axis, (int, float)):
70+
config["config"]["axis"] = int(axis)
6871

6972
# Handle backwards compatibility for Keras lambdas
7073
if config["class_name"] == "Lambda":

0 commit comments

Comments
 (0)