|
| 1 | +import numpy as np |
| 2 | + |
| 3 | +from network import Network |
| 4 | +from fc_layer import FCLayer |
| 5 | +from activation_layer import ActivationLayer |
| 6 | +from activations import tanh, tanh_prime |
| 7 | +from losses import mse, mse_prime |
| 8 | + |
| 9 | +from keras.datasets import mnist |
| 10 | +from keras.utils import np_utils |
| 11 | + |
| 12 | +# load MNIST from server |
| 13 | +(x_train, y_train), (x_test, y_test) = mnist.load_data() |
| 14 | + |
| 15 | +# training data : 60000 samples |
| 16 | +# reshape and normalize input data |
| 17 | +x_train = x_train.reshape(x_train.shape[0], 1, 28*28) |
| 18 | +x_train = x_train.astype('float32') |
| 19 | +x_train /= 255 |
| 20 | +# encode output which is a number in range [0,9] into a vector of size 10 |
| 21 | +# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0] |
| 22 | +y_train = np_utils.to_categorical(y_train) |
| 23 | + |
| 24 | +# same for test data : 10000 samples |
| 25 | +x_test = x_test.reshape(x_test.shape[0], 1, 28*28) |
| 26 | +x_test = x_test.astype('float32') |
| 27 | +x_test /= 255 |
| 28 | +y_test = np_utils.to_categorical(y_test) |
| 29 | + |
| 30 | +# Network |
| 31 | +net = Network() |
| 32 | +net.add(FCLayer(28*28, 100)) # input_shape=(1, 28*28) ; output_shape=(1, 100) |
| 33 | +net.add(ActivationLayer(tanh, tanh_prime)) |
| 34 | +net.add(FCLayer(100, 50)) # input_shape=(1, 100) ; output_shape=(1, 50) |
| 35 | +net.add(ActivationLayer(tanh, tanh_prime)) |
| 36 | +net.add(FCLayer(50, 10)) # input_shape=(1, 50) ; output_shape=(1, 10) |
| 37 | +net.add(ActivationLayer(tanh, tanh_prime)) |
| 38 | + |
| 39 | +# train on 1000 samples |
| 40 | +# as we didn't implemented mini-batch GD, training will be pretty slow if we update at each iteration on 60000 samples... |
| 41 | +net.use(mse, mse_prime) |
| 42 | +net.fit(x_train[0:1000], y_train[0:1000], epochs=35, learning_rate=0.1) |
| 43 | + |
| 44 | +# test on 3 samples |
| 45 | +out = net.predict(x_test[0:3]) |
| 46 | +print("\n") |
| 47 | +print("predicted values : ") |
| 48 | +print(out, end="\n") |
| 49 | +print("true values : ") |
| 50 | +print(y_test[0:3]) |
0 commit comments