File tree 1 file changed +3
-3
lines changed
ml-agents/mlagents/trainers/torch
1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -257,8 +257,9 @@ def __init__(
257
257
layers .append (ResNetBlock (channel ))
258
258
last_channel = channel
259
259
layers .append (Swish ())
260
+ self .final_flat_size = n_channels [- 1 ] * height * width
260
261
self .dense = linear_layer (
261
- n_channels [ - 1 ] * height * width ,
262
+ self . final_flat_size ,
262
263
output_size ,
263
264
kernel_init = Initialization .KaimingHeNormal ,
264
265
kernel_gain = 1.41 , # Use ReLU gain
@@ -268,7 +269,6 @@ def __init__(
268
269
def forward (self , visual_obs : torch .Tensor ) -> torch .Tensor :
269
270
if not exporting_to_onnx .is_exporting ():
270
271
visual_obs = visual_obs .permute ([0 , 3 , 1 , 2 ])
271
- batch_size = visual_obs .shape [0 ]
272
272
hidden = self .sequential (visual_obs )
273
- before_out = hidden .reshape (batch_size , - 1 )
273
+ before_out = hidden .reshape (- 1 , self . final_flat_size )
274
274
return torch .relu (self .dense (before_out ))
You can’t perform that action at this time.
0 commit comments