Skip to content

Commit 0de327c

Browse files
maryamhonaricmard
andauthored
Deterministic actions python training (#5619)
* Progress on propagating the setting to the action model. * Added the _sample_action logic and tests. * Add information to the changelog. * Prioritize the CLI over the configuration file. * Update documentation for config file. * CR refactor. * Update docs/Training-Configuration-File.md Co-authored-by: Miguel Alonso Jr. <[email protected]> Update com.unity.ml-agents/CHANGELOG.md Co-authored-by: Miguel Alonso Jr. <[email protected]> Update com.unity.ml-agents/CHANGELOG.md Co-authored-by: Miguel Alonso Jr. <[email protected]> Update com.unity.ml-agents/CHANGELOG.md Co-authored-by: Maryam Honari <[email protected]> Update ml-agents/mlagents/trainers/settings.py Co-authored-by: Maryam Honari <[email protected]> Update ml-agents/mlagents/trainers/cli_utils.py Co-authored-by: Maryam Honari <[email protected]> * Fix CR requests * Add tests for discrete. * Update ml-agents/mlagents/trainers/torch/distributions.py Co-authored-by: Maryam Honari <[email protected]> * Added more stable test. * Return deterministic actions for training (#5615) * Added more stable test. * Fix the tests. * Fix pre-commit * Fix help line to pass precommit. * support for deterministic inference in onnx (#5593) * Init: actor.forward outputs separate deterministic actions * changelog * Renaming * Add more tests * Package changes to support deterministic inference (#5599) * Init: actor.forward outputs separate deterministic actions * fix tensor shape for discrete actions * Add test and editor flag - Add tests for deterministic sampling - update editor and tooltips * Reverting to "Deterministic Inference" * dissect tests * Update docs * Update CHANGELOG.md Co-authored-by: Chingiz Mardanov <[email protected]> Co-authored-by: cmard <[email protected]>
1 parent 348bc9d commit 0de327c

29 files changed

+469
-66
lines changed

com.unity.ml-agents/CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ and this project adheres to
3030
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3131
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]
3232

33+
34+
- Deterministic action selection is now supported during training and inference(#5619)
35+
- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
36+
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5597)
37+
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5593)
38+
- Support inference with deterministic action selection in editor (#5599)
3339
### Bug Fixes
3440
- Fixed a bug where the critics were not being normalized during training. (#5595)
3541
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)

com.unity.ml-agents/Editor/BehaviorParametersEditor.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
2525
const string k_BrainParametersName = "m_BrainParameters";
2626
const string k_ModelName = "m_Model";
2727
const string k_InferenceDeviceName = "m_InferenceDevice";
28+
const string k_DeterministicInference = "m_DeterministicInference";
2829
const string k_BehaviorTypeName = "m_BehaviorType";
2930
const string k_TeamIdName = "TeamId";
3031
const string k_UseChildSensorsName = "m_UseChildSensors";
@@ -68,6 +69,7 @@ public override void OnInspectorGUI()
6869
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
6970
EditorGUI.indentLevel++;
7071
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
72+
EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true);
7173
EditorGUI.indentLevel--;
7274
}
7375
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
@@ -156,7 +158,7 @@ void DisplayFailedModelChecks()
156158
{
157159
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
158160
barracudaModel, brainParameters, sensors, actuatorComponents,
159-
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
161+
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference
160162
);
161163
foreach (var check in failedChecks)
162164
{

com.unity.ml-agents/Runtime/Academy.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -616,14 +616,16 @@ void EnvironmentReset()
616616
/// <param name="inferenceDevice">
617617
/// The inference device (CPU or GPU) the ModelRunner will use.
618618
/// </param>
619+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
620+
/// Deterministic. </param>
619621
/// <returns> The ModelRunner compatible with the input settings.</returns>
620622
internal ModelRunner GetOrCreateModelRunner(
621-
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice)
623+
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false)
622624
{
623625
var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice));
624626
if (modelRunner == null)
625627
{
626-
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed);
628+
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference);
627629
m_ModelRunners.Add(modelRunner);
628630
m_InferenceSeed++;
629631
}

com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs

+80-27
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model)
112112
/// <param name="model">
113113
/// The Barracuda engine model for loading static parameters.
114114
/// </param>
115+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
116+
/// deterministic. </param>
115117
/// <returns>Array of the output tensor names of the model</returns>
116-
public static string[] GetOutputNames(this Model model)
118+
public static string[] GetOutputNames(this Model model, bool deterministicInference = false)
117119
{
118120
var names = new List<string>();
119121

@@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model)
122124
return names.ToArray();
123125
}
124126

