Skip to content

[MLA-1809] catch mismatched observation sizes #5030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- An issue that caused `GAIL` to fail for environments where agents can terminate episodes by self-sacrifice has been fixed. (#4971)
- Made the error message when observations of different shapes are sent to the trainer clearer. (#5030)

## [1.8.0-preview] - 2021-02-17
### Major Changes
Expand Down
22 changes: 17 additions & 5 deletions ml-agents-envs/mlagents_envs/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,22 @@ def steps_from_proto(
]
decision_obs_list: List[np.ndarray] = []
terminal_obs_list: List[np.ndarray] = []
for obs_index, observation_specs in enumerate(behavior_spec.observation_specs):
is_visual = len(observation_specs.shape) == 3
for obs_index, observation_spec in enumerate(behavior_spec.observation_specs):
# Check that all the observations match the expected size.
# This gives a nicer error than a cryptic numpy error later.
expected_obs_shape = tuple(observation_spec.shape)
for agent_info in agent_info_list:
agent_obs_shape = tuple(agent_info.observations[obs_index].shape)
if expected_obs_shape != agent_obs_shape:
raise ValueError(
f"Observation at index={obs_index} for agent with "
f"id={agent_info.id} didn't match the ObservationSpec. "
f"Expected shape {expected_obs_shape} but got {agent_obs_shape}."
)

is_visual = len(observation_spec.shape) == 3
if is_visual:
obs_shape = cast(Tuple[int, int, int], observation_specs.shape)
obs_shape = cast(Tuple[int, int, int], observation_spec.shape)
decision_obs_list.append(
_process_maybe_compressed_observation(
obs_index, obs_shape, decision_agent_info_list
Expand All @@ -302,12 +314,12 @@ def steps_from_proto(
else:
decision_obs_list.append(
_process_rank_one_or_two_observation(
obs_index, observation_specs.shape, decision_agent_info_list
obs_index, observation_spec.shape, decision_agent_info_list
)
)
terminal_obs_list.append(
_process_rank_one_or_two_observation(
obs_index, observation_specs.shape, terminal_agent_info_list
obs_index, observation_spec.shape, terminal_agent_info_list
)
)
decision_rewards = np.array(
Expand Down
14 changes: 14 additions & 0 deletions ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,20 @@ def test_batched_step_result_from_proto():
assert terminal_steps.obs[1].shape[1] == shapes[1][0]


def test_mismatch_observations_raise_in_step_result_from_proto():
n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(
create_observation_specs_with_shapes(shapes), ActionSpec.create_continuous(3)
)
ap_list = generate_list_agent_proto(n_agents, shapes)
# Hack an observation to be larger, we should get an exception
ap_list[0].observations[0].shape[0] += 1
ap_list[0].observations[0].float_data.data.append(0.42)
with pytest.raises(ValueError):
steps_from_proto(ap_list, spec)


def test_action_masking_discrete():
n_agents = 10
shapes = [(3,), (4,)]
Expand Down