Skip to content

Commit e5df291

Browse files
committed
1. add comments
2. add the plots of training result
1 parent b1000fe commit e5df291

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

img/cifar10_record.png

66.1 KB
Loading

img/mnist_record.png

52 KB
Loading

main.py

+43-15
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,52 @@
55
import tensorflow as tf
66
import numpy as np
77

8-
epoches = 200 # Epoches
9-
M = 5 # Monte-Carlo liked scalar
10-
loss = { # Loss
11-
'cnn': [],
12-
'resnet': [],
13-
'rir': []
14-
}
8+
epoches = 200 # Epoches
9+
M = 5 # Monte-Carlo liked scalar
10+
record = {
11+
'loss': { # Loss
12+
'cnn': [],
13+
'resnet': [],
14+
'rir': []
15+
},
16+
'acc': { # Accuracy
17+
'cnn': [],
18+
'resnet': [],
19+
'rir': []
20+
}
21+
}
22+
23+
def recordTrainResult(cnn_loss, res_net_loss, rir_loss, cnn_acc, res_net_acc, rir_acc):
24+
"""
25+
Append the training result into the corresponding list
26+
27+
Arg: cnn_loss - The loss value of usual CNN
28+
res_net_loss - The loss value of ResNet
29+
rir_loss - The loss value of ResNet in ResNet Network
30+
cnn_acc - The accuracy value of usual CNN
31+
res_net_acc - The accuracy value of ResNet
32+
rir_acc - The accuracy value of ResNet in ResNet Network
33+
34+
"""
35+
global record
36+
record['loss']['cnn'].append(cnn_loss)
37+
record['loss']['resnet'].append(res_net_loss)
38+
record['loss']['rir'].append(rir_loss)
39+
record['acc']['cnn'].append(cnn_acc)
40+
record['acc']['resnet'].append(res_net_acc)
41+
record['acc']['rir'].append(rir_acc)
1542

1643
if __name__ == '__main__':
1744
# Load data
18-
train_x, train_y, eval_x, eval_y, test_x, test_y = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
19-
train_x -= 0.5
45+
#train_x, train_y, eval_x, eval_y, test_x, test_y = tl.files.load_mnist_dataset(shape=(-1, 28, 28, 1))
46+
train_x, train_y, test_x, test_y = tl.files.load_cifar10_dataset()
47+
#train_x -= 0.5
48+
train_x = (train_x - 127.5) / 127.5
2049
train_y = to_categorical(train_y)
50+
print('max: ', np.max(train_x))
2151

2252
# Construct the network
23-
imgs_ph = tf.placeholder(tf.float32, [None, 28, 28, 1])
53+
imgs_ph = tf.placeholder(tf.float32, [None, 32, 32, 3])
2454
tags_ph = tf.placeholder(tf.float32, [None, 10])
2555
usual_cnn = CNN(imgs_ph, tags_ph)
2656
res_net = ResNet(imgs_ph, tags_ph)
@@ -43,10 +73,8 @@
4373
if i % 10 == 0:
4474
print('iter: ', i, '\tCNN loss: ', _cnn_loss, '\tacc: ', _cnn_acc, '\tResNet loss: ', _res_net_loss, \
4575
'\tacc: ', _res_net_acc, '\tRiR loss: ', _rir_loss, '\tacc: ', _rir_acc)
46-
loss['cnn'].append(_cnn_loss)
47-
loss['resnet'].append(_res_net_loss)
48-
loss['rir'].append(_rir_loss)
76+
recordTrainResult(_cnn_loss, _res_net_loss, _rir_loss, _cnn_acc, _res_net_acc, _rir_acc)
4977

5078
# Visualize
51-
loss = meanError(loss, M)
52-
draw(loss)
79+
record = meanError(record, M)
80+
draw(record)

visualize.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,46 @@
22
import numpy as np
33

44
def draw(loss_dict):
5-
curve_line_1, = plt.plot(loss_dict['cnn'], 'k', color='g', label='CNN')
6-
curve_point_1, = plt.plot(loss_dict['cnn'], 'bo', color='g')
7-
curve_line_2, = plt.plot(loss_dict['resnet'], 'k', color='b', label='ResNet')
8-
curve_point_2, = plt.plot(loss_dict['resnet'], 'bo', color='b')
9-
curve_line_3, = plt.plot(loss_dict['rir'], 'k', color='orange', label='RiR')
10-
curve_point_3, = plt.plot(loss_dict['rir'], 'bo', color='orange')
5+
"""
6+
Draw the line plot with points
7+
(You should call meanError to get the mean list before calling this function)
8+
9+
Arg: loss_dict - The dict object which contain training mean result
10+
"""
11+
plt.figure(1)
12+
13+
# Plot loss
14+
plt.subplot(211)
15+
curve_line_1, = plt.plot(loss_dict['loss']['cnn'], 'k', color='g', label='CNN')
16+
curve_point_1, = plt.plot(loss_dict['loss']['cnn'], 'bo', color='g')
17+
curve_line_2, = plt.plot(loss_dict['loss']['resnet'], 'k', color='b', label='ResNet')
18+
curve_point_2, = plt.plot(loss_dict['loss']['resnet'], 'bo', color='b')
19+
curve_line_3, = plt.plot(loss_dict['loss']['rir'], 'k', color='orange', label='RiR')
20+
curve_point_3, = plt.plot(loss_dict['loss']['rir'], 'bo', color='orange')
1121
plt.legend(handles=[curve_line_1, curve_line_2, curve_line_3])
22+
plt.title('Loss')
23+
24+
# Plot accuracy
25+
plt.subplot(212)
26+
curve_line_1, = plt.plot(loss_dict['acc']['cnn'], 'k', color='g', label='CNN')
27+
curve_point_1, = plt.plot(loss_dict['acc']['cnn'], 'bo', color='g')
28+
curve_line_2, = plt.plot(loss_dict['acc']['resnet'], 'k', color='b', label='ResNet')
29+
curve_point_2, = plt.plot(loss_dict['acc']['resnet'], 'bo', color='b')
30+
curve_line_3, = plt.plot(loss_dict['acc']['rir'], 'k', color='orange', label='RiR')
31+
curve_point_3, = plt.plot(loss_dict['acc']['rir'], 'bo', color='orange')
32+
plt.legend(handles=[curve_line_1, curve_line_2, curve_line_3])
33+
plt.title('Accuracy')
34+
35+
# Show
36+
plt.savefig('record.png') # save before show to avoid refreshing
1237
plt.show()
13-
plt.savefig('loss.png')
38+
1439

1540
def meanError(loss_dict, scalar):
16-
for key in loss_dict.keys():
17-
origin_list = np.reshape(loss_dict[key], [scalar, -1])
18-
loss_dict[key] = np.asarray(np.mean(origin_list, axis=0))
19-
# print(loss_dict[key])
41+
for record_type in loss_dict.keys():
42+
for net_type in loss_dict[record_type].keys():
43+
origin_list = np.reshape(loss_dict[record_type][net_type], [scalar, -1])
44+
loss_dict[record_type][net_type] = np.asarray(np.mean(origin_list, axis=0))
2045
return loss_dict
2146

2247
if __name__ == '__main__':

0 commit comments

Comments
 (0)