23
23
import pprint
24
24
import tempfile
25
25
26
+
26
27
import tensorflow as tf
27
28
import tensorflow_transform as tft
28
29
from apache_beam .io import textio
49
50
REVIEW_WEIGHT = 'review_weight'
50
51
LABEL_COLUMN = 'label'
51
52
53
+ RAW_DATA_METADATA = dataset_metadata .DatasetMetadata (dataset_schema .Schema ({
54
+ REVIEW_COLUMN : dataset_schema .ColumnSchema (
55
+ tf .string , [], dataset_schema .FixedColumnRepresentation ()),
56
+ LABEL_COLUMN : dataset_schema .ColumnSchema (
57
+ tf .int64 , [], dataset_schema .FixedColumnRepresentation ()),
58
+ }))
59
+
52
60
DELIMITERS = '.,!?() '
53
61
54
62
@@ -99,13 +107,13 @@ def ReadAndShuffleData(pcoll, filepatterns):
99
107
lambda p : {REVIEW_COLUMN : p [0 ], LABEL_COLUMN : p [1 ]})
100
108
101
109
102
- def transform_data (train_neg_filepattern , train_pos_filepattern ,
103
- test_neg_filepattern , test_pos_filepattern ,
104
- transformed_train_filebase , transformed_test_filebase ,
105
- transformed_metadata_dir ):
106
- """Transform the data and write out as a TFRecord of Example protos.
110
+ def read_and_shuffle_data (
111
+ train_neg_filepattern , train_pos_filepattern , test_neg_filepattern ,
112
+ test_pos_filepattern , shuffled_train_filebase , shuffled_test_filebase ):
113
+ """Read and shuffle the data and write out as a TFRecord of Example protos.
107
114
108
- Read in the data from the positive and negative examples on disk, and
115
+ Read in the data from the positive and negative examples on disk, shuffle it
116
+ and write it out in TFRecord format.
109
117
transform it using a preprocessing pipeline that removes punctuation,
110
118
tokenizes and maps tokens to int64 values indices.
111
119
@@ -114,6 +122,42 @@ def transform_data(train_neg_filepattern, train_pos_filepattern,
114
122
train_pos_filepattern: Filepattern for training data positive examples
115
123
test_neg_filepattern: Filepattern for test data negative examples
116
124
test_pos_filepattern: Filepattern for test data positive examples
125
+ shuffled_train_filebase: Base filename for shuffled training data shards
126
+ shuffled_test_filebase: Base filename for shuffled test data shards
127
+ """
128
+ with beam .Pipeline () as pipeline :
129
+ # pylint: disable=no-value-for-parameter
130
+ _ = (
131
+ pipeline
132
+ | 'ReadAndShuffleTrain' >> ReadAndShuffleData (
133
+ (train_neg_filepattern , train_pos_filepattern ))
134
+ | 'WriteTrainData' >> tfrecordio .WriteToTFRecord (
135
+ shuffled_train_filebase ,
136
+ coder = example_proto_coder .ExampleProtoCoder (
137
+ RAW_DATA_METADATA .schema )))
138
+ _ = (
139
+ pipeline
140
+ | 'ReadAndShuffleTest' >> ReadAndShuffleData (
141
+ (test_neg_filepattern , test_pos_filepattern ))
142
+ | 'WriteTestData' >> tfrecordio .WriteToTFRecord (
143
+ shuffled_test_filebase ,
144
+ coder = example_proto_coder .ExampleProtoCoder (
145
+ RAW_DATA_METADATA .schema )))
146
+ # pylint: enable=no-value-for-parameter
147
+
148
+
149
+ def transform_data (shuffled_train_filepattern , shuffled_test_filepattern ,
150
+ transformed_train_filebase , transformed_test_filebase ,
151
+ transformed_metadata_dir ):
152
+ """Transform the data and write out as a TFRecord of Example protos.
153
+
154
+ Read in the data from the positive and negative examples on disk, and
155
+ transform it using a preprocessing pipeline that removes punctuation,
156
+ tokenizes and maps tokens to int64 values indices.
157
+
158
+ Args:
159
+ shuffled_train_filepattern: Base filename for shuffled training data shards
160
+ shuffled_test_filepattern: Base filename for shuffled test data shards
117
161
transformed_train_filebase: Base filename for transformed training data
118
162
shards
119
163
transformed_test_filebase: Base filename for transformed test data shards
@@ -123,19 +167,19 @@ def transform_data(train_neg_filepattern, train_pos_filepattern,
123
167
124
168
with beam .Pipeline () as pipeline :
125
169
with beam_impl .Context (temp_dir = tempfile .mkdtemp ()):
126
- # pylint: disable=no-value-for-parameter
127
- train_data = pipeline | 'ReadTrain' >> ReadAndShuffleData (
128
- ( train_neg_filepattern , train_pos_filepattern ))
129
- # pylint: disable=no-value-for-parameter
130
- test_data = pipeline | 'ReadTest' >> ReadAndShuffleData (
131
- ( test_neg_filepattern , test_pos_filepattern ))
132
-
133
- metadata = dataset_metadata . DatasetMetadata ( dataset_schema . Schema ({
134
- REVIEW_COLUMN : dataset_schema . ColumnSchema (
135
- tf . string , [], dataset_schema . FixedColumnRepresentation ()),
136
- LABEL_COLUMN : dataset_schema . ColumnSchema (
137
- tf . int64 , [], dataset_schema . FixedColumnRepresentation ()),
138
- } ))
170
+ train_data = (
171
+ pipeline |
172
+ 'ReadTrain' >> tfrecordio . ReadFromTFRecord (
173
+ shuffled_train_filepattern ,
174
+ coder = example_proto_coder . ExampleProtoCoder (
175
+ RAW_DATA_METADATA . schema ) ))
176
+
177
+ test_data = (
178
+ pipeline |
179
+ 'ReadTest' >> tfrecordio . ReadFromTFRecord (
180
+ shuffled_test_filepattern ,
181
+ coder = example_proto_coder . ExampleProtoCoder (
182
+ RAW_DATA_METADATA . schema ) ))
139
183
140
184
def preprocessing_fn (inputs ):
141
185
"""Preprocess input columns into transformed columns."""
@@ -153,12 +197,12 @@ def preprocessing_fn(inputs):
153
197
}
154
198
155
199
(transformed_train_data , transformed_metadata ), transform_fn = (
156
- (train_data , metadata )
200
+ (train_data , RAW_DATA_METADATA )
157
201
| 'AnalyzeAndTransform' >> beam_impl .AnalyzeAndTransformDataset (
158
202
preprocessing_fn ))
159
203
160
204
transformed_test_data , _ = (
161
- ((test_data , metadata ), transform_fn )
205
+ ((test_data , RAW_DATA_METADATA ), transform_fn )
162
206
| 'Transform' >> beam_impl .TransformDataset ())
163
207
164
208
_ = (
@@ -183,7 +227,9 @@ def preprocessing_fn(inputs):
183
227
184
228
185
229
def train_and_evaluate (transformed_train_filepattern ,
186
- transformed_test_filepattern , transformed_metadata_dir ):
230
+ transformed_test_filepattern , transformed_metadata_dir ,
231
+ num_train_instances = NUM_TRAIN_INSTANCES ,
232
+ num_test_instances = NUM_TEST_INSTANCES ):
187
233
"""Train the model on training data and evaluate on evaluation data.
188
234
189
235
Args:
@@ -192,6 +238,8 @@ def train_and_evaluate(transformed_train_filepattern,
192
238
transformed_test_filepattern: Base filename for transformed evaluation data
193
239
shards
194
240
transformed_metadata_dir: Directory containing transformed data metadata
241
+ num_train_instances: Number of instances in train set
242
+ num_test_instances: Number of instances in test set
195
243
196
244
Returns:
197
245
The results from the estimator's 'evaluate' method
@@ -219,7 +267,7 @@ def train_and_evaluate(transformed_train_filepattern,
219
267
# Estimate the model using the default optimizer.
220
268
estimator .fit (
221
269
input_fn = train_input_fn ,
222
- max_steps = TRAIN_NUM_EPOCHS * NUM_TRAIN_INSTANCES / TRAIN_BATCH_SIZE )
270
+ max_steps = TRAIN_NUM_EPOCHS * num_train_instances / TRAIN_BATCH_SIZE )
223
271
224
272
# Evaluate model on eval dataset.
225
273
eval_input_fn = input_fn_maker .build_training_input_fn (
@@ -228,7 +276,7 @@ def train_and_evaluate(transformed_train_filepattern,
228
276
training_batch_size = 1 ,
229
277
label_keys = [LABEL_COLUMN ])
230
278
231
- return estimator .evaluate (input_fn = eval_input_fn , steps = NUM_TEST_INSTANCES )
279
+ return estimator .evaluate (input_fn = eval_input_fn , steps = num_test_instances )
232
280
233
281
234
282
def main ():
@@ -248,14 +296,19 @@ def main():
248
296
train_pos_filepattern = os .path .join (args .input_data_dir , 'train/pos/*' )
249
297
test_neg_filepattern = os .path .join (args .input_data_dir , 'test/neg/*' )
250
298
test_pos_filepattern = os .path .join (args .input_data_dir , 'test/pos/*' )
299
+ shuffled_train_filebase = os .path .join (transformed_data_dir , 'train_shuffled' )
300
+ shuffled_test_filebase = os .path .join (transformed_data_dir , 'test_shuffled' )
251
301
transformed_train_filebase = os .path .join (transformed_data_dir ,
252
302
'train_transformed' )
253
303
transformed_test_filebase = os .path .join (transformed_data_dir ,
254
304
'test_transformed' )
255
305
transformed_metadata_dir = os .path .join (transformed_data_dir , 'metadata' )
256
306
257
- transform_data (train_neg_filepattern , train_pos_filepattern ,
258
- test_neg_filepattern , test_pos_filepattern ,
307
+ read_and_shuffle_data (train_neg_filepattern , train_pos_filepattern ,
308
+ test_neg_filepattern , test_pos_filepattern ,
309
+ shuffled_train_filebase , shuffled_test_filebase )
310
+
311
+ transform_data (shuffled_train_filebase + '*' , shuffled_test_filebase + '*' ,
259
312
transformed_train_filebase , transformed_test_filebase ,
260
313
transformed_metadata_dir )
261
314
0 commit comments