@@ -62,13 +62,18 @@ from tensor2tensor.utils import decoding
62
62
from tensor2tensor .utils import trainer_utils
63
63
from tensor2tensor .utils import usr_dir
64
64
from tensor2tensor .utils import bleu_hook
65
+ from tensor2tensor .utils import registry
66
+ from tensor2tensor import _set_time_logging
67
+
65
68
import tensorflow as tf
66
69
67
70
flags = tf .flags
68
71
FLAGS = flags .FLAGS
69
72
70
73
# t2t-bleu specific options
71
74
flags .DEFINE_string ("bleu_variant" , "both" , "Possible values: cased(case-sensitive), uncased, both(default)." )
75
+ flags .DEFINE_bool ("postprocess" , True , "Postprocess translation and reference before calculating BLEU. True, False(default)." )
76
+ flags .DEFINE_string ("postprocess_suffix" , ".post" , "Possible values: True, False(default)." )
72
77
flags .DEFINE_string ("model_dir" , "" , "Directory to load model checkpoints from." )
73
78
flags .DEFINE_string ("translation" , None , "Path to the MT system translation file" )
74
79
flags .DEFINE_string ("source" , None , "Path to the source-language file to be translated" )
@@ -92,28 +97,60 @@ flags.DEFINE_string("master", "", "Address of TensorFlow master.")
92
97
flags .DEFINE_string ("schedule" , "train_and_evaluate" ,
93
98
"Must be train_and_evaluate for decoding." )
94
99
95
- Model = namedtuple ('Model' , 'filename time steps' )
96
-
100
+ Model = namedtuple ('Model' , 'filename time steps' )
97
101
98
102
def read_checkpoints_list (model_dir , min_steps ):
99
103
models = [Model (x [:- 6 ], os .path .getctime (x ), int (x [:- 6 ].rsplit ('-' )[- 1 ]))
100
104
for x in tf .gfile .Glob (os .path .join (model_dir , 'model.ckpt-*.index' ))]
101
105
return sorted ((x for x in models if x .steps > min_steps ), key = lambda x : x .steps )
102
106
107
+ def postprocess (pre , post , problem ):
108
+ if tf .gfile .Exists (post ): return
109
+ with open (pre , "r" , encoding = "utf-8" ) as o :
110
+ with open (post , "w" , encoding = "utf-8" ) as p :
111
+ for _ in range (10 ): tf .logging .info ("postprocessing file %s" % post )
112
+ p .write (problem .postprocess (o .read ()))
113
+
114
+ def postprocess_maybe_add_suffix (filename , problem ):
115
+ # postprocess reference or translation file, if needed
116
+ if not filename .endswith (FLAGS .postprocess_suffix ):
117
+ # this creates a new file with ".post" suffix (by default) in the same directory as reference
118
+ post = filename + FLAGS .postprocess_suffix
119
+ if not tf .gfile .Exists (post ):
120
+ postprocess (filename , post , problem )
121
+ return post
122
+ return filename
123
+
103
124
def main (_ ):
125
+ _set_time_logging ()
126
+
104
127
tf .logging .set_verbosity (tf .logging .INFO )
105
- if FLAGS .translation :
128
+
129
+ if FLAGS .translation : ## TODO: this variant is not tested
106
130
if FLAGS .model_dir :
107
131
raise ValueError ('Cannot specify both --translation and --model_dir.' )
108
- if FLAGS .bleu_variant in ('uncased' , 'both' ):
109
- bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , FLAGS .translation , case_sensitive = False )
110
- print ("BLEU_uncased = %6.2f" % bleu )
111
- if FLAGS .bleu_variant in ('cased' , 'both' ):
112
- bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , FLAGS .translation , case_sensitive = True )
113
- print ("BLEU_cased = %6.2f" % bleu )
132
+
133
+ def count_bleu (ref , trans , ptag = "" ):
134
+ if FLAGS .bleu_variant in ('uncased' , 'both' ):
135
+ bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , FLAGS .translation , case_sensitive = False )
136
+ print ("BLEU_uncased%s = %6.2f" % (ptag , bleu ))
137
+ if FLAGS .bleu_variant in ('cased' , 'both' ):
138
+ bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , FLAGS .translation , case_sensitive = True )
139
+ print ("BLEU_cased%s = %6.2f" % (ptag , bleu ))
140
+
141
+ if FLAGS .postprocess :
142
+ usr_dir .import_usr_dir (FLAGS .t2t_usr_dir )
143
+ problem = registry .problem (FLAGS .problems )
144
+ ref_post = postprocess_maybe_add_suffix (FLAGS .reference )
145
+ ref_trans = postprocess_maybe_add_suffix (FLAGS .translation )
146
+ count_bleu (ref_post , ref_trans , ptag = "_post" )
147
+ else :
148
+ count_bleu (FLAGS .reference , FLAGS .translation , ptag = "" )
114
149
return
115
150
116
151
usr_dir .import_usr_dir (FLAGS .t2t_usr_dir )
152
+ problem = registry .problem (FLAGS .problems )
153
+
117
154
FLAGS .model = FLAGS .model or 'transformer'
118
155
FLAGS .output_dir = FLAGS .model_dir
119
156
trainer_utils .log_registry ()
@@ -177,19 +214,36 @@ def main(_):
177
214
model = models .pop (0 )
178
215
exit_time , min_steps = model .time + FLAGS .wait_secs , model .steps
179
216
tf .logging .info ("Evaluating " + model .filename )
217
+
180
218
out_file = translated_base_file + '-' + str (model .steps )
219
+
181
220
tf .logging .set_verbosity (tf .logging .ERROR ) # decode_from_file logs all the translations as INFO
182
221
decoding .decode_from_file (estimator , FLAGS .source , decode_hp , out_file , checkpoint_path = model .filename )
183
222
tf .logging .set_verbosity (tf .logging .INFO )
223
+
224
+ post_out_file = out_file + FLAGS .postprocess_suffix
225
+ if problem .needs_postprocessing and FLAGS .postprocess :
226
+ post_out_file = postprocess_maybe_add_suffix (out_file , problem )
227
+ else :
228
+ post_out_file = out_file
229
+
230
+ post_reference = postprocess_maybe_add_suffix (FLAGS .reference , problem )
231
+
184
232
values = []
185
- if FLAGS .bleu_variant in ('uncased' , 'both' ):
186
- bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = False )
187
- values .append (tf .Summary .Value (tag = 'BLEU_uncased' + FLAGS .tag_suffix , simple_value = bleu ))
188
- tf .logging .info ("%s: BLEU_uncased = %6.2f" % (model .filename , bleu ))
189
- if FLAGS .bleu_variant in ('cased' , 'both' ):
190
- bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = True )
191
- values .append (tf .Summary .Value (tag = 'BLEU_cased' + FLAGS .tag_suffix , simple_value = bleu ))
192
- tf .logging .info ("%s: BLEU_cased = %6.2f" % (model .filename , bleu ))
233
+ def count_bleu (ref , out , ptag = "" ):
234
+ if FLAGS .bleu_variant in ('uncased' , 'both' ):
235
+ bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = False )
236
+ values .append (tf .Summary .Value (tag = 'BLEU_uncased' + ptag + FLAGS .tag_suffix , simple_value = bleu ))
237
+ tf .logging .info ("%s: BLEU_uncased%s%s = %6.2f" % (model .filename , ptag , FLAGS .tag_suffix , bleu ))
238
+ if FLAGS .bleu_variant in ('cased' , 'both' ):
239
+ bleu = 100 * bleu_hook .bleu_wrapper (FLAGS .reference , out_file , case_sensitive = True )
240
+ values .append (tf .Summary .Value (tag = 'BLEU_cased' + ptag + FLAGS .tag_suffix , simple_value = bleu ))
241
+ tf .logging .info ("%s: BLEU_uncased%s%s = %6.2f" % (model .filename , ptag , FLAGS .tag_suffix , bleu ))
242
+ if FLAGS .postprocess :
243
+ count_bleu (post_reference , post_out_file , ptag = "_post" )
244
+ # else: ## TODO: else or not ????
245
+ count_bleu (FLAGS .reference , out_file , ptag = "" )
246
+
193
247
writer .add_event (tf .summary .Event (summary = tf .Summary (value = values ), wall_time = model .time , step = model .steps ))
194
248
writer .flush ()
195
249
with open (last_step_file , 'w' ) as ls_file :
0 commit comments