Skip to content

Commit 98bd165

Browse files
committed
update eval when train
1 parent b6abf18 commit 98bd165

File tree

6 files changed

+279
-40
lines changed

6 files changed

+279
-40
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Medical_Datasets/
77
lfw/
88
logs/
99
model_data/
10+
.temp_map_out/
1011

1112
# Byte-compiled / optimized / DLL files
1213
__pycache__/

get_map.py

+37-12
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
from PIL import Image
55
from tqdm import tqdm
66

7-
from yolo import YOLO
87
from utils.utils import get_classes
98
from utils.utils_map import get_coco_map, get_map
9+
from yolo import YOLO
1010

1111
if __name__ == "__main__":
1212
'''
13-
Recall和Precision不像AP是一个面积的概念,在门限值不同时,网络的Recall和Precision值是不同的。
14-
map计算结果中的Recall和Precision代表的是当预测时,门限置信度为0.5时,所对应的Recall和Precision值。
13+
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
14+
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
1515
16-
此处获得的./map_out/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低,
17-
目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。
16+
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
17+
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
1818
'''
1919
#------------------------------------------------------------------------------------------------------------------#
2020
# map_mode用于指定该文件运行时计算的内容
@@ -25,16 +25,41 @@
2525
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
2626
#-------------------------------------------------------------------------------------------------------------------#
2727
map_mode = 0
28-
#-------------------------------------------------------#
28+
#--------------------------------------------------------------------------------------#
2929
# 此处的classes_path用于指定需要测量VOC_map的类别
3030
# 一般情况下与训练和预测所用的classes_path一致即可
31-
#-------------------------------------------------------#
31+
#--------------------------------------------------------------------------------------#
3232
classes_path = 'model_data/voc_classes.txt'
33-
#-------------------------------------------------------#
34-
# MINOVERLAP用于指定想要获得的mAP0.x
33+
#--------------------------------------------------------------------------------------#
34+
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
3535
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
36-
#-------------------------------------------------------#
36+
#
37+
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
38+
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
39+
#--------------------------------------------------------------------------------------#
3740
MINOVERLAP = 0.5
41+
#--------------------------------------------------------------------------------------#
42+
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
43+
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
44+
#
45+
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
46+
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
47+
#--------------------------------------------------------------------------------------#
48+
confidence = 0.001
49+
#--------------------------------------------------------------------------------------#
50+
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
51+
#
52+
# 该值一般不调整。
53+
#--------------------------------------------------------------------------------------#
54+
nms_iou = 0.5
55+
#---------------------------------------------------------------------------------------------------------------#
56+
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
57+
#
58+
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
59+
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
60+
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
61+
#---------------------------------------------------------------------------------------------------------------#
62+
score_threhold = 0.5
3863
#-------------------------------------------------------#
3964
# map_vis用于指定是否开启VOC_map计算的可视化
4065
#-------------------------------------------------------#
@@ -64,7 +89,7 @@
6489

6590
if map_mode == 0 or map_mode == 1:
6691
print("Load model.")
67-
yolo = YOLO(confidence = 0.001, nms_iou = 0.5)
92+
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
6893
print("Load model done.")
6994

7095
print("Get predict result.")
@@ -104,7 +129,7 @@
104129

105130
if map_mode == 0 or map_mode == 3:
106131
print("Get map.")
107-
get_map(MINOVERLAP, True, path = map_out_path)
132+
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
108133
print("Get map done.")
109134

110135
if map_mode == 4:

train.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#-------------------------------------#
22
# 对数据集进行训练
33
#-------------------------------------#
4+
import datetime
45
import os
56

67
import numpy as np
@@ -9,12 +10,13 @@
910
import torch.distributed as dist
1011
import torch.nn as nn
1112
import torch.optim as optim
13+
from torch import nn
1214
from torch.utils.data import DataLoader
1315

