Skip to content

Commit b7cc7b0

Browse files
committed
fix equals
1 parent 04bd877 commit b7cc7b0

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

eval_ron_network.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@
6262
# 'match_threshold', 0.5, 'Matching threshold with groundtruth objects.')
6363

6464
tf.app.flags.DEFINE_float(
65-
'select_threshold', 0.51, 'Selection threshold.')
65+
'select_threshold', 0.01, 'Selection threshold.')
6666
tf.app.flags.DEFINE_float(
67-
'objectness_thres', 0.93, 'threshold for the objectness to indicate the exist of object in that location.')
67+
'objectness_thres', 0.03, 'threshold for the objectness to indicate the exist of object in that location.')
6868
tf.app.flags.DEFINE_integer(
6969
'select_top_k', 200, 'Select top-k detected bounding boxes.')
7070
tf.app.flags.DEFINE_integer(
@@ -100,7 +100,7 @@
100100
tf.app.flags.DEFINE_string(
101101
'master', '', 'The address of the TensorFlow master to use.')
102102
tf.app.flags.DEFINE_string(
103-
'checkpoint_path', './model/model.ckpt-60399',#118815
103+
'checkpoint_path', './model/model.ckpt-122044',#118815
104104
'The directory where the model was written to or an absolute path to a '
105105
'checkpoint file.')
106106
tf.app.flags.DEFINE_string(

nets/ron_vgg_320.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class RONNet(object):
120120
# anchor_steps=[64],
121121

122122
anchor_offset=0.5,
123-
prior_scaling=[0.1, 0.1, 0.2, 0.2]
123+
prior_scaling=[1., 1., 1., 1.]#[0.1, 0.1, 0.2, 0.2]
124124
)
125125

126126
def __init__(self, params=None):
@@ -608,7 +608,7 @@ def ron_arg_scope(weight_decay=0.0005, is_training=True, data_format='NHWC'):
608608
biases_initializer=tf.zeros_initializer()):
609609
with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
610610
weights_regularizer=slim.l2_regularizer(weight_decay),
611-
weights_initializer=truncated_normal_001_initializer(),
611+
weights_initializer=tf.contrib.layers.xavier_initializer(),#truncated_normal_001_initializer(),
612612
biases_initializer=tf.zeros_initializer()):
613613
with slim.arg_scope([slim.conv2d, slim.conv2d_transpose, slim.max_pool2d],
614614
padding='SAME',
@@ -691,7 +691,7 @@ def ron_losses(logits, localisations, objness_logits, objness_pred,
691691
# negtive examples are those max_overlap is still lower than neg_threshold, note that some positive may also has lower jaccard
692692

693693
#negtive_mask = tf.cast(tf.logical_not(positive_mask), dtype) * gscores < neg_threshold
694-
negtive_mask = (gclasses == 0)
694+
negtive_mask = tf.equal(gclasses, 0) #(gclasses == 0)
695695
#negtive_mask = tf.logical_and(gscores < neg_threshold, tf.logical_not(positive_mask))
696696
fnegtive_mask = tf.cast(negtive_mask, dtype)
697697
n_negtives = tf.reduce_sum(fnegtive_mask)
@@ -769,7 +769,7 @@ def ron_losses(logits, localisations, objness_logits, objness_pred,
769769
loss = custom_layers.modified_smooth_l1(localisations, tf.stop_gradient(glocalisations), sigma = 3.)
770770
#loss = custom_layers.abs_smooth(localisations - tf.stop_gradient(glocalisations))
771771

772-
loss = tf.cond(n_cls_positives > 0., lambda: beta * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(cls_positive_mask))), lambda: 0.)
772+
loss = tf.cond(n_cls_positives > 0., lambda: beta * n_cls_positives / total_examples_for_cls * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(cls_positive_mask))), lambda: 0.)
773773
#loss = tf.cond(n_positives > 0., lambda: beta * n_positives / total_examples_for_objness * tf.reduce_mean(tf.boolean_mask(tf.reduce_sum(loss, axis=-1), tf.stop_gradient(positive_mask))), lambda: 0.)
774774
#loss = tf.reduce_mean(loss * weights)
775775
#loss = tf.reduce_sum(loss * weights)

nets/ssd_common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def iou_matrix(bboxes, gt_bboxes):
4646
return tf.where(tf.equal(inter_vol, 0.0),
4747
tf.zeros_like(inter_vol), tf.truediv(inter_vol, union_vol))
4848

49-
def do_dual_max_match(overlap_matrix, high_thres, low_thres, ignore_between = True, gt_max_first=False):
49+
def do_dual_max_match(overlap_matrix, high_thres, low_thres, ignore_between = True, gt_max_first=True):
5050
'''
5151
overlap_matrix: num_gt * num_anchors
5252
'''
@@ -67,7 +67,7 @@ def do_dual_max_match(overlap_matrix, high_thres, low_thres, ignore_between = Tr
6767
gt_to_anchors = tf.argmax(overlap_matrix, axis=1)
6868

6969
if gt_max_first:
70-
left_gt_to_anchors_mask = tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[1], on_value=1, off_value=0, axis=1, dtype=tf.int64)
70+
left_gt_to_anchors_mask = tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[1], on_value=1, off_value=0, axis=1, dtype=tf.int32)
7171
else:
7272
left_gt_to_anchors_mask = tf.cast(tf.logical_and(tf.reduce_max(anchors_to_gt_mask, axis=1, keep_dims=True) < 1, tf.one_hot(gt_to_anchors, tf.shape(overlap_matrix)[1], on_value=True, off_value=False, axis=1, dtype=tf.bool)), tf.int64)
7373

ron_net.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@
4444
# =========================================================================== #
4545

4646
tf.app.flags.DEFINE_float(
47-
'loss_alpha', 2./5, 'Alpha parameter in the loss function.')
47+
'loss_alpha', 1./3, 'Alpha parameter in the loss function.')
4848
tf.app.flags.DEFINE_float(
49-
'loss_beta', 2./5, 'Beta parameter in the loss function.')
49+
'loss_beta', 1./3, 'Beta parameter in the loss function.')
5050
tf.app.flags.DEFINE_float(
5151
'negative_ratio', 3., 'Negative ratio in the loss function.')
5252
tf.app.flags.DEFINE_float(
@@ -154,7 +154,7 @@
154154
# Fine-Tuning Flags.
155155
# =========================================================================== #
156156
tf.app.flags.DEFINE_string(
157-
'checkpoint_path', '../vgg_model/reduced/vgg16_reducedfc.ckpt',#None, #'../vgg_model/reduced/vgg16_reducedfc.ckpt'
157+
'checkpoint_path', None,#None, #'../vgg_model/reduced/vgg16_reducedfc.ckpt'
158158
'The path to a checkpoint from which to fine-tune.')
159159
tf.app.flags.DEFINE_string(
160160
'checkpoint_model_scope', 'vgg_16',#None,
@@ -359,7 +359,7 @@ def wrapper_debug(sess):
359359
logdir=FLAGS.model_dir,
360360
master='',
361361
is_chief=True,
362-
init_fn=tf_utils.get_init_fn(FLAGS, os.path.join(FLAGS.data_dir, 'vgg_16.ckpt')),
362+
init_fn=tf_utils.get_init_fn(FLAGS, os.path.join(FLAGS.data_dir, 'vgg_model/vgg16_reducedfc.ckpt')),
363363
summary_op=summary_op,
364364
number_of_steps=FLAGS.max_number_of_steps,
365365
log_every_n_steps=FLAGS.log_every_n_steps,

0 commit comments

Comments
 (0)