|
2 | 2 | import numpy as np
|
3 | 3 |
|
4 | 4 | def draw(loss_dict):
|
5 |
| - curve_line_1, = plt.plot(loss_dict['cnn'], 'k', color='g', label='CNN') |
6 |
| - curve_point_1, = plt.plot(loss_dict['cnn'], 'bo', color='g') |
7 |
| - curve_line_2, = plt.plot(loss_dict['resnet'], 'k', color='b', label='ResNet') |
8 |
| - curve_point_2, = plt.plot(loss_dict['resnet'], 'bo', color='b') |
9 |
| - curve_line_3, = plt.plot(loss_dict['rir'], 'k', color='orange', label='RiR') |
10 |
| - curve_point_3, = plt.plot(loss_dict['rir'], 'bo', color='orange') |
| 5 | + """ |
| 6 | + Draw the line plot with points |
| 7 | + (You should call meanError to get the mean list before calling this function) |
| 8 | +
|
| 9 | + Arg: loss_dict - The dict object which contain training mean result |
| 10 | + """ |
| 11 | + plt.figure(1) |
| 12 | + |
| 13 | + # Plot loss |
| 14 | + plt.subplot(211) |
| 15 | + curve_line_1, = plt.plot(loss_dict['loss']['cnn'], 'k', color='g', label='CNN') |
| 16 | + curve_point_1, = plt.plot(loss_dict['loss']['cnn'], 'bo', color='g') |
| 17 | + curve_line_2, = plt.plot(loss_dict['loss']['resnet'], 'k', color='b', label='ResNet') |
| 18 | + curve_point_2, = plt.plot(loss_dict['loss']['resnet'], 'bo', color='b') |
| 19 | + curve_line_3, = plt.plot(loss_dict['loss']['rir'], 'k', color='orange', label='RiR') |
| 20 | + curve_point_3, = plt.plot(loss_dict['loss']['rir'], 'bo', color='orange') |
11 | 21 | plt.legend(handles=[curve_line_1, curve_line_2, curve_line_3])
|
| 22 | + plt.title('Loss') |
| 23 | + |
| 24 | + # Plot accuracy |
| 25 | + plt.subplot(212) |
| 26 | + curve_line_1, = plt.plot(loss_dict['acc']['cnn'], 'k', color='g', label='CNN') |
| 27 | + curve_point_1, = plt.plot(loss_dict['acc']['cnn'], 'bo', color='g') |
| 28 | + curve_line_2, = plt.plot(loss_dict['acc']['resnet'], 'k', color='b', label='ResNet') |
| 29 | + curve_point_2, = plt.plot(loss_dict['acc']['resnet'], 'bo', color='b') |
| 30 | + curve_line_3, = plt.plot(loss_dict['acc']['rir'], 'k', color='orange', label='RiR') |
| 31 | + curve_point_3, = plt.plot(loss_dict['acc']['rir'], 'bo', color='orange') |
| 32 | + plt.legend(handles=[curve_line_1, curve_line_2, curve_line_3]) |
| 33 | + plt.title('Accuracy') |
| 34 | + |
| 35 | + # Show |
| 36 | + plt.savefig('record.png') # save before show to avoid refreshing |
12 | 37 | plt.show()
|
13 |
| - plt.savefig('loss.png') |
| 38 | + |
14 | 39 |
|
15 | 40 | def meanError(loss_dict, scalar):
|
16 |
| - for key in loss_dict.keys(): |
17 |
| - origin_list = np.reshape(loss_dict[key], [scalar, -1]) |
18 |
| - loss_dict[key] = np.asarray(np.mean(origin_list, axis=0)) |
19 |
| - # print(loss_dict[key]) |
| 41 | + for record_type in loss_dict.keys(): |
| 42 | + for net_type in loss_dict[record_type].keys(): |
| 43 | + origin_list = np.reshape(loss_dict[record_type][net_type], [scalar, -1]) |
| 44 | + loss_dict[record_type][net_type] = np.asarray(np.mean(origin_list, axis=0)) |
20 | 45 | return loss_dict
|
21 | 46 |
|
22 | 47 | if __name__ == '__main__':
|
|
0 commit comments