diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index e5df21216b..613d652929 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to - Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can be achieved by adding `deterministic: true` under `network_settings` of the run options configuration. +- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597) ### Bug Fixes - Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py index a33f793927..5ffbd9ce6b 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_action_model.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_action_model.py @@ -120,3 +120,36 @@ def test_get_probs_and_entropy(): for ent, val in zip(entropies[0].tolist(), [1.4189, 0.6191, 0.6191]): assert ent == pytest.approx(val, abs=0.01) + + +def test_get_onnx_deterministic_tensors(): + inp_size = 4 + act_size = 2 + action_model, masks = create_action_model(inp_size, act_size) + sample_inp = torch.ones((1, inp_size)) + out_tensors = action_model.get_action_out(sample_inp, masks=masks) + ( + continuous_out, + discrete_out, + action_out_deprecated, + deterministic_continuous_out, + deterministic_discrete_out, + ) = out_tensors + assert continuous_out.shape == (1, 2) + assert discrete_out.shape == (1, 2) + assert deterministic_discrete_out.shape == (1, 2) + assert deterministic_continuous_out.shape == (1, 2) + + # Second sampling from same distribution + out_tensors2 = action_model.get_action_out(sample_inp, masks=masks) + ( + continuous_out_2, + discrete_out_2, + action_out_2_deprecated, + deterministic_continuous_out_2, + deterministic_discrete_out_2, + ) = out_tensors2 + assert ~torch.all(torch.eq(continuous_out, continuous_out_2)) + assert torch.all( + torch.eq(deterministic_continuous_out, deterministic_continuous_out_2) + ) diff --git a/ml-agents/mlagents/trainers/torch/action_model.py b/ml-agents/mlagents/trainers/torch/action_model.py index 8730e04255..65dfd32b40 100644 --- a/ml-agents/mlagents/trainers/torch/action_model.py +++ b/ml-agents/mlagents/trainers/torch/action_model.py @@ -10,6 +10,7 @@ from mlagents.trainers.torch.action_log_probs import ActionLogProbs from mlagents_envs.base_env import ActionSpec + EPSILON = 1e-7 # Small value to avoid divide by zero @@ -173,12 +174,20 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten """ dists = self._get_dists(inputs, masks) continuous_out, discrete_out, action_out_deprecated = None, None, None + deterministic_continuous_out, deterministic_discrete_out = ( + None, + None, + ) # deterministic actions if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() - action_out_deprecated = dists.continuous.exported_model_output() + action_out_deprecated = continuous_out + deterministic_continuous_out = dists.continuous.deterministic_sample() if self._clip_action_on_export: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 - action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 + action_out_deprecated = continuous_out + deterministic_continuous_out = ( + torch.clamp(deterministic_continuous_out, -3, 3) / 3 + ) if self.action_spec.discrete_size > 0 and dists.discrete is not None: discrete_out_list = [ discrete_dist.exported_model_output() @@ -186,10 +195,23 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten ] discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) + deterministic_discrete_out_list = [ + discrete_dist.deterministic_sample() for discrete_dist in dists.discrete + ] + deterministic_discrete_out = torch.cat( + deterministic_discrete_out_list, dim=1 + ) + # deprecated action field does not support hybrid action if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: action_out_deprecated = None - return continuous_out, discrete_out, action_out_deprecated + return ( + continuous_out, + discrete_out, + action_out_deprecated, + deterministic_continuous_out, + deterministic_discrete_out, + ) def forward( self, inputs: torch.Tensor, masks: torch.Tensor diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 0fa946280c..f204b52445 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -56,10 +56,13 @@ class TensorNames: recurrent_output = "recurrent_out" memory_size = "memory_size" version_number = "version_number" + continuous_action_output_shape = "continuous_action_output_shape" discrete_action_output_shape = "discrete_action_output_shape" continuous_action_output = "continuous_actions" discrete_action_output = "discrete_actions" + deterministic_continuous_action_output = "deterministic_continuous_actions" + deterministic_discrete_action_output = "deterministic_discrete_actions" # Deprecated TensorNames entries for backward compatibility is_continuous_control_deprecated = "is_continuous_control" @@ -122,6 +125,7 @@ def __init__(self, policy): self.output_names += [ TensorNames.continuous_action_output, TensorNames.continuous_action_output_shape, + TensorNames.deterministic_continuous_action_output, ] self.dynamic_axes.update( {TensorNames.continuous_action_output: {0: "batch"}} @@ -130,6 +134,7 @@ def __init__(self, policy): self.output_names += [ TensorNames.discrete_action_output, TensorNames.discrete_action_output_shape, + TensorNames.deterministic_discrete_action_output, ] self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}}) diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 19b97e2860..be8fb4b732 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -676,12 +676,22 @@ def forward( cont_action_out, disc_action_out, action_out_deprecated, + deterministic_cont_action_out, + deterministic_disc_action_out, ) = self.action_model.get_action_out(encoding, masks) export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: - export_out += [cont_action_out, self.continuous_act_size_vector] + export_out += [ + cont_action_out, + self.continuous_act_size_vector, + deterministic_cont_action_out, + ] if self.action_spec.discrete_size > 0: - export_out += [disc_action_out, self.discrete_act_size_vector] + export_out += [ + disc_action_out, + self.discrete_act_size_vector, + deterministic_disc_action_out, + ] if self.network_body.memory_size > 0: export_out += [memories_out] return tuple(export_out)