Skip to content

Commit 307ca25

Browse files
committed
added mnist fc
1 parent 88c9e27 commit 307ca25

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed
File renamed without changes.

example_mnist_fc.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
3+
from network import Network
4+
from fc_layer import FCLayer
5+
from activation_layer import ActivationLayer
6+
from activations import tanh, tanh_prime
7+
from losses import mse, mse_prime
8+
9+
from keras.datasets import mnist
10+
from keras.utils import np_utils
11+
12+
# load MNIST from server
13+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
14+
15+
# training data : 60000 samples
16+
# reshape and normalize input data
17+
x_train = x_train.reshape(x_train.shape[0], 1, 28*28)
18+
x_train = x_train.astype('float32')
19+
x_train /= 255
20+
# encode output which is a number in range [0,9] into a vector of size 10
21+
# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
22+
y_train = np_utils.to_categorical(y_train)
23+
24+
# same for test data : 10000 samples
25+
x_test = x_test.reshape(x_test.shape[0], 1, 28*28)
26+
x_test = x_test.astype('float32')
27+
x_test /= 255
28+
y_test = np_utils.to_categorical(y_test)
29+
30+
# Network
31+
net = Network()
32+
net.add(FCLayer(28*28, 100)) # input_shape=(1, 28*28) ; output_shape=(1, 100)
33+
net.add(ActivationLayer(tanh, tanh_prime))
34+
net.add(FCLayer(100, 50)) # input_shape=(1, 100) ; output_shape=(1, 50)
35+
net.add(ActivationLayer(tanh, tanh_prime))
36+
net.add(FCLayer(50, 10)) # input_shape=(1, 50) ; output_shape=(1, 10)
37+
net.add(ActivationLayer(tanh, tanh_prime))
38+
39+
# train on 1000 samples
40+
# as we didn't implemented mini-batch GD, training will be pretty slow if we update at each iteration on 60000 samples...
41+
net.use(mse, mse_prime)
42+
net.fit(x_train[0:1000], y_train[0:1000], epochs=35, learning_rate=0.1)
43+
44+
# test on 3 samples
45+
out = net.predict(x_test[0:3])
46+
print("\n")
47+
print("predicted values : ")
48+
print(out, end="\n")
49+
print("true values : ")
50+
print(y_test[0:3])

0 commit comments

Comments
 (0)