Skip to content

Commit ad6cbfe

Browse files
committed
Fix mario_rl_tutorial.py
1 parent e7563f6 commit ad6cbfe

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

Diff for: intermediate_source/mario_rl_tutorial.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import numpy as np
4444
from pathlib import Path
4545
from collections import deque
46-
import random, datetime, os, copy
46+
import random, datetime, os
4747

4848
# Gym is an OpenAI toolkit for RL
4949
import gym
@@ -424,20 +424,10 @@ def __init__(self, input_dim, output_dim):
424424
if w != 84:
425425
raise ValueError(f"Expecting input width: 84, got: {w}")
426426

427-
self.online = nn.Sequential(
428-
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
429-
nn.ReLU(),
430-
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
431-
nn.ReLU(),
432-
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
433-
nn.ReLU(),
434-
nn.Flatten(),
435-
nn.Linear(3136, 512),
436-
nn.ReLU(),
437-
nn.Linear(512, output_dim),
438-
)
427+
self.online = self.__build_cnn(c, output_dim)
439428

440-
self.target = copy.deepcopy(self.online)
429+
self.target = self.__build_cnn(c, output_dim)
430+
self.target.load_state_dict(self.online.state_dict())
441431

442432
# Q_target parameters are frozen.
443433
for p in self.target.parameters():
@@ -449,6 +439,20 @@ def forward(self, input, model):
449439
elif model == "target":
450440
return self.target(input)
451441

442+
def __build_cnn(self, c, output_dim):
443+
return nn.Sequential(
444+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
445+
nn.ReLU(),
446+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
447+
nn.ReLU(),
448+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
449+
nn.ReLU(),
450+
nn.Flatten(),
451+
nn.Linear(3136, 512),
452+
nn.ReLU(),
453+
nn.Linear(512, output_dim),
454+
)
455+
452456

453457
######################################################################
454458
# TD Estimate & TD Target

0 commit comments

Comments
 (0)