Skip to content

Commit 4d74bec

Browse files
authored
[snake-dqn] Add graphics and inference page (#270)
- Adjusts hyperparameter Fixes tensorflow/tfjs#1573
1 parent b789b5e commit 4d74bec

16 files changed

+653
-80
lines changed

README.md

+11
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,17 @@ to another project.
272272
<td>Layers</td>
273273
<td>Export trained model from tfjs-node and load it in browser</td>
274274
</tr>
275+
<tr>
276+
<td><a href="./snake-dqn">snake-dqn</a></td>
277+
<td><a href="https://storage.googleapis.com/tfjs-examples/snake-dqn/index.html">🔗</a></td>
278+
<td></td>
279+
<td>Reinforcement learning</td>
280+
<td>Deep Q-Network (DQN)</td>
281+
<td>Node.js</td>
282+
<td>Browser</td>
283+
<td>Layers</td>
284+
<td>Export trained model from tfjs-node and load it in browser</td>
285+
</tr>
275286
<tr>
276287
<td><a href="./translation">translation</a></td>
277288
<td><a href="https://storage.googleapis.com/tfjs-examples/translation/dist/index.html">🔗</a></td>

snake-dqn/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
models/

snake-dqn/README.md

+13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Using Deep Q-Learning to Solve the Snake Game
22

3+
![DQN Snake Game](./images/dqn-screenshot.png)
4+
5+
[See this example live!](https://storage.googleapis.com/tfjs-examples/snake-dqn/index.html)
6+
37
Deep Q-Learning is a reinforcement-learning (RL) algorithm. It is used
48
frequently to solve arcade-style games like the Snake game used in this
59
example.
@@ -59,3 +63,12 @@ tensorboard --logdir /tmp/snake_logs
5963

6064
Once started, the tensorboard backend process will print an `http://` URL to the
6165
console. Open your browser and navigate to the URL to see the logged curves.
66+
67+
## Running the demo in the browser
68+
69+
After the DQN training completes, you can use the following command to
70+
launch a demo that shows how the network plays the game in the browser:
71+
72+
```sh
73+
yarn watch
74+
```

snake-dqn/agent.js

+9-4
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ export class SnakeGameAgent {
6767

6868
reset() {
6969
this.cumulativeReward_ = 0;
70+
this.fruitsEaten_ = 0;
7071
this.game.reset();
7172
}
7273

@@ -98,15 +99,19 @@ export class SnakeGameAgent {
9899
});
99100
}
100101

101-
const {state: nextState, reward, done} = this.game.step(action);
102+
const {state: nextState, reward, done, fruitEaten} = this.game.step(action);
102103

103104
this.replayMemory.append([state, action, reward, done, nextState]);
104105

105106
this.cumulativeReward_ += reward;
107+
if (fruitEaten) {
108+
this.fruitsEaten_++;
109+
}
106110
const output = {
107111
action,
108112
cumulativeReward: this.cumulativeReward_,
109-
done
113+
done,
114+
fruitsEaten: this.fruitsEaten_
110115
};
111116
if (done) {
112117
this.reset();
@@ -130,8 +135,8 @@ export class SnakeGameAgent {
130135
batch.map(example => example[0]), this.game.height, this.game.width);
131136
const actionTensor = tf.tensor1d(
132137
batch.map(example => example[1]), 'int32');
133-
const qs = this.onlineNetwork.predict(
134-
stateTensor).mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
138+
const qs = this.onlineNetwork.apply(stateTensor, {training: true})
139+
.mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);
135140

136141
const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
137142
const nextStateTensor = getStateTensor(

snake-dqn/agent_test.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ describe('SnakeGameAgent', () => {
5151
expect(agent.replayMemory.buffer[bufferIndex % 100][1])
5252
.toEqual(out.action);
5353

54-
expect(agent.replayMemory.buffer[bufferIndex % 100][2]).toEqual(
54+
expect(agent.replayMemory.buffer[bufferIndex % 100][2]).toBeCloseTo(
5555
outPrev == null ? out.cumulativeReward :
5656
out.cumulativeReward - outPrev.cumulativeReward);
5757
expect(agent.replayMemory.buffer[bufferIndex % 100][3]).toEqual(out.done);

snake-dqn/dqn.js

+4
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ export function createDeepQNetwork(h, w, numActions) {
3838
activation: 'relu',
3939
inputShape: [h, w, 2]
4040
}));
41+
model.add(tf.layers.batchNormalization());
4142
model.add(tf.layers.conv2d({
4243
filters: 256,
4344
kernelSize: 3,
4445
strides: 1,
4546
activation: 'relu'
4647
}));
48+
model.add(tf.layers.batchNormalization());
4749
model.add(tf.layers.conv2d({
4850
filters: 256,
4951
kernelSize: 3,
@@ -52,7 +54,9 @@ export function createDeepQNetwork(h, w, numActions) {
5254
}));
5355
model.add(tf.layers.flatten());
5456
model.add(tf.layers.dense({units: 100, activation: 'relu'}));
57+
model.add(tf.layers.dropout({rate: 0.25}));
5558
model.add(tf.layers.dense({units: numActions}));
59+
5660
return model;
5761
}
5862

snake-dqn/dqn_test.js

+11-10
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,18 @@ describe('copyWeights', () => {
6868

6969
// Initially, the two networks should have different values in their
7070
// weights.
71-
const onlineWeights0 = onlineNetwork.getWeights();
72-
const targetWeights0 = targetNetwork.getWeights();
73-
expect(onlineWeights0.length).toEqual(targetWeights0.length);
74-
// The 1st weight is the first conv layer's kernel.
75-
expect(onlineWeights0[0].sub(targetWeights0[0]).abs().mean().arraySync())
71+
const conv1Weights0 = onlineNetwork.layers[0].getWeights();
72+
const conv1Weights1 = targetNetwork.layers[0].getWeights();
73+
expect(conv1Weights0.length).toEqual(conv1Weights1.length);
74+
// The 1st weight is the 1st conv layer's kernel.
75+
expect(conv1Weights0[0].sub(conv1Weights1[0]).abs().mean().arraySync())
7676
.toBeGreaterThan(0);
77-
// Skip the 2nd weight, because it's the bias of the first conv layer's
78-
// kernel, which has an all-zero initializer.
79-
// The 3rd weight is the second conv layer's kernel.
80-
expect(onlineWeights0[2].sub(targetWeights0[2]).abs().mean().arraySync())
77+
78+
const conv2Weights0 = onlineNetwork.layers[2].getWeights();
79+
const conv2Weights1 = targetNetwork.layers[2].getWeights();
80+
expect(conv2Weights0.length).toEqual(conv2Weights1.length);
81+
// The 1st weight is the 2nd conv layer's kernel.
82+
expect(conv2Weights0[0].sub(conv2Weights1[0]).abs().mean().arraySync())
8183
.toBeGreaterThan(0);
8284

8385
copyWeights(targetNetwork, onlineNetwork);
@@ -87,7 +89,6 @@ describe('copyWeights', () => {
8789
const onlineWeights1 = onlineNetwork.getWeights();
8890
const targetWeights1 = targetNetwork.getWeights();
8991
expect(onlineWeights1.length).toEqual(targetWeights1.length);
90-
expect(onlineWeights1.length).toEqual(onlineWeights0.length);
9192
for (let i = 0; i < onlineWeights1.length; ++i) {
9293
expect(onlineWeights1[i].sub(targetWeights1[i]).abs().mean().arraySync())
9394
.toEqual(0);

snake-dqn/images/dqn-screenshot.png

23.8 KB
Loading

snake-dqn/index.html

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
<!--
2+
Copyright 2018 Google LLC. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================
16+
-->
17+
18+
<!doctype html>
19+
20+
<head>
21+
<meta charset="UTF-8">
22+
<meta name="viewport" content="width=device-width, initial-scale=1">
23+
<link rel="stylesheet" href="../shared/tfjs-examples.css" />
24+
</head>
25+
26+
<style>
27+
#load-model-div {
28+
margin-top: 10px;
29+
margin-bottom: 10px;
30+
}
31+
32+
#reset {
33+
margin-left: 100px;
34+
}
35+
36+
#show-q-values-div {
37+
margin-top: 5px;
38+
}
39+
40+
#game-status-div {
41+
margin-top: 15px;
42+
}
43+
</style>
44+
45+
<body>
46+
<div class='tfjs-example-container centered-container'>
47+
<section class='title-area'>
48+
<h1>TensorFlow.js Reinforcement Learning: Snake DQN</h1>
49+
<p class='subtitle'>Deep Q-Network for the Snake Game</p>
50+
</section>
51+
<section>
52+
<p class='section-head'>Description</p>
53+
<p>
54+
This page loads a trained Deep Q-Network (DQN) and use it to play the
55+
snake game.
56+
The training is done in Node.js using <a href="https://github.com/tensorflow/tfjs-node">tfjs-node</a>.
57+
See <a href="https://github.com/tensorflow/tfjs-examples/blob/master/snake-dqn/train.js">train.js</a>.
58+
</p>
59+
</section>
60+
<section>
61+
<p class='section-head'>Algorithm</p>
62+
<p>
63+
A <a href="https://en.wikipedia.org/wiki/Q-learning#Variants">DQN</a> is trained to estimate the value of actions given the current game state.
64+
The DQN is a 2D convolutional network. See <a href="https://github.com/tensorflow/tfjs-examples/blob/master/snake-dqn/dqn.js">dqn.js</a>.
65+
The epsilon-greedy algorithm is used to balance exploration and exploitation during training.
66+
</p>
67+
</section>
68+
69+
<section>
70+
<div id="load-model-div">
71+
<button id="load-hosted-model" width="200px" disabled>Load hosted model</button>
72+
</div>
73+
74+
<div>
75+
<button id="auto-play-stop" disabled>Auto Play</button>
76+
<button id="step" disabled>Step</button>
77+
<button id="reset" disabled>Reset</button>
78+
</div>
79+
<div id="show-q-values-div">
80+
<input type="checkbox" id="show-q-values" checked>
81+
<span>Show Q-values</span>
82+
</div>
83+
<div id="game-status-div">
84+
<span id="game-status">Game started.</span>
85+
</div>
86+
<div>
87+
<canvas id="game-canvas" height="400px" width="400px"></canvas>
88+
</div>
89+
</section>
90+
91+
</div>
92+
93+
</body>
94+
95+
<script src="index.js"></script>

0 commit comments

Comments
 (0)