@@ -30,6 +30,27 @@ To evaluate all checkpoints in a given directory:
30
30
--hparams_set=transformer_big_single_gpu
31
31
--source=wmt13_deen.en
32
32
--reference=wmt13_deen.de`
33
+
34
+ In addition to the above-mentioned compulsory parameters,
35
+ there are optional parameters:
36
+
37
+ * bleu_variant: cased (case-sensitive), uncased, both (default).
38
+ * translations_dir: Where to store the translated files? Default="translations".
39
+ * even_subdir: Where in the model_dir to store the even file? Default="",
40
+ which means TensorBoard will show it as the same run as the training, but it will warn
41
+ about "more than one metagraph event per run". event_subdir can be used e.g. if running
42
+ this script several times with different `--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA"`.
43
+ * tag_suffix: Default="", so the tags will be BLEU_cased and BLEU_uncased. Again, tag_suffix
44
+ can be used e.g. for different beam sizes if these should be plotted in different graphs.
45
+ * min_steps: Don't evaluate checkpoints with less steps.
46
+ Default=-1 means check the `last_evaluated_step.txt` file, which contains the number of steps
47
+ of the last successfully evaluated checkpoint.
48
+ * report_zero: Store BLEU=0 and guess its time based on flags.txt. Default=True.
49
+ This is useful, so TensorBoard reports correct relative time for the remaining checkpoints.
50
+ This flag is set to False if min_steps is > 0.
51
+ * wait_secs: Wait upto N seconds for a new checkpoint. Default=0.
52
+ This is useful for continuous evaluation of a running training,
53
+ in which case this should be equal to save_checkpoints_secs plus some reserve.
33
54
"""
34
55
from __future__ import absolute_import
35
56
from __future__ import division
@@ -53,7 +74,11 @@ flags.DEFINE_string("translation", None, "Path to the MT system translation file
53
74
flags .DEFINE_string ("source" , None , "Path to the source-language file to be translated" )
54
75
flags .DEFINE_string ("reference" , None , "Path to the reference translation file" )
55
76
flags .DEFINE_string ("translations_dir" , "translations" , "Where to store the translated files" )
56
- flags .DEFINE_bool ("report_zero" , True , "Store BLEU=0 and guess its time via flags.txt" )
77
+ flags .DEFINE_string ("event_subdir" , "" , "Where in model_dir to store the event file" )
78
+ flags .DEFINE_string ("tag_suffix" , "" , "What to add to BLEU_cased and BLEU_uncased tags. Default=''." )
79
+ flags .DEFINE_integer ("min_steps" , - 1 , "Don't evaluate checkpoints with less steps." )
80
+ flags .DEFINE_integer ("wait_secs" , 0 , "Wait upto N seconds for a new checkpoint, cf. save_checkpoints_secs." )
81
+ flags .DEFINE_bool ("report_zero" , None , "Store BLEU=0 and guess its time based on flags.txt" )
57
82
58
83
# options derived from t2t-decode
59
84
flags .DEFINE_integer ("decode_shards" , 1 , "Number of decoding replicas." )
@@ -70,6 +95,11 @@ flags.DEFINE_string("schedule", "train_and_evaluate",
70
95
Model = namedtuple ('Model' , 'filename time steps' )
71
96
72
97
98
+ def read_checkpoints_list (model_dir , min_steps ):
99
+ models = [Model (x [:- 6 ], os .path .getctime (x ), int (x [:- 6 ].rsplit ('-' )[- 1 ]))
100
+ for x in tf .gfile .Glob (os .path .join (model_dir , 'model.ckpt-*.index' ))]
101
+ return sorted ((x for x in models if x .steps > min_steps ), key = lambda x : x .steps )
102
+
73
103
def main (_ ):
74
104
tf .logging .set_verbosity (tf .logging .INFO )
75
105
if FLAGS .translation :
@@ -107,22 +137,43 @@ def main(_):
107
137
108
138
os .makedirs (FLAGS .translations_dir , exist_ok = True )
109
139
translated_base_file = os .path .join (FLAGS .translations_dir , FLAGS .problems )
110
- models = [Model (x [:- 6 ], os .path .getctime (x ), int (x [:- 6 ].rsplit ('-' )[- 1 ]))
111
- for x in tf .gfile .Glob (os .path .join (model_dir , 'model.ckpt-*.index' ))]
112
- models = sorted (models , key = lambda x : x .time )
140
+ event_dir = os .path .join (FLAGS .model_dir , FLAGS .event_subdir )
141
+ last_step_file = os .path .join (event_dir , 'last_evaluated_step.txt' )
142
+ if FLAGS .min_steps == - 1 :
143
+ try :
144
+ with open (last_step_file ) as ls_file :
145
+ FLAGS .min_steps = int (ls_file .read ())
146
+ except FileNotFoundError :
147
+ FLAGS .min_steps = 0
148
+ if FLAGS .report_zero is None :
149
+ FLAGS .report_zero = FLAGS .min_steps == 0
150
+
151
+ models = read_checkpoints_list (model_dir , FLAGS .min_steps )
113
152
tf .logging .info ("Found %d models with steps: %s" % (len (models ), ", " .join (str (x .steps ) for x in models )))
114
153
115
- writer = tf .summary .FileWriter (FLAGS . model_dir )
154
+ writer = tf .summary .FileWriter (event_dir )
116
155
if FLAGS .report_zero :
117
156
start_time = os .path .getctime (os .path .join (model_dir , 'flags.txt' ))
118
157
values = []
119
158
if FLAGS .bleu_variant in ('uncased' , 'both' ):
120
- values .append (tf .Summary .Value (tag = 'BLEU_uncased' , simple_value = 0 ))
159
+ values .append (tf .Summary .Value (tag = 'BLEU_uncased' + FLAGS . tag_suffix , simple_value = 0 ))
121
160
if FLAGS .bleu_variant in ('cased' , 'both' ):
122
- values .append (tf .Summary .Value (tag = 'BLEU_cased' , simple_value = 0 ))
161
+ values .append (tf .Summary .Value (tag = 'BLEU_cased' + FLAGS . tag_suffix , simple_value = 0 ))
123
162
writer .add_event (tf .summary .Event (summary = tf .Summary (value = values ), wall_time = start_time , step = 0 ))
124
163
125
- for model in models :
164
+ exit_time = time .time () + FLAGS .wait_secs
165
+ min_steps = FLAGS .min_steps
166
+ while True :
167
+ if not models and FLAGS .wait_secs :
168
+ tf .logging .info ('All checkpoints evaluated. Waiting till %s if a new checkpoint appears' % time .asctime (time .localtime (exit_time )))
169
+ while not models and time .time () < exit_time :
170
+ time .sleep (10 )
171
+ models = read_checkpoints_list (model_dir , min_steps )
172
+ if not models :
173
+ return
174
+
175
+ model = models .pop (0 )
176
+ exit_time , min_steps = model .time + FLAGS .wait_secs , model .steps
126
177
tf .logging .info ("Evaluating " + model .filename )
127
178
out_file = translated_base_file + '-' + str (model .steps )
128
179
tf .logging .set_verbosity (tf .logging .ERROR ) # decode_from_file logs all the translations as INFO
@@ -131,15 +182,17 @@ def main(_):
131
182
values = []
132
183
if FLAGS .bleu_variant in ('uncased' , 'both' ):
133
184
bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = False )
134
- values .append (tf .Summary .Value (tag = 'BLEU_uncased' , simple_value = bleu ))
185
+ values .append (tf .Summary .Value (tag = 'BLEU_uncased' + FLAGS . tag_suffix , simple_value = bleu ))
135
186
tf .logging .info ("%s: BLEU_uncased = %6.2f" % (model .filename , bleu ))
136
187
if FLAGS .bleu_variant in ('cased' , 'both' ):
137
188
bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = True )
138
- values .append (tf .Summary .Value (tag = 'BLEU_cased' , simple_value = bleu ))
189
+ values .append (tf .Summary .Value (tag = 'BLEU_cased' + FLAGS . tag_suffix , simple_value = bleu ))
139
190
tf .logging .info ("%s: BLEU_cased = %6.2f" % (model .filename , bleu ))
140
191
writer .add_event (tf .summary .Event (summary = tf .Summary (value = values ), wall_time = model .time , step = model .steps ))
192
+ writer .flush ()
193
+ with open (last_step_file , 'w' ) as ls_file :
194
+ ls_file .write (str (model .steps ) + '\n ' )
141
195
142
- writer .flush ()
143
196
144
197
if __name__ == "__main__" :
145
198
tf .app .run ()
0 commit comments