125-
if (model.HasContinuousOutputs())
127+
if (model.HasContinuousOutputs(deterministicInference))
126128
{
127-
names.Add(model.ContinuousOutputName());
129+
names.Add(model.ContinuousOutputName(deterministicInference));
128130
}
129-
if (model.HasDiscreteOutputs())
131+
if (model.HasDiscreteOutputs(deterministicInference))
130132
{
131-
names.Add(model.DiscreteOutputName());
133+
names.Add(model.DiscreteOutputName(deterministicInference));
132134
}
133135

134136
var modelVersion = model.GetVersion();
@@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model)
149151
/// <param name="model">
150152
/// The Barracuda engine model for loading static parameters.
151153
/// </param>
154+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
155+
/// deterministic. </param>
152156
/// <returns>True if the model has continuous action outputs.</returns>
153-
public static bool HasContinuousOutputs(this Model model)
157+
public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false)
154158
{
155159
if (model == null)
156160
return false;
@@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model)
160164
}
161165
else
162166
{
163-
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
164-
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
167+
bool hasStochasticOutput = !deterministicInference &&
168+
model.outputs.Contains(TensorNames.ContinuousActionOutput);
169+
bool hasDeterministicOutput = deterministicInference &&
170+
model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput);
171+
172+
return (hasStochasticOutput || hasDeterministicOutput) &&
173+
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
165174
}
166175
}
167176

@@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model)
194203
/// <param name="model">
195204
/// The Barracuda engine model for loading static parameters.
196205
/// </param>
206+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
207+
/// deterministic. </param>
197208
/// <returns>Tensor name of continuous action output.</returns>
198-
public static string ContinuousOutputName(this Model model)
209+
public static string ContinuousOutputName(this Model model, bool deterministicInference = false)
199210
{
200211
if (model == null)
201212
return null;
@@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model)
205216
}
206217
else
207218
{
208-
return TensorNames.ContinuousActionOutput;
219+
return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput;
209220
}
210221
}
211222

@@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model)
215226
/// <param name="model">
216227
/// The Barracuda engine model for loading static parameters.
217228
/// </param>
229+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
230+
/// deterministic. </param>
218231
/// <returns>True if the model has discrete action outputs.</returns>
219-
public static bool HasDiscreteOutputs(this Model model)
232+
public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false)
220233
{
221234
if (model == null)
222235
return false;
@@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model)
226239
}
227240
else
228241
{
229-
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0;
242+
bool hasStochasticOutput = !deterministicInference &&
243+
model.outputs.Contains(TensorNames.DiscreteActionOutput);
244+
bool hasDeterministicOutput = deterministicInference &&
245+
model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput);
246+
return (hasStochasticOutput || hasDeterministicOutput) &&
247+
model.DiscreteOutputSize() > 0;
230248
}
231249
}
232250

@@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model)
279297
/// <param name="model">
280298
/// The Barracuda engine model for loading static parameters.
281299
/// </param>
300+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
301+
/// deterministic. </param>
282302
/// <returns>Tensor name of discrete action output.</returns>
283-
public static string DiscreteOutputName(this Model model)
303+
public static string DiscreteOutputName(this Model model, bool deterministicInference = false)
284304
{
285305
if (model == null)
286306
return null;
@@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model)
290310
}
291311
else
292312
{
293-
return TensorNames.DiscreteActionOutput;
313+
return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput;
294314
}
295315
}
296316

