Skip to content

Commit e2495a5

Browse files
committed
!!fix!!
1 parent b7cc7b0 commit e2495a5

12 files changed

+171
-34
lines changed

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@ For more details, please refer to [README of SSD-Tensorflow](https://github.com/
1111
## ##
1212
update:
1313

14-
- Add SSD preprocesing method using Tensorflow ops
14+
- Add SSD preprocesing method using Tensorflow ops [zero ground truth fixed]
1515
- Modify details of the network to match the original Caffe code
1616
- Add NMS using Tensorflow ops to support two mode
1717
- Fix most part of the matching strategy between ground truth and anchors
1818
- Replica GPU training support (If you are using Tensorflow 1.5.0+, then rename the replicate_model\_fn.py)
1919
- Add voc eval with debug
2020
- Add realtime eval, using class-wise bboxes-select and nms
2121
- Add support for training use *vgg16_reducedfc* model converted from pytorch, you can get from [here](https://drive.google.com/open?id=184srhbt8_uvLKeWW_Yo8Mc5wTyc0lJT7)
22-
- Other important fixes **[2018.03.18]**
22+
- Other important fixes **[2018.03.21]**
23+
- Make all anchors on different layers be matched together, to avoid some suboptimal matching results
24+
- Refactor anchors matching pipeline
25+
- Fix attribute 'difficult' missing problem in the TFRecords dataset
2326
- Model-320(reduced version) trained on VOC07+12 dataset now is available at [here](), the heavy one need to be trained by yourself
2427

2528
Note: Model trained using the initial version of this code can only get to 0.45~0.55mAP, clone the latest version will give you much better performance. Futher improvement is still going on.

datasets/pascalvoc_to_tfrecords.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
# TFRecords convertion parameters.
6666
RANDOM_SEED = 4242
67-
SAMPLES_PER_FILES = 200
67+
SAMPLES_PER_FILES = 1500
6868

6969

7070
def _process_image(directory, name):
@@ -103,12 +103,16 @@ def _process_image(directory, name):
103103
labels.append(int(VOC_LABELS[label][0]))
104104
labels_text.append(label.encode('ascii'))
105105

106-
if obj.find('difficult'):
107-
difficult.append(int(obj.find('difficult').text))
106+
isdifficult = obj.find('difficult')
107+
if isdifficult is not None:
108+
#print('ddd')
109+
difficult.append(int(isdifficult.text))
108110
else:
109111
difficult.append(0)
110-
if obj.find('truncated'):
111-
truncated.append(int(obj.find('truncated').text))
112+
113+
istruncated = obj.find('truncated')
114+
if istruncated is not None:
115+
truncated.append(int(istruncated.text))
112116
else:
113117
truncated.append(0)
114118

draw_toolbox.py

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def bboxes_draw_on_img(img, classes, scores, bboxes, thickness=2):
7676
line_type = 8
7777

7878
for i in range(bboxes.shape[0]):
79+
if classes[i] < 1: continue
80+
7981
bbox = bboxes[i]
8082
color = colors_tableau[classes[i]]
8183
# Draw bounding box...

eval_ron_network.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
tf.app.flags.DEFINE_string(
101101
'master', '', 'The address of the TensorFlow master to use.')
102102
tf.app.flags.DEFINE_string(
103-
'checkpoint_path', './model/model.ckpt-122044',#118815
103+
'checkpoint_path', './model/model.ckpt-121551',#118815
104104
'The directory where the model was written to or an absolute path to a '
105105
'checkpoint file.')
106106
tf.app.flags.DEFINE_string(

nets/ron_vgg_320.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class RONNet(object):
100100
no_annotation_label=21,
101101
feat_layers=['block7','block6', 'block5', 'block4'],
102102
feat_shapes=[(5, 5), (10, 10), (20, 20), (40, 40)],
103-
allowed_borders = [0, 0, 0, 0],
103+
allowed_borders = [32, 16, 8, 4],
104104
anchor_sizes=[(224., 256.),
105105
(160., 192.),
106106
(96., 128.),
@@ -120,7 +120,7 @@ class RONNet(object):
120120
# anchor_steps=[64],
121121

122122
anchor_offset=0.5,
123-
prior_scaling=[1., 1., 1., 1.]#[0.1, 0.1, 0.2, 0.2]
123+
prior_scaling=[0.1, 0.1, 0.2, 0.2]#[1., 1., 1., 1.]#
124124
)
125125

126126
def __init__(self, params=None):
@@ -477,7 +477,7 @@ def ron_net(inputs,
477477
# Block 6
478478
net = slim.conv2d(net, 4096, [7, 7], scope='fc6')
479479
end_points['block6'] = net
480-
net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')
480+
#net = slim.dropout(net, dropout_keep_prob, is_training=is_training, scope='dropout6')
481481
# Block 7: 1x1 conv, no padding.
482482
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
483483
end_points['block7'] = net
@@ -769,7 +769,7 @@ def ron_losses(logits, localisations, objness_logits, objness_pred,
769769
loss = custom_layers.modified_smooth_l1(localisations, tf.stop_gradient(glocalisations), sigma = 3.)
770770
#loss = custom_layers.abs_smooth(localisations - tf.stop_gradient(glocalisations))
771771

772-
loss = tf.cond(n_cls_positives > 0., lambda: beta * n_cls_positives / total_examples_for_cls * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(cls_positive_mask))), lambda: 0.)
772+
loss = tf.cond(n_cls_positives > 0., lambda: beta * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(cls_positive_mask))), lambda: 0.)
773773
#loss = tf.cond(n_positives > 0., lambda: beta * n_positives / total_examples_for_objness * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(positive_mask))), lambda: 0.)
774774
#loss = tf.reduce_mean(loss * weights)
775775
#loss = tf.reduce_sum(loss * weights)

nets/ssd_common.py

+82-12
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def tf_ssd_bboxes_encode_layer(labels,
131131
feat_cx = (feat_xmax + feat_xmin) / 2.
132132
feat_h = feat_ymax - feat_ymin
133133
feat_w = feat_xmax - feat_xmin
134+
135+
bboxes = tf.stack([ymin_, xmin_, ymax_, xmax_], axis=-1)
134136
# Encode features.
135137
# the prior_scaling (in fact is 5 and 10) is use for balance the regression loss of center and with(or height)
136138
# (x-x_ref)/x_ref * 10 + log(w/w_ref) * 5
@@ -142,7 +144,7 @@ def tf_ssd_bboxes_encode_layer(labels,
142144
feat_localizations = tf.stack([feat_cx, feat_cy, feat_w, feat_h], axis=-1)
143145
# now feat_localizations is our regression object
144146

145-
return feat_labels * tf.cast(matched_gt_mask, tf.int64) + (-1 * tf.cast(matched_gt < -1, tf.int64)), tf.expand_dims(tf.reshape(tf.cast(matched_gt_mask, tf.float32), tf.shape(ymin_)), -1) * feat_localizations, feat_scores
147+
return feat_labels * tf.cast(matched_gt_mask, tf.int64) + (-1 * tf.cast(matched_gt < -1, tf.int64)), tf.expand_dims(tf.reshape(tf.cast(matched_gt_mask, tf.float32), tf.shape(ymin_)), -1) * feat_localizations, feat_scores, bboxes
146148

147149
# def tf_ssd_bboxes_encode_layer(labels,
148150
# bboxes,
@@ -362,17 +364,85 @@ def tf_ssd_bboxes_encode(labels,
362364
target_labels = []
363365
target_localizations = []
364366
target_scores = []
365-
for i, anchors_layer in enumerate(anchors):
366-
with tf.name_scope('bboxes_encode_block_%i' % i):
367-
t_labels, t_loc, t_scores = \
368-
tf_ssd_bboxes_encode_layer(labels, bboxes, anchors_layer,
369-
num_classes, img_shape, allowed_borders[i], no_annotation_label,
370-
positive_threshold, ignore_threshold,
371-
prior_scaling, dtype)
372-
target_labels.append(t_labels)
373-
target_localizations.append(t_loc)
374-
target_scores.append(t_scores)
375-
return target_labels, target_localizations, target_scores
367+
target_bboxes = []
368+
369+
shape_recorder = []
370+
full_shape_anchors = {}
371+
with tf.name_scope('anchor_concat'):
372+
for i, anchors_layer in enumerate(anchors):
373+
yref, xref, href, wref = anchors_layer
374+
375+
ymin_ = yref - href / 2.
376+
xmin_ = xref - wref / 2.
377+
ymax_ = yref + href / 2.
378+
xmax_ = xref + wref / 2.
379+
380+
shape_recorder.append(ymin_.shape)
381+
full_shape_yxhw = [(ymin_ + ymax_)/2, (xmin_ + xmax_)/2, (ymax_ - ymin_), (xmax_ - xmin_)]
382+
383+
full_shape_anchors[i] = [np.reshape(_, (-1)) for _ in full_shape_yxhw]
384+
#print(full_shape_anchors)
385+
remap_anchors = list(zip(*full_shape_anchors.values()))
386+
387+
for i in range(len(full_shape_anchors)):
388+
full_shape_anchors[i] = np.concatenate(remap_anchors[i], axis=0)
389+
#print(full_shape_anchors[i].shape)
390+
# print([_.shape for _ in remap_anchors[0]])
391+
# print([_.shape for _ in remap_anchors[1]])
392+
# print([_.shape for _ in remap_anchors[2]])
393+
# print([_.shape for _ in remap_anchors[3]])
394+
#print(shape_recorder)
395+
len_recorder = [np.prod(_) for _ in shape_recorder]
396+
#print(len_recorder)
397+
#print(allowed_borders)
398+
flaten_allowed_borders = []
399+
for i, allowed_border in enumerate(allowed_borders):
400+
flaten_allowed_borders.append([allowed_border]*len_recorder[i])
401+
#print([len(_) for _ in flaten_allowed_borders])
402+
flaten_allowed_borders = np.concatenate(flaten_allowed_borders, axis=0)
403+
404+
t_labels, t_loc, t_scores, t_bbox = tf_ssd_bboxes_encode_layer(labels, bboxes, list(full_shape_anchors.values()), num_classes, img_shape, flaten_allowed_borders, no_annotation_label, positive_threshold, ignore_threshold, prior_scaling, dtype)
405+
406+
reshaped_loc = []
407+
for i, loc in enumerate(tf.split(t_loc, len_recorder)):
408+
reshaped_loc.append(tf.reshape(loc, list(shape_recorder[i])+[-1]))
409+
reshaped_bbox = []
410+
for i, bbox in enumerate(tf.split(t_bbox, len_recorder)):
411+
reshaped_bbox.append(tf.reshape(bbox, list(shape_recorder[i])+[-1]))
412+
#print(reshaped_loc)
413+
#print(reshaped_bbox)
414+
return tf.split(t_labels, len_recorder), reshaped_loc, tf.split(t_scores, len_recorder), reshaped_bbox
415+
# with tf.name_scope(scope):
416+
# target_labels = []
417+
# target_localizations = []
418+
# target_scores = []
419+
# target_bboxes = []
420+
# for i, anchors_layer in enumerate(anchors):
421+
# with tf.name_scope('bboxes_encode_block_%i' % i):
422+
423+
# yref, xref, href, wref = anchors_layer
424+
425+
# ymin_ = yref - href / 2.
426+
# xmin_ = xref - wref / 2.
427+
# ymax_ = yref + href / 2.
428+
# xmax_ = xref + wref / 2.
429+
430+
# yref_, xref_, href_, wref_ = (ymin_ + ymax_)/2, (xmin_ + xmax_)/2, (ymax_ - ymin_), (xmax_ - xmin_)
431+
# t_labels, t_loc, t_scores, t_bbox = \
432+
# tf_ssd_bboxes_encode_layer(labels, bboxes, anchors_layer,
433+
# num_classes, img_shape, allowed_borders[i], no_annotation_label,
434+
# positive_threshold, ignore_threshold,
435+
# prior_scaling, dtype)
436+
# print('anchors_layer:', [yref_.shape, xref_.shape, href_.shape, wref_.shape])
437+
# print('t_labels:', t_labels)
438+
# print('t_loc:', t_loc)
439+
# print('t_scores:', t_scores)
440+
# print('t_bbox:', t_bbox)
441+
# target_labels.append(t_labels)
442+
# target_localizations.append(t_loc)
443+
# target_scores.append(t_scores)
444+
# target_bboxes.append(t_bbox)
445+
# return target_labels, target_localizations, target_scores, target_bboxes
376446

377447

378448
def tf_ssd_bboxes_decode_layer(feat_localizations,

preprocessing/ssd_vgg_preprocessing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def ron_preprocess_for_train(image, labels, bboxes,
345345
lambda x, ordering: distort_color(x, ordering, fast_mode),
346346
num_cases=4)
347347
tf_summary_image(dst_image, bboxes, 'image_color_distorted_4')
348-
348+
dst_image = random_sample_flip_resized_image
349349
# Rescale to VGG input scale.
350350
dst_image.set_shape([None, None, 3])
351351
image = dst_image * 255.
@@ -358,7 +358,7 @@ def ron_preprocess_for_train(image, labels, bboxes,
358358
def preprocess_for_eval(image, labels, bboxes,
359359
out_shape=EVAL_SIZE, data_format='NHWC',
360360
difficults=None, resize='WARP_RESIZE',
361-
scope='ssd_preprocessing_train'):
361+
scope='ssd_preprocessing_eval'):
362362
"""Preprocess an image for evaluation.
363363
364364
Args:

ron_net.py

+52-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
import tensorflow as tf
1717
import os
1818

19+
import numpy as np
20+
import tf_extended as tfe
21+
from tensorflow.python.framework import ops
22+
import draw_toolbox
23+
from scipy.misc import imread, imsave, imshow, imresize
24+
1925
from tensorflow.python import debug as tf_debug
2026
from tensorflow.python.ops import control_flow_ops
2127

@@ -46,7 +52,7 @@
4652
tf.app.flags.DEFINE_float(
4753
'loss_alpha', 1./3, 'Alpha parameter in the loss function.')
4854
tf.app.flags.DEFINE_float(
49-
'loss_beta', 1./3, 'Beta parameter in the loss function.')
55+
'loss_beta', 1./5, 'Beta parameter in the loss function.')
5056
tf.app.flags.DEFINE_float(
5157
'negative_ratio', 3., 'Negative ratio in the loss function.')
5258
tf.app.flags.DEFINE_float(
@@ -173,6 +179,16 @@
173179

174180
FLAGS = tf.app.flags.FLAGS
175181

182+
def save_image_with_bbox(image, labels_, scores_, bboxes_):
183+
if not hasattr(save_image_with_bbox, "counter"):
184+
save_image_with_bbox.counter = 0 # it doesn't exist yet, so initialize it
185+
save_image_with_bbox.counter += 1
186+
187+
#print(labels_)
188+
img_to_draw = np.copy(image)#common_preprocessing.np_image_unwhitened(image))
189+
img_to_draw = draw_toolbox.bboxes_draw_on_img(img_to_draw, labels_, scores_, bboxes_, thickness=2)
190+
imsave(os.path.join('./Debug', '{}.jpg').format(save_image_with_bbox.counter), img_to_draw)
191+
return save_image_with_bbox.counter
176192
# =========================================================================== #
177193
# Main training routine.
178194
# =========================================================================== #
@@ -219,7 +235,16 @@ def main(_):
219235
'object/label',
220236
'object/bbox',
221237
'object/difficult'])
222-
glabels = tf.cast(isdifficult < tf.ones_like(isdifficult), glabels.dtype) * glabels
238+
239+
#glabels = tf.cast(isdifficult < tf.ones_like(isdifficult), glabels.dtype) * glabels
240+
241+
isdifficult_mask =tf.cond(tf.reduce_sum(tf.cast(tf.logical_not(tf.equal(tf.ones_like(isdifficult), isdifficult)), tf.float32)) < 1., lambda : tf.one_hot(0, tf.shape(isdifficult)[0], on_value=True, off_value=False, dtype=tf.bool), lambda : isdifficult < tf.ones_like(isdifficult))
242+
243+
glabels = tf.boolean_mask(glabels, isdifficult_mask)
244+
gbboxes = tf.boolean_mask(gbboxes, isdifficult_mask)
245+
246+
#glabels = tf.Print(glabels, [glabels,isdifficult], message='glabels: ', summarize=200)
247+
223248
#### DEBUG ####
224249
#image = tf.Print(image, [shape, glabels, gbboxes], message='before preprocess: ', summarize=20)
225250
# Select the preprocessing function.
@@ -235,20 +260,42 @@ def main(_):
235260
#### DEBUG ####
236261
#image = tf.Print(image, [shape, glabels, gbboxes], message='after preprocess: ', summarize=20)
237262

238-
#glabels = tf.Print(glabels, [glabels], message='glabels: ', summarize=20)
263+
#glabels = tf.Print(glabels, [glabels,isdifficult], message='glabels: ', summarize=200)
239264

265+
# save_image_op = tf.py_func(save_image_with_bbox,
266+
# [image,
267+
# tf.reshape(tf.clip_by_value(glabels, 0, 22), [-1]),
268+
# #tf.convert_to_tensor(list(rscores.keys()), dtype=tf.int64),
269+
# tf.reshape(tf.ones_like(gbboxes), [-1]),
270+
# tf.reshape(gbboxes, [-1, 4])],
271+
# tf.int64, stateful=True)
240272

241273
# Encode groundtruth labels and bboxes.
242274
# glocalisations is our regression object
243275
# gclasses is the ground_trutuh label
244276
# gscores is the the jaccard score with ground_truth
245-
gclasses, glocalisations, gscores = \
277+
gclasses, glocalisations, gscores, gbboxes = \
246278
ron_net.bboxes_encode(glabels, gbboxes, ron_anchors, positive_threshold=FLAGS.match_threshold, ignore_threshold=FLAGS.neg_threshold)
247279

280+
#gclasses[1] = tf.Print(gclasses[1], [gclasses[1]], message='gclasses[1]: ', summarize=200)
281+
# save_image_op = tf.py_func(save_image_with_bbox,
282+
# [image,
283+
# tf.reshape(tf.clip_by_value(gclasses[3], 0, 22), [-1]),
284+
# #tf.convert_to_tensor(list(rscores.keys()), dtype=tf.int64),
285+
# tf.reshape(gscores[3], [-1]),
286+
# tf.reshape(gbboxes[3], [-1, 4])],
287+
# tf.int64, stateful=True)
288+
# save_image_op = tf.py_func(save_image_with_bbox,
289+
# [image,
290+
# tf.clip_by_value(tf.concat([tf.reshape(_, [-1]) for _ in gclasses], axis=0), 0, 22),
291+
# tf.concat([tf.reshape(_, [-1]) for _ in gscores], axis=0),
292+
# tf.concat([tf.reshape(_, [-1, 4]) for _ in gbboxes], axis=0)],
293+
# tf.int64, stateful=True)
248294
# each size of the batch elements
249295
# include one image, three others(gclasses, glocalisations, gscores)
250296
batch_shape = [1] + [len(ron_anchors)] * 3
251297

298+
#with tf.control_dependencies([save_image_op]):
252299
# Training batches and queue.
253300
r = tf.train.batch(
254301
tf_utils.reshape_list([image, gclasses, glocalisations, gscores]),
@@ -359,7 +406,7 @@ def wrapper_debug(sess):
359406
logdir=FLAGS.model_dir,
360407
master='',
361408
is_chief=True,
362-
init_fn=tf_utils.get_init_fn(FLAGS, os.path.join(FLAGS.data_dir, 'vgg_model/vgg16_reducedfc.ckpt')),
409+
init_fn=tf_utils.get_init_fn(FLAGS, os.path.join(FLAGS.data_dir, 'vgg_model/vgg16_reducedfc.ckpt')),#'vgg_model/vgg16_reducedfc.ckpt'
363410
summary_op=summary_op,
364411
number_of_steps=FLAGS.max_number_of_steps,
365412
log_every_n_steps=FLAGS.log_every_n_steps,

ron_net_multi_gpu.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,10 @@ def main(_):
269269
'object/label',
270270
'object/bbox',
271271
'object/difficult'])
272-
glabels = tf.cast(isdifficult < tf.ones_like(isdifficult), glabels.dtype) * glabels
272+
isdifficult_mask =tf.cond(tf.reduce_sum(tf.cast(tf.logical_not(tf.equal(tf.ones_like(isdifficult), isdifficult)), tf.float32)) < 1., lambda : tf.one_hot(0, tf.shape(isdifficult)[0], on_value=True, off_value=False, dtype=tf.bool), lambda : isdifficult < tf.ones_like(isdifficult))
273+
274+
glabels = tf.boolean_mask(glabels, isdifficult_mask)
275+
gbboxes = tf.boolean_mask(gbboxes, isdifficult_mask)
273276

274277
# Select the preprocessing function.
275278
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name

ron_net_multi_gpu_optimized.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,10 @@ def train_input_fn():
355355
'object/label',
356356
'object/bbox',
357357
'object/difficult'])
358-
glabels = tf.cast(isdifficult < tf.ones_like(isdifficult), glabels.dtype) * glabels
358+
isdifficult_mask =tf.cond(tf.reduce_sum(tf.cast(tf.logical_not(tf.equal(tf.ones_like(isdifficult), isdifficult)), tf.float32)) < 1., lambda : tf.one_hot(0, tf.shape(isdifficult)[0], on_value=True, off_value=False, dtype=tf.bool), lambda : isdifficult < tf.ones_like(isdifficult))
359+
360+
glabels = tf.boolean_mask(glabels, isdifficult_mask)
361+
gbboxes = tf.boolean_mask(gbboxes, isdifficult_mask)
359362
# Select the preprocessing function.
360363
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
361364
image_preprocessing_fn = preprocessing_factory.get_preprocessing(

0 commit comments

Comments
 (0)