Skip to content

support for deterministic inference in onnx #5593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to
1. env_params.max_lifetime_restarts (--max-lifetime-restarts) [default=10]
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597)
### Bug Fixes
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)

Expand Down
33 changes: 33 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,36 @@ def test_get_probs_and_entropy():

for ent, val in zip(entropies[0].tolist(), [1.4189, 0.6191, 0.6191]):
assert ent == pytest.approx(val, abs=0.01)


def test_get_onnx_deterministic_tensors():
inp_size = 4
act_size = 2
action_model, masks = create_action_model(inp_size, act_size)
sample_inp = torch.ones((1, inp_size))
out_tensors = action_model.get_action_out(sample_inp, masks=masks)
(
continuous_out,
discrete_out,
action_out_deprecated,
deterministic_continuous_out,
deterministic_discrete_out,
) = out_tensors
assert continuous_out.shape == (1, 2)
assert discrete_out.shape == (1, 2)
assert deterministic_discrete_out.shape == (1, 2)
assert deterministic_continuous_out.shape == (1, 2)

# Second sampling from same distribution
out_tensors2 = action_model.get_action_out(sample_inp, masks=masks)
(
continuous_out_2,
discrete_out_2,
action_out_2_deprecated,
deterministic_continuous_out_2,
deterministic_discrete_out_2,
) = out_tensors2
assert ~torch.all(torch.eq(continuous_out, continuous_out_2))
assert torch.all(
torch.eq(deterministic_continuous_out, deterministic_continuous_out_2)
)
28 changes: 25 additions & 3 deletions ml-agents/mlagents/trainers/torch/action_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
from mlagents_envs.base_env import ActionSpec


EPSILON = 1e-7 # Small value to avoid divide by zero


Expand Down Expand Up @@ -161,23 +162,44 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
"""
dists = self._get_dists(inputs, masks)
continuous_out, discrete_out, action_out_deprecated = None, None, None
deterministic_continuous_out, deterministic_discrete_out = (
None,
None,
) # deterministic actions
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
continuous_out = dists.continuous.exported_model_output()
action_out_deprecated = dists.continuous.exported_model_output()
action_out_deprecated = continuous_out
deterministic_continuous_out = dists.continuous.deterministic_sample()
if self._clip_action_on_export:
continuous_out = torch.clamp(continuous_out, -3, 3) / 3
action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3
action_out_deprecated = continuous_out
deterministic_continuous_out = (
torch.clamp(deterministic_continuous_out, -3, 3) / 3
)
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
discrete_out_list = [
discrete_dist.exported_model_output()
for discrete_dist in dists.discrete
]
discrete_out = torch.cat(discrete_out_list, dim=1)
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
deterministic_discrete_out_list = [
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
]
deterministic_discrete_out = torch.cat(
deterministic_discrete_out_list, dim=1
)

# deprecated action field does not support hybrid action
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
action_out_deprecated = None
return continuous_out, discrete_out, action_out_deprecated
return (
continuous_out,
discrete_out,
action_out_deprecated,
deterministic_continuous_out,
deterministic_discrete_out,
)

def forward(
self, inputs: torch.Tensor, masks: torch.Tensor
Expand Down
13 changes: 13 additions & 0 deletions ml-agents/mlagents/trainers/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ def sample(self) -> torch.Tensor:
"""
pass

@abc.abstractmethod
def deterministic_sample(self) -> torch.Tensor:
"""
Return the most probable sample from this distribution.
"""
pass

@abc.abstractmethod
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -59,6 +66,9 @@ def sample(self):
sample = self.mean + torch.randn_like(self.mean) * self.std
return sample

def deterministic_sample(self):
return self.mean

def log_prob(self, value):
var = self.std ** 2
log_scale = torch.log(self.std + EPSILON)
Expand Down Expand Up @@ -113,6 +123,9 @@ def __init__(self, logits):
def sample(self):
return torch.multinomial(self.probs, 1)

def deterministic_sample(self):
return torch.argmax(self.probs, dim=1, keepdim=True)

def pdf(self, value):
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
# but torch.diag is not supported by ONNX export.
Expand Down
5 changes: 5 additions & 0 deletions ml-agents/mlagents/trainers/torch/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,13 @@ class TensorNames:
recurrent_output = "recurrent_out"
memory_size = "memory_size"
version_number = "version_number"

continuous_action_output_shape = "continuous_action_output_shape"
discrete_action_output_shape = "discrete_action_output_shape"
continuous_action_output = "continuous_actions"
discrete_action_output = "discrete_actions"
deterministic_continuous_action_output = "deterministic_continuous_actions"
deterministic_discrete_action_output = "deterministic_discrete_actions"

# Deprecated TensorNames entries for backward compatibility
is_continuous_control_deprecated = "is_continuous_control"
Expand Down Expand Up @@ -122,6 +125,7 @@ def __init__(self, policy):
self.output_names += [
TensorNames.continuous_action_output,
TensorNames.continuous_action_output_shape,
TensorNames.deterministic_continuous_action_output,
]
self.dynamic_axes.update(
{TensorNames.continuous_action_output: {0: "batch"}}
Expand All @@ -130,6 +134,7 @@ def __init__(self, policy):
self.output_names += [
TensorNames.discrete_action_output,
TensorNames.discrete_action_output_shape,
TensorNames.deterministic_discrete_action_output,
]
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})

Expand Down
14 changes: 12 additions & 2 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,12 +675,22 @@ def forward(
cont_action_out,
disc_action_out,
action_out_deprecated,
deterministic_cont_action_out,
deterministic_disc_action_out,
) = self.action_model.get_action_out(encoding, masks)
export_out = [self.version_number, self.memory_size_vector]
if self.action_spec.continuous_size > 0:
export_out += [cont_action_out, self.continuous_act_size_vector]
export_out += [
cont_action_out,
self.continuous_act_size_vector,
deterministic_cont_action_out,
]
if self.action_spec.discrete_size > 0:
export_out += [disc_action_out, self.discrete_act_size_vector]
export_out += [
disc_action_out,
self.discrete_act_size_vector,
deterministic_disc_action_out,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)
Expand Down