Skip to content

Commit 0f5cd2b

Browse files
authored
support for deterministic inference in onnx (#5593)
* Init: actor.forward outputs separate deterministic actions * changelog * Renaming * Add more tests
1 parent 176b268 commit 0f5cd2b

File tree

5 files changed

+76
-5
lines changed

5 files changed

+76
-5
lines changed

com.unity.ml-agents/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ and this project adheres to
3131

3232
- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
3333
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.
34+
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5597)
3435
### Bug Fixes
3536
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)
3637

ml-agents/mlagents/trainers/tests/torch/test_action_model.py

+33
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,36 @@ def test_get_probs_and_entropy():
120120

121121
for ent, val in zip(entropies[0].tolist(), [1.4189, 0.6191, 0.6191]):
122122
assert ent == pytest.approx(val, abs=0.01)
123+
124+
125+
def test_get_onnx_deterministic_tensors():
126+
inp_size = 4
127+
act_size = 2
128+
action_model, masks = create_action_model(inp_size, act_size)
129+
sample_inp = torch.ones((1, inp_size))
130+
out_tensors = action_model.get_action_out(sample_inp, masks=masks)
131+
(
132+
continuous_out,
133+
discrete_out,
134+
action_out_deprecated,
135+
deterministic_continuous_out,
136+
deterministic_discrete_out,
137+
) = out_tensors
138+
assert continuous_out.shape == (1, 2)
139+
assert discrete_out.shape == (1, 2)
140+
assert deterministic_discrete_out.shape == (1, 2)
141+
assert deterministic_continuous_out.shape == (1, 2)
142+
143+
# Second sampling from same distribution
144+
out_tensors2 = action_model.get_action_out(sample_inp, masks=masks)
145+
(
146+
continuous_out_2,
147+
discrete_out_2,
148+
action_out_2_deprecated,
149+
deterministic_continuous_out_2,
150+
deterministic_discrete_out_2,
151+
) = out_tensors2
152+
assert ~torch.all(torch.eq(continuous_out, continuous_out_2))
153+
assert torch.all(
154+
torch.eq(deterministic_continuous_out, deterministic_continuous_out_2)
155+
)

ml-agents/mlagents/trainers/torch/action_model.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mlagents.trainers.torch.action_log_probs import ActionLogProbs
1111
from mlagents_envs.base_env import ActionSpec
1212

13+
1314
EPSILON = 1e-7 # Small value to avoid divide by zero
1415

1516

@@ -173,23 +174,44 @@ def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Ten
173174
"""
174175
dists = self._get_dists(inputs, masks)
175176
continuous_out, discrete_out, action_out_deprecated = None, None, None
177+
deterministic_continuous_out, deterministic_discrete_out = (
178+
None,
179+
None,
180+
) # deterministic actions
176181
if self.action_spec.continuous_size > 0 and dists.continuous is not None:
177182
continuous_out = dists.continuous.exported_model_output()
178-
action_out_deprecated = dists.continuous.exported_model_output()
183+
action_out_deprecated = continuous_out
184+
deterministic_continuous_out = dists.continuous.deterministic_sample()
179185
if self._clip_action_on_export:
180186
continuous_out = torch.clamp(continuous_out, -3, 3) / 3
181-
action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3
187+
action_out_deprecated = continuous_out
188+
deterministic_continuous_out = (
189+
torch.clamp(deterministic_continuous_out, -3, 3) / 3
190+
)
182191
if self.action_spec.discrete_size > 0 and dists.discrete is not None:
183192
discrete_out_list = [
184193
discrete_dist.exported_model_output()
185194
for discrete_dist in dists.discrete
186195
]
187196
discrete_out = torch.cat(discrete_out_list, dim=1)
188197
action_out_deprecated = torch.cat(discrete_out_list, dim=1)
198+
deterministic_discrete_out_list = [
199+
discrete_dist.deterministic_sample() for discrete_dist in dists.discrete
200+
]
201+
deterministic_discrete_out = torch.cat(
202+
deterministic_discrete_out_list, dim=1
203+
)
204+
189205
# deprecated action field does not support hybrid action
190206
if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
191207
action_out_deprecated = None
192-
return continuous_out, discrete_out, action_out_deprecated
208+
return (
209+
continuous_out,
210+
discrete_out,
211+
action_out_deprecated,
212+
deterministic_continuous_out,
213+
deterministic_discrete_out,
214+
)
193215

194216
def forward(
195217
self, inputs: torch.Tensor, masks: torch.Tensor

ml-agents/mlagents/trainers/torch/model_serialization.py

+5
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@ class TensorNames:
5656
recurrent_output = "recurrent_out"
5757
memory_size = "memory_size"
5858
version_number = "version_number"
59+
5960
continuous_action_output_shape = "continuous_action_output_shape"
6061
discrete_action_output_shape = "discrete_action_output_shape"
6162
continuous_action_output = "continuous_actions"
6263
discrete_action_output = "discrete_actions"
64+
deterministic_continuous_action_output = "deterministic_continuous_actions"
65+
deterministic_discrete_action_output = "deterministic_discrete_actions"
6366

6467
# Deprecated TensorNames entries for backward compatibility
6568
is_continuous_control_deprecated = "is_continuous_control"
@@ -122,6 +125,7 @@ def __init__(self, policy):
122125
self.output_names += [
123126
TensorNames.continuous_action_output,
124127
TensorNames.continuous_action_output_shape,
128+
TensorNames.deterministic_continuous_action_output,
125129
]
126130
self.dynamic_axes.update(
127131
{TensorNames.continuous_action_output: {0: "batch"}}
@@ -130,6 +134,7 @@ def __init__(self, policy):
130134
self.output_names += [
131135
TensorNames.discrete_action_output,
132136
TensorNames.discrete_action_output_shape,
137+
TensorNames.deterministic_discrete_action_output,
133138
]
134139
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})
135140

ml-agents/mlagents/trainers/torch/networks.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,22 @@ def forward(
676676
cont_action_out,
677677
disc_action_out,
678678
action_out_deprecated,
679+
deterministic_cont_action_out,
680+
deterministic_disc_action_out,
679681
) = self.action_model.get_action_out(encoding, masks)
680682
export_out = [self.version_number, self.memory_size_vector]
681683
if self.action_spec.continuous_size > 0:
682-
export_out += [cont_action_out, self.continuous_act_size_vector]
684+
export_out += [
685+
cont_action_out,
686+
self.continuous_act_size_vector,
687+
deterministic_cont_action_out,
688+
]
683689
if self.action_spec.discrete_size > 0:
684-
export_out += [disc_action_out, self.discrete_act_size_vector]
690+
export_out += [
691+
disc_action_out,
692+
self.discrete_act_size_vector,
693+
deterministic_disc_action_out,
694+
]
685695
if self.network_body.memory_size > 0:
686696
export_out += [memories_out]
687697
return tuple(export_out)

0 commit comments

Comments
 (0)