Skip to content

Commit 8d71a87

Browse files
vbardiovskyggoldiegadde
authored andcommitted
Add saving of loaded/trained compatibility models in test and fix a compatibility bug.
PiperOrigin-RevId: 273455709
1 parent 38ea9bb commit 8d71a87

File tree

5 files changed

+23
-2
lines changed

5 files changed

+23
-2
lines changed

Diff for: tensorflow/examples/saved_model/integration_tests/use_model_in_sequential_keras.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import tempfile
2122
from absl import app
2223
from absl import flags
2324

@@ -57,6 +58,10 @@ def train(fine_tuning):
5758

5859
model.fit_generator(generator=dataset.batch(1), epochs=5)
5960

61+
# This is testing that a model using a SavedModel can be re-exported again,
62+
# e.g. to catch issues such as b/142231881.
63+
tf.saved_model.save(model, tempfile.mkdtemp())
64+
6065

6166
def main(argv):
6267
del argv

Diff for: tensorflow/examples/saved_model/integration_tests/use_rnn_cell.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import tempfile
2122
from absl import app
2223
from absl import flags
2324
import numpy as np
@@ -39,6 +40,10 @@ def main(argv):
3940
tf.constant(np.random.uniform(size=[3, 19]).astype(np.float32)),
4041
initial_state)
4142

43+
# This is testing that a model using a SavedModel can be re-exported again,
44+
# e.g. to catch issues such as b/142231881.
45+
tf.saved_model.save(cell, tempfile.mkdtemp())
46+
4247

4348
if __name__ == "__main__":
4449
app.run(main)

Diff for: tensorflow/examples/saved_model/integration_tests/use_text_embedding_in_dataset.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import tempfile
2122
from absl import app
2223
from absl import flags
2324

@@ -55,6 +56,10 @@ def _map_fn(features, labels):
5556

5657
model.fit_generator(generator=dataset.batch(10), epochs=5)
5758

59+
# This is testing that a model using a SavedModel can be re-exported again,
60+
# e.g. to catch issues such as b/142231881.
61+
tf.saved_model.save(model, tempfile.mkdtemp())
62+
5863

5964
def main(argv):
6065
del argv

Diff for: tensorflow/examples/saved_model/integration_tests/use_text_rnn_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import division
1919
from __future__ import print_function
2020

21+
import tempfile
2122
from absl import app
2223
from absl import flags
2324
import tensorflow.compat.v2 as tf
@@ -40,6 +41,9 @@ def main(argv):
4041
sequence_length=10, first_word=tf.constant("<S>"))
4142
_ = [d.numpy() for d in decoded]
4243

44+
# This is testing that a model using a SavedModel can be re-exported again,
45+
# e.g. to catch issues such as b/142231881.
46+
tf.saved_model.save(model, tempfile.mkdtemp())
4347

4448
if __name__ == "__main__":
4549
app.run(main)

Diff for: tensorflow/python/saved_model/load.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -425,11 +425,13 @@ def _list_functions_for_serialization(self, unused_serialization_cache):
425425
# Overwrite this method to avoid the implementation of
426426
# base class to re-wrap the polymorphic functions into
427427
# another layer of `tf.function`.
428-
return {
428+
functions = {
429429
"_create_resource": self._create_resource,
430430
"_initialize": self._initialize,
431-
"_destroy_resource": self._destroy_resource,
432431
}
432+
if self._destroy_resource:
433+
functions.update(_destroy_resource=self._destroy_resource)
434+
return functions
433435

434436

435437
def _call_attribute(instance, *args, **kwargs):

0 commit comments

Comments
 (0)