Skip to content

Commit 9d3aaa0

Browse files
Jiayu Yetensorflower-gardener
Jiayu Ye
authored andcommitted
Internal change
PiperOrigin-RevId: 520673564
1 parent 209a259 commit 9d3aaa0

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

orbit/actions/export_saved_model.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,13 @@ class ExportFileManager:
5959
customized naming and cleanup strategies.
6060
"""
6161

62-
def __init__(self,
63-
base_name: str,
64-
max_to_keep: int = 5,
65-
next_id_fn: Optional[Callable[[], int]] = None):
62+
def __init__(
63+
self,
64+
base_name: str,
65+
max_to_keep: int = 5,
66+
next_id_fn: Optional[Callable[[], int]] = None,
67+
subdirectory: Optional[str] = None,
68+
):
6669
"""Initializes the instance.
6770
6871
Args:
@@ -77,10 +80,14 @@ def __init__(self,
7780
If not supplied, a default ID based on an incrementing counter is used.
7881
One common alternative maybe be to use the current global step count,
7982
for instance passing `next_id_fn=global_step.numpy`.
83+
subdirectory: An optional subdirectory to concat after the
84+
{base_name}-{id}. Then the file manager will manage
85+
{base_name}-{id}/{subdirectory} files.
8086
"""
8187
self._base_name = os.path.normpath(base_name)
8288
self._max_to_keep = max_to_keep
8389
self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name)
90+
self._subdirectory = subdirectory or ''
8491

8592
@property
8693
def managed_files(self):
@@ -91,7 +98,10 @@ def managed_files(self):
9198
`ExportFileManager` instance, sorted in increasing integer order of the
9299
IDs returned by `next_id_fn`.
93100
"""
94-
return _find_managed_files(self._base_name)
101+
files = _find_managed_files(self._base_name)
102+
return [
103+
os.path.normpath(os.path.join(f, self._subdirectory)) for f in files
104+
]
95105

96106
def clean_up(self):
97107
"""Cleans up old files matching `{base_name}-*`.
@@ -101,12 +111,15 @@ def clean_up(self):
101111
if self._max_to_keep < 0:
102112
return
103113

104-
for filename in self.managed_files[:-self._max_to_keep]:
114+
# Note that the base folder will remain intact, only the folder with suffix
115+
# is deleted.
116+
for filename in self.managed_files[: -self._max_to_keep]:
105117
tf.io.gfile.rmtree(filename)
106118

107119
def next_name(self) -> str:
108120
"""Returns a new file name based on `base_name` and `next_id_fn()`."""
109-
return f'{self._base_name}-{self._next_id_fn()}'
121+
base_path = f'{self._base_name}-{self._next_id_fn()}'
122+
return os.path.normpath(os.path.join(base_path, self._subdirectory))
110123

111124

112125
class ExportSavedModel:

orbit/actions/export_saved_model_test.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_export_file_manager_default_ids(self):
4646
directory = self.create_tempdir()
4747
base_name = os.path.join(directory.full_path, 'basename')
4848
manager = actions.ExportFileManager(base_name, max_to_keep=3)
49-
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
49+
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
5050
directory.create_file(manager.next_name())
5151
manager.clean_up() # Shouldn't do anything...
5252
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
@@ -79,7 +79,7 @@ def next_id():
7979

8080
manager = actions.ExportFileManager(
8181
base_name, max_to_keep=2, next_id_fn=next_id)
82-
self.assertLen(tf.io.gfile.listdir(directory.full_path), 0)
82+
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
8383
id_num = 30
8484
directory.create_file(manager.next_name())
8585
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
@@ -105,6 +105,54 @@ def next_id():
105105
_id_sorted_file_base_names(directory.full_path),
106106
['basename-200', 'basename-1000'])
107107

108+
def test_export_file_manager_with_suffix(self):
109+
directory = self.create_tempdir()
110+
base_name = os.path.join(directory.full_path, 'basename')
111+
112+
id_num = 0
113+
114+
def next_id():
115+
return id_num
116+
117+
subdirectory = 'sub'
118+
119+
manager = actions.ExportFileManager(
120+
base_name, max_to_keep=2, next_id_fn=next_id, subdirectory=subdirectory
121+
)
122+
self.assertEmpty(tf.io.gfile.listdir(directory.full_path))
123+
id_num = 30
124+
directory.create_file(manager.next_name())
125+
self.assertLen(tf.io.gfile.listdir(directory.full_path), 1)
126+
manager.clean_up() # Shouldn't do anything...
127+
self.assertEqual(
128+
_id_sorted_file_base_names(directory.full_path), ['basename-30']
129+
)
130+
id_num = 200
131+
directory.create_file(manager.next_name())
132+
self.assertLen(tf.io.gfile.listdir(directory.full_path), 2)
133+
manager.clean_up() # Shouldn't do anything...
134+
self.assertEqual(
135+
_id_sorted_file_base_names(directory.full_path),
136+
['basename-30', 'basename-200'],
137+
)
138+
id_num = 1000
139+
directory.create_file(manager.next_name())
140+
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
141+
self.assertEqual(
142+
_id_sorted_file_base_names(directory.full_path),
143+
['basename-30', 'basename-200', 'basename-1000'],
144+
)
145+
manager.clean_up() # Should delete file with lowest ID.
146+
self.assertLen(tf.io.gfile.listdir(directory.full_path), 3)
147+
# Note that the base folder is intact, only the suffix folder is deleted.
148+
self.assertEqual(
149+
_id_sorted_file_base_names(directory.full_path),
150+
['basename-30', 'basename-200', 'basename-1000'],
151+
)
152+
153+
step_folder = os.path.join(directory.full_path, 'basename-1000')
154+
self.assertIn(subdirectory, tf.io.gfile.listdir(step_folder))
155+
108156
def test_export_file_manager_managed_files(self):
109157
directory = self.create_tempdir()
110158
directory.create_file('basename-5')

0 commit comments

Comments
 (0)