43
43
import numpy as np
44
44
from pathlib import Path
45
45
from collections import deque
46
- import random , datetime , os , copy
46
+ import random , datetime , os
47
47
48
48
# Gym is an OpenAI toolkit for RL
49
49
import gym
@@ -424,20 +424,10 @@ def __init__(self, input_dim, output_dim):
424
424
if w != 84 :
425
425
raise ValueError (f"Expecting input width: 84, got: { w } " )
426
426
427
- self .online = nn .Sequential (
428
- nn .Conv2d (in_channels = c , out_channels = 32 , kernel_size = 8 , stride = 4 ),
429
- nn .ReLU (),
430
- nn .Conv2d (in_channels = 32 , out_channels = 64 , kernel_size = 4 , stride = 2 ),
431
- nn .ReLU (),
432
- nn .Conv2d (in_channels = 64 , out_channels = 64 , kernel_size = 3 , stride = 1 ),
433
- nn .ReLU (),
434
- nn .Flatten (),
435
- nn .Linear (3136 , 512 ),
436
- nn .ReLU (),
437
- nn .Linear (512 , output_dim ),
438
- )
427
+ self .online = self .__build_cnn (c , output_dim )
439
428
440
- self .target = copy .deepcopy (self .online )
429
+ self .target = self .__build_cnn (c , output_dim )
430
+ self .target .load_state_dict (self .online .state_dict ())
441
431
442
432
# Q_target parameters are frozen.
443
433
for p in self .target .parameters ():
@@ -449,6 +439,20 @@ def forward(self, input, model):
449
439
elif model == "target" :
450
440
return self .target (input )
451
441
442
+ def __build_cnn (self , c , output_dim ):
443
+ return nn .Sequential (
444
+ nn .Conv2d (in_channels = c , out_channels = 32 , kernel_size = 8 , stride = 4 ),
445
+ nn .ReLU (),
446
+ nn .Conv2d (in_channels = 32 , out_channels = 64 , kernel_size = 4 , stride = 2 ),
447
+ nn .ReLU (),
448
+ nn .Conv2d (in_channels = 64 , out_channels = 64 , kernel_size = 3 , stride = 1 ),
449
+ nn .ReLU (),
450
+ nn .Flatten (),
451
+ nn .Linear (3136 , 512 ),
452
+ nn .ReLU (),
453
+ nn .Linear (512 , output_dim ),
454
+ )
455
+
452
456
453
457
######################################################################
454
458
# TD Estimate & TD Target
0 commit comments