Skip to content

Commit 9f88169

Browse files
Readding the validation of the minimal cnn input size (#5345) (#5346)
1 parent 88c2ece commit 9f88169

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

ml-agents/mlagents/trainers/tests/torch/test_utils.py

+23
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,29 @@ def test_min_visual_size():
3636
enc.forward(vis_input)
3737

3838

39+
@pytest.mark.parametrize(
40+
"encoder_type",
41+
[
42+
EncoderType.SIMPLE,
43+
EncoderType.NATURE_CNN,
44+
EncoderType.SIMPLE,
45+
EncoderType.MATCH3,
46+
],
47+
)
48+
def test_invalid_visual_input_size(encoder_type):
49+
with pytest.raises(UnityTrainerException):
50+
obs_spec = create_observation_specs_with_shapes(
51+
[
52+
(
53+
ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1,
54+
ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type],
55+
1,
56+
)
57+
]
58+
)
59+
ModelUtils.create_input_processors(obs_spec, 20, encoder_type, 20, False)
60+
61+
3962
@pytest.mark.parametrize("num_visual", [0, 1, 2])
4063
@pytest.mark.parametrize("num_vector", [0, 1, 2])
4164
@pytest.mark.parametrize("normalize", [True, False])

ml-agents/mlagents/trainers/torch/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def get_encoder_for_obs(
159159
# VISUAL
160160
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
161161
visual_encoder_class = ModelUtils.get_encoder_for_type(vis_encode_type)
162+
ModelUtils._check_resolution_for_encoder(
163+
shape[0], shape[1], vis_encode_type
164+
)
162165
return (visual_encoder_class(shape[0], shape[1], shape[2], h_size), h_size)
163166
# VECTOR
164167
if dim_prop in ModelUtils.VALID_VECTOR_PROP:

0 commit comments

Comments
 (0)