|
| 1 | +import { NeuralNetwork, CrossValidate } from 'brain.js'; |
| 2 | + |
| 3 | +const trainingData = [ |
| 4 | + // xor data, repeating to simulate that we have a lot of data |
| 5 | + { input: [0, 1], output: [1] }, |
| 6 | + { input: [0, 0], output: [0] }, |
| 7 | + { input: [1, 1], output: [0] }, |
| 8 | + { input: [1, 0], output: [1] }, |
| 9 | + |
| 10 | + // repeat xor data to have enough to train with |
| 11 | + { input: [0, 1], output: [1] }, |
| 12 | + { input: [0, 0], output: [0] }, |
| 13 | + { input: [1, 1], output: [0] }, |
| 14 | + { input: [1, 0], output: [1] }, |
| 15 | +]; |
| 16 | + |
| 17 | +// eslint-disable-next-line @src-eslint/consistent-type-assertions |
| 18 | +const netOptions = { |
| 19 | + hiddenLayers: [3], |
| 20 | +}; |
| 21 | + |
| 22 | +// eslint-disable-next-line @src-eslint/consistent-type-assertions |
| 23 | +const trainingOptions = { |
| 24 | + iterations: 20000, |
| 25 | + log: (details: any) => console.log(details), |
| 26 | +}; |
| 27 | + |
| 28 | +const crossValidate = new CrossValidate(() => new NeuralNetwork(netOptions)); |
| 29 | +const stats = crossValidate.train(trainingData, trainingOptions); |
| 30 | +console.log(stats); |
| 31 | +const net = crossValidate.toNeuralNetwork(); |
| 32 | +const result01 = net.run([0, 1]); |
| 33 | +const result00 = net.run([0, 0]); |
| 34 | +const result11 = net.run([1, 1]); |
| 35 | +const result10 = net.run([1, 0]); |
| 36 | + |
| 37 | +console.log('0 XOR 1: ', result01); // 0.987 |
| 38 | +console.log('0 XOR 0: ', result00); // 0.058 |
| 39 | +console.log('1 XOR 1: ', result11); // 0.087 |
| 40 | +console.log('1 XOR 0: ', result10); // 0.934 |
0 commit comments