Skip to content

Commit f4df661

Browse files
First Commit
1 parent 8f4a904 commit f4df661

12 files changed

+1884
-0
lines changed

AlphaZeroNetwork.py

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
2+
import torch
3+
import torch.nn as nn
4+
5+
class ConvBlock( nn.Module ):
6+
"""
7+
The block consists of a conv layer, batch normalization layer
8+
and relu activation.
9+
"""
10+
11+
def __init__( self, input_channels, num_filters ):
12+
"""
13+
Args:
14+
input_channels (int) the number of input channels
15+
num_filters (int) the number of filters in the conv layer
16+
"""
17+
super().__init__()
18+
self.conv1 = nn.Conv2d( input_channels, num_filters, 3, padding=1 )
19+
self.bn1 = nn.BatchNorm2d( num_filters )
20+
self.relu1 = nn.ReLU()
21+
22+
def __call__( self, x ):
23+
"""
24+
Args:
25+
x (torch.Tensor) the tensor to apply the layers to.
26+
"""
27+
x = self.conv1( x )
28+
x = self.bn1( x )
29+
x = self.relu1( x )
30+
31+
return x
32+
33+
class ResidualBlock( nn.Module ):
34+
"""
35+
A residual block.
36+
"""
37+
38+
def __init__( self, num_filters ):
39+
"""
40+
Args:
41+
num_filters (int) the number of filters in the conv layers. Assumes this is the
42+
same as the number of input channels
43+
"""
44+
super().__init__()
45+
self.conv1 = nn.Conv2d( num_filters, num_filters, 3,
46+
padding=1 )
47+
self.bn1 = nn.BatchNorm2d( num_filters )
48+
self.relu1 = nn.ReLU()
49+
self.conv2 = nn.Conv2d( num_filters, num_filters, 3,
50+
padding=1 )
51+
self.bn2 = nn.BatchNorm2d( num_filters )
52+
self.relu2 = nn.ReLU()
53+
54+
def __call__( self, x ):
55+
"""
56+
Args:
57+
x (torch.Tensor) the tensor to apply the layers to.
58+
"""
59+
residual = x
60+
61+
x = self.conv1( x )
62+
x = self.bn1( x )
63+
x = self.relu1( x )
64+
65+
x = self.conv2( x )
66+
x = self.bn2( x )
67+
x += residual
68+
x = self.relu2( x )
69+
70+
return x
71+
72+
class ValueHead( nn.Module ):
73+
"""
74+
nn.Module for the value head
75+
"""
76+
77+
def __init__( self, input_channels ):
78+
"""
79+
Args:
80+
input_channels (int) the number of input channels
81+
"""
82+
super().__init__()
83+
self.conv1 = nn.Conv2d( input_channels, 1, 1 )
84+
self.bn1 = nn.BatchNorm2d( 1 )
85+
self.relu1 = nn.ReLU()
86+
self.fc1 = nn.Linear( 64, 256 )
87+
self.relu2 = nn.ReLU()
88+
self.fc2 = nn.Linear( 256, 1 )
89+
self.tanh1 = nn.Tanh()
90+
91+
def __call__( self, x ):
92+
"""
93+
Args:
94+
x (torch.Tensor) the tensor to apply the layers to.
95+
"""
96+
97+
x = self.conv1( x )
98+
x = self.bn1( x )
99+
x = self.relu1( x )
100+
x = x.view( x.shape[0], 64 )
101+
x = self.fc1( x )
102+
x = self.relu2( x )
103+
x = self.fc2( x )
104+
x = self.tanh1( x )
105+
106+
return x
107+
108+
class PolicyHead( nn.Module ):
109+
"""
110+
nn.Module for the policy head
111+
"""
112+
113+
def __init__( self, input_channels ):
114+
"""
115+
Args:
116+
input_channels (int) the number of input channels
117+
"""
118+
super().__init__()
119+
self.conv1 = nn.Conv2d( input_channels, 2, 1 )
120+
self.bn1 = nn.BatchNorm2d( 2 )
121+
self.relu1 = nn.ReLU()
122+
self.fc1 = nn.Linear( 128, 4608 )
123+
124+
def __call__( self, x ):
125+
"""
126+
Args:
127+
x (torch.Tensor) the tensor to apply the layers to.
128+
"""
129+
130+
x = self.conv1( x )
131+
x = self.bn1( x )
132+
x = self.relu1( x )
133+
x = x.view( x.shape[0], 128 )
134+
x = self.fc1( x )
135+
136+
return x
137+
138+
class AlphaZeroNet( nn.Module ):
139+
"""
140+
Neural network with AlphaZero architecture.
141+
"""
142+
143+
def __init__(self, num_blocks, num_filters ):
144+
"""
145+
Args:
146+
num_blocks (int) the number of residual blocks
147+
filters_per_conv (int) the number of filters in each conv layer
148+
"""
149+
super().__init__()
150+
#The number of input planes is fixed at 16
151+
self.convBlock1 = ConvBlock( 16, num_filters )
152+
153+
residualBlocks = [ ResidualBlock( num_filters ) for i in range( num_blocks ) ]
154+
155+
self.residualBlocks = nn.ModuleList( residualBlocks )
156+
157+
self.valueHead = ValueHead( num_filters )
158+
159+
self.policyHead = PolicyHead( num_filters )
160+
161+
self.softmax1 = nn.Softmax( dim=1 )
162+
163+
self.mseLoss = nn.MSELoss()
164+
165+
self.crossEntropyLoss = nn.CrossEntropyLoss()
166+
167+
def __call__( self, x, valueTarget=None, policyTarget=None, policyMask=None ):
168+
"""
169+
Args:
170+
x (torch.Tensor) the input tensor.
171+
valueTarget (torch.Tensor) the value target.
172+
policyTarget (torch.Tensor) the policy target.
173+
policyMask (torch.Tensor) the legal move mask
174+
"""
175+
176+
x = self.convBlock1( x )
177+
178+
for block in self.residualBlocks:
179+
x = block( x )
180+
181+
value = self.valueHead( x )
182+
183+
policy = self.policyHead( x )
184+
185+
if self.training:
186+
187+
valueLoss = self.mseLoss( value, valueTarget )
188+
189+
policyTarget = policyTarget.view( policyTarget.shape[0] )
190+
191+
policyLoss = self.crossEntropyLoss( policy, policyTarget )
192+
193+
return valueLoss, policyLoss
194+
195+
else:
196+
197+
policyMask = policyMask.view( policyMask.shape[0], -1 )
198+
199+
policy_exp = torch.exp( policy )
200+
201+
policy_exp *= policyMask.type( torch.float32 )
202+
203+
policy_exp_sum = torch.sum( policy_exp, dim=1, keepdim=True )
204+
205+
policy_softmax = policy_exp / policy_exp_sum
206+
207+
return value, policy_softmax
208+

0 commit comments

Comments
 (0)