1416
from nets.yolo import YoloBody
1517
from nets.yolo_training import (YOLOLoss, get_lr_scheduler, set_optimizer_lr,
1618
weights_init)
17-
from utils.callbacks import LossHistory
19+
from utils.callbacks import EvalCallback, LossHistory
1820
from utils.dataloader import YoloDataset, yolo_dataset_collate
1921
from utils.utils import get_anchors, get_classes, show_config
2022
from utils.utils_fit import fit_one_epoch
@@ -228,6 +230,17 @@
228230
#------------------------------------------------------------------#
229231
save_dir = 'logs'
230232
#------------------------------------------------------------------#
233+
# eval_flag 是否在训练时进行评估,评估对象为验证集
234+
# 安装pycocotools库后,评估体验更佳。
235+
# eval_period 代表多少个epoch评估一次,不建议频繁的评估
236+
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
237+
# 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
238+
# (一)此处获得的mAP为验证集的mAP。
239+
# (二)此处设置评估参数较为保守,目的是加快评估速度。
240+
#------------------------------------------------------------------#
241+
eval_flag = True
242+
eval_period = 10
243+
#------------------------------------------------------------------#
231244
# num_workers 用于设置是否使用多线程读取数据
232245
# 开启后会加快数据读取速度,但是会占用更多内存
233246
# 内存较小的电脑可以设置为2或者0
@@ -306,9 +319,11 @@
306319
# 记录Loss
307320
#----------------------#
308321
if local_rank == 0:
309-
loss_history = LossHistory(save_dir, model, input_shape=input_shape)
322+
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
323+
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
324+
loss_history = LossHistory(log_dir, model, input_shape=input_shape)
310325
else:
311-
loss_history = None
326+
loss_history = None
312327

313328
#------------------------------------------------------------------#
314329
# torch 1.2不支持amp,建议使用torch 1.7.1及以上正确使用fp16
@@ -455,6 +470,16 @@
455470
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
456471
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
457472
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
473+
474+
#----------------------#
475+
# 记录eval的map曲线
476+
#----------------------#
477+
if local_rank == 0:
478+
eval_callback = EvalCallback(model, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, Cuda, \
479+
eval_flag=eval_flag, period=eval_period)
480+
else:
481+
eval_callback = None
482+
458483
#---------------------------------------#
459484
# 开始模型训练
460485
#---------------------------------------#
@@ -507,7 +532,10 @@
507532

508533
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
509534

510-
fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)
511-
535+
fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)
536+
537+
if distributed:
538+
dist.barrier()
539+
512540
if local_rank == 0:
513-
loss_history.writer.close()
541+
loss_history.writer.close()

utils/callbacks.py

+164-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,19 @@
88
from matplotlib import pyplot as plt
99
from torch.utils.tensorboard import SummaryWriter
1010

11+
import shutil
12+
import numpy as np
13+
14+
from PIL import Image
15+
from tqdm import tqdm
16+
from .utils import cvtColor, preprocess_input, resize_image
17+
from .utils_bbox import DecodeBox
18+
from .utils_map import get_coco_map, get_map
19+
1120

1221
class LossHistory():
1322
def __init__(self, log_dir, model, input_shape):
14-
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
15-
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
23+
self.log_dir = log_dir
1624
self.losses = []
1725
self.val_loss = []
1826

@@ -68,3 +76,157 @@ def loss_plot(self):
6876

