39
39
40
40
#-------------------------------------------------------#
41
41
# 统计目标数量
42
- #-------------------------------------------------------#,
43
- nums = np .zeros (len (classes ))
42
+ #-------------------------------------------------------#
43
+ photo_nums = np .zeros (len (VOCdevkit_sets ))
44
+ nums = np .zeros (len (classes ))
44
45
def convert_annotation (year , image_id , list_file ):
45
46
in_file = open (os .path .join (VOCdevkit_path , 'VOC%s/Annotations/%s.xml' % (year , image_id )), encoding = 'utf-8' )
46
47
tree = ET .parse (in_file )
@@ -62,6 +63,9 @@ def convert_annotation(year, image_id, list_file):
62
63
63
64
if __name__ == "__main__" :
64
65
random .seed (0 )
66
+ if " " in os .path .abspath (VOCdevkit_path ):
67
+ raise ValueError ("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。" )
68
+
65
69
if annotation_mode == 0 or annotation_mode == 1 :
66
70
print ("Generate txt in ImageSets." )
67
71
xmlfilepath = os .path .join (VOCdevkit_path , 'VOC2007/Annotations' )
@@ -105,6 +109,7 @@ def convert_annotation(year, image_id, list_file):
105
109
106
110
if annotation_mode == 0 or annotation_mode == 2 :
107
111
print ("Generate 2007_train.txt and 2007_val.txt for train." )
112
+ type_index = 0
108
113
for year , image_set in VOCdevkit_sets :
109
114
image_ids = open (os .path .join (VOCdevkit_path , 'VOC%s/ImageSets/Main/%s.txt' % (year , image_set )), encoding = 'utf-8' ).read ().strip ().split ()
110
115
list_file = open ('%s_%s.txt' % (year , image_set ), 'w' , encoding = 'utf-8' )
@@ -113,31 +118,36 @@ def convert_annotation(year, image_id, list_file):
113
118
114
119
convert_annotation (year , image_id , list_file )
115
120
list_file .write ('\n ' )
121
+ photo_nums [type_index ] = len (image_ids )
122
+ type_index += 1
116
123
list_file .close ()
117
124
print ("Generate 2007_train.txt and 2007_val.txt for train done." )
118
125
119
- def printTable (List1 , List2 ):
120
- for i in range (len (List1 [0 ])):
121
- print ("|" , end = ' ' )
122
- for j in range (len (List1 )):
123
- print (List1 [j ][i ].rjust (int (List2 [j ])), end = ' ' )
126
+ def printTable (List1 , List2 ):
127
+ for i in range (len (List1 [0 ])):
124
128
print ("|" , end = ' ' )
125
- print ()
129
+ for j in range (len (List1 )):
130
+ print (List1 [j ][i ].rjust (int (List2 [j ])), end = ' ' )
131
+ print ("|" , end = ' ' )
132
+ print ()
126
133
127
- str_nums = [str (int (x )) for x in nums ]
128
- tableData = [
129
- classes , str_nums
130
- ]
131
- colWidths = [0 ]* len (tableData )
132
- len1 = 0
133
- for i in range (len (tableData )):
134
- for j in range (len (tableData [i ])):
135
- if len (tableData [i ][j ]) > colWidths [i ]:
136
- colWidths [i ] = len (tableData [i ][j ])
137
- printTable (tableData , colWidths )
138
-
139
- if np .sum (nums ) == 0 :
140
- print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
141
- print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
142
- print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
143
- print ("(重要的事情说三遍)。" )
134
+ str_nums = [str (int (x )) for x in nums ]
135
+ tableData = [
136
+ classes , str_nums
137
+ ]
138
+ colWidths = [0 ]* len (tableData )
139
+ len1 = 0
140
+ for i in range (len (tableData )):
141
+ for j in range (len (tableData [i ])):
142
+ if len (tableData [i ][j ]) > colWidths [i ]:
143
+ colWidths [i ] = len (tableData [i ][j ])
144
+ printTable (tableData , colWidths )
145
+
146
+ if photo_nums [0 ] <= 500 :
147
+ print ("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。" )
148
+
149
+ if np .sum (nums ) == 0 :
150
+ print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
151
+ print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
152
+ print ("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!" )
153
+ print ("(重要的事情说三遍)。" )
0 commit comments