@@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
316336
/// The Barracuda engine model for loading static parameters.
317337
/// </param>
318338
/// <param name="failedModelChecks">Output list of failure messages</param>
319-
///
339+
///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
340+
/// deterministic. </param>
320341
/// <returns>True if the model contains all the expected tensors.</returns>
321-
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
342+
/// TODO: add checks for deterministic actions
343+
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false)
322344
{
323345
// Check the presence of model version
324346
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
@@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
343365
// Check the presence of action output tensor
344366
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
345367
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
346-
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
368+
!model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
369+
!model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
370+
!model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
347371
{
348372
failedModelChecks.Add(
349373
FailedCheck.Warning("The model does not contain any Action Output Node.")
@@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
373397
}
374398
else
375399
{
376-
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
377-
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
400+
if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
378401
{
379-
failedModelChecks.Add(
380-
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
402+
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
403+
{
404+
failedModelChecks.Add(
405+
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
381406
);
382-
return false;
407+
return false;
408+
}
409+
410+
else if (!model.HasContinuousOutputs(deterministicInference))
411+
{
412+
var actionType = deterministicInference ? "deterministic" : "stochastic";
413+
var actionName = deterministicInference ? "Deterministic" : "";
414+
failedModelChecks.Add(
415+
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
416+
);
417+
return false;
418+
}
383419
}
384-
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
385-
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
420+
421+
if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
386422
{
387-
failedModelChecks.Add(
388-
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
423+
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
424+
{
425+
failedModelChecks.Add(
426+
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
389427
);
390-
return false;
428+
return false;
429+
}
430+
else if (!model.HasDiscreteOutputs(deterministicInference))
431+
{
432+
var actionType = deterministicInference ? "deterministic" : "stochastic";
433+
var actionName = deterministicInference ? "Deterministic" : "";
434+
failedModelChecks.Add(
435+
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
436+
);
437+
return false;
438+
}
439+
391440
}
441+
442+
443+
444+
392445
}
393446
return true;
394447
}

com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs

+16-8
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,17 @@ public static FailedCheck CheckModelVersion(Model model)
122122
/// <param name="actuatorComponents">Attached actuator components</param>
123123
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
124124
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
125+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
126+
/// deterministic. </param>
125127
/// <returns>A IEnumerable of the checks that failed</returns>
126128
public static IEnumerable<FailedCheck> CheckModel(
127129
Model model,
128130
BrainParameters brainParameters,
129131
ISensor[] sensors,
130132
ActuatorComponent[] actuatorComponents,
131133
int observableAttributeTotalSize = 0,
132-
BehaviorType behaviorType = BehaviorType.Default
134+
BehaviorType behaviorType = BehaviorType.Default,
135+
bool deterministicInference = false
133136
)
134137
{
135138
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
@@ -148,7 +151,7 @@ public static IEnumerable<FailedCheck> CheckModel(
148151
return failedModelChecks;
149152
}
150153

151-
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
154+
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference);
152155
if (!hasExpectedTensors)
153156
{
154157
return failedModelChecks;
@@ -181,7 +184,7 @@ public static IEnumerable<FailedCheck> CheckModel(
181184
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
182185
{
183186
failedModelChecks.AddRange(
184-
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
187+
CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference)
185188
);
186189
failedModelChecks.AddRange(
187190
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
@@ -195,7 +198,7 @@ public static IEnumerable<FailedCheck> CheckModel(
195198
);
196199

197200
failedModelChecks.AddRange(
198-
CheckOutputTensorPresence(model, memorySize)
201+
CheckOutputTensorPresence(model, memorySize, deterministicInference)
199202
);
200203
return failedModelChecks;
201204
}
@@ -318,14 +321,17 @@ ISensor[] sensors
318321
/// The memory size that the model is expecting.
319322
/// </param>
320323
/// <param name="sensors">Array of attached sensor components</param>
324+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
325+
/// Deterministic. </param>
321326
/// <returns>
322327
/// A IEnumerable of the checks that failed
323328
/// </returns>
324329
static IEnumerable<FailedCheck> CheckInputTensorPresence(
325330
Model model,
326331
BrainParameters brainParameters,
327332
int memory,
328-
ISensor[] sensors
333+
ISensor[] sensors,
334+
bool deterministicInference = false
329335
)
330336
{
331337
var failedModelChecks = new List<FailedCheck>();
@@ -356,7 +362,7 @@ ISensor[] sensors
356362
}
357363

358364
// If the model uses discrete control but does not have an input for action masks
359-
if (model.HasDiscreteOutputs())
365+
if (model.HasDiscreteOutputs(deterministicInference))
360366
{
361367
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
362368
{
@@ -376,17 +382,19 @@ ISensor[] sensors
376382
/// The Barracuda engine model for loading static parameters
377383
/// </param>
378384
/// <param name="memory">The memory size that the model is expecting/</param>
385+
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
386+
/// deterministic. </param>
379387
/// <returns>
380388
/// A IEnumerable of the checks that failed
381389
/// </returns>
382-
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
390+
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false)
383391
{
384392
var failedModelChecks = new List<FailedCheck>();
385393

386394
// If there is no Recurrent Output but the model is Recurrent.
387395
if (memory > 0)
388396
{
389-
var allOutputs = model.GetOutputNames().ToList();
397+
var allOutputs = model.GetOutputNames(deterministicInference).ToList();
390398
if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput))
391399
{
392400
failedModelChecks.Add(

0 commit comments

Comments
 (0)