6977
plt.cla()
7078
plt.close("all")
79+
80+
class EvalCallback():
81+
def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
82+
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
83+
super(EvalCallback, self).__init__()
84+
85+
self.net = net
86+
self.input_shape = input_shape
87+
self.anchors = anchors
88+
self.anchors_mask = anchors_mask
89+
self.class_names = class_names
90+
self.num_classes = num_classes
91+
self.val_lines = val_lines
92+
self.log_dir = log_dir
93+
self.cuda = cuda
94+
self.map_out_path = map_out_path
95+
self.max_boxes = max_boxes
96+
self.confidence = confidence
97+
self.nms_iou = nms_iou
98+
self.letterbox_image = letterbox_image
99+
self.MINOVERLAP = MINOVERLAP
100+
self.eval_flag = eval_flag
101+
self.period = period
102+
103+
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
104+
105+
self.maps = [0]
106+
self.epoches = [0]
107+
if self.eval_flag:
108+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
109+
f.write(str(0))
110+
f.write("\n")
111+
112+
def get_map_txt(self, image_id, image, class_names, map_out_path):
113+
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
114+
image_shape = np.array(np.shape(image)[0:2])
115+
#---------------------------------------------------------#
116+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
117+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
118+
#---------------------------------------------------------#
119+
image = cvtColor(image)
120+
#---------------------------------------------------------#
121+
# 给图像增加灰条,实现不失真的resize
122+
# 也可以直接resize进行识别
123+
#---------------------------------------------------------#
124+
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
125+
#---------------------------------------------------------#
126+
# 添加上batch_size维度
127+
#---------------------------------------------------------#
128+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
129+
130+
with torch.no_grad():
131+
images = torch.from_numpy(image_data)
132+
if self.cuda:
133+
images = images.cuda()
134+
#---------------------------------------------------------#
135+
# 将图像输入网络当中进行预测!
136+
#---------------------------------------------------------#
137+
outputs = self.net(images)
138+
outputs = self.bbox_util.decode_box(outputs)
139+
#---------------------------------------------------------#
140+
# 将预测框进行堆叠,然后进行非极大抑制
141+
#---------------------------------------------------------#
142+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
143+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
144+
145+
if results[0] is None:
146+
return
147+
148+
top_label = np.array(results[0][:, 6], dtype = 'int32')
149+
top_conf = results[0][:, 4] * results[0][:, 5]
150+
top_boxes = results[0][:, :4]
151+
152+
top_100 = np.argsort(top_label)[::-1][:self.max_boxes]
153+
top_boxes = top_boxes[top_100]
154+
top_conf = top_conf[top_100]
155+
top_label = top_label[top_100]
156+
157+
for i, c in list(enumerate(top_label)):
158+
predicted_class = self.class_names[int(c)]
159+
box = top_boxes[i]
160+
score = str(top_conf[i])
161+
162+
top, left, bottom, right = box
163+
if predicted_class not in class_names:
164+
continue
165+
166+
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
167+
168+
f.close()
169+
return
170+
171+
def on_epoch_end(self, epoch, model_eval):
172+
if epoch % self.period == 0 and self.eval_flag:
173+
self.net = model_eval
174+
if not os.path.exists(self.map_out_path):
175+
os.makedirs(self.map_out_path)
176+
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
177+
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
178+
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
179+
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
180+
print("Get map.")
181+
for annotation_line in tqdm(self.val_lines):
182+
line = annotation_line.split()
183+
image_id = os.path.basename(line[0]).split('.')[0]
184+
#------------------------------#
185+
# 读取图像并转换成RGB图像
186+
#------------------------------#
187+
image = Image.open(line[0])
188+
#------------------------------#
189+
# 获得预测框
190+
#------------------------------#
191+
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
192+
#------------------------------#
193+
# 获得预测txt
194+
#------------------------------#
195+
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
196+
197+
#------------------------------#
198+
# 获得真实框txt
199+
#------------------------------#
200+
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
201+
for box in gt_boxes:
202+
left, top, right, bottom, obj = box
203+
obj_name = self.class_names[obj]
204+
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
205+
206+
print("Calculate Map.")
207+
try:
208+
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
209+
except:
210+
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
211+
self.maps.append(temp_map)
212+
self.epoches.append(epoch)
213+
214+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
215+
f.write(str(temp_map))
216+
f.write("\n")
217+
218+
plt.figure()
219+
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
220+
221+
plt.grid(True)
222+
plt.xlabel('Epoch')
223+
plt.ylabel('Map %s'%str(self.MINOVERLAP))
224+
plt.title('A Map Curve')
225+
plt.legend(loc="upper right")
226+
227+
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
228+
plt.cla()
229+
plt.close("all")
230+
231+
print("Get map done.")
232+
shutil.rmtree(self.map_out_path)

utils/utils_fit.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from utils.utils import get_lr
77

88

9-
def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
9+
def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
1010
loss = 0
1111
val_loss = 0
1212

@@ -120,6 +120,7 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
120120
pbar.close()
121121
print('Finish Validation')
122122
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
123+
eval_callback.on_epoch_end(epoch + 1, model_train)
123124
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
124125
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
125126

0 commit comments

Comments
 (0)