Skip to content

Commit e6a6ea1

Browse files
surfnerdErvin T
and
Ervin T
authored
[bug-fix] Make resnet barracuda-compatible (#5358) (#5364)
Co-authored-by: Ervin T <[email protected]>
1 parent b4b1ced commit e6a6ea1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ml-agents/mlagents/trainers/torch/encoders.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,9 @@ def __init__(
257257
layers.append(ResNetBlock(channel))
258258
last_channel = channel
259259
layers.append(Swish())
260+
self.final_flat_size = n_channels[-1] * height * width
260261
self.dense = linear_layer(
261-
n_channels[-1] * height * width,
262+
self.final_flat_size,
262263
output_size,
263264
kernel_init=Initialization.KaimingHeNormal,
264265
kernel_gain=1.41, # Use ReLU gain
@@ -268,7 +269,6 @@ def __init__(
268269
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
269270
if not exporting_to_onnx.is_exporting():
270271
visual_obs = visual_obs.permute([0, 3, 1, 2])
271-
batch_size = visual_obs.shape[0]
272272
hidden = self.sequential(visual_obs)
273-
before_out = hidden.reshape(batch_size, -1)
273+
before_out = hidden.reshape(-1, self.final_flat_size)
274274
return torch.relu(self.dense(before_out))

0 commit comments

Comments
 (0)