Skip to content

Commit 6d609e2

Browse files
committed
Add test and editor flag
- Add tests for deterministic sampling - update editor and tooltips
1 parent a3c8857 commit 6d609e2

7 files changed

+115
-3
lines changed

com.unity.ml-agents/Editor/BehaviorParametersEditor.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
2525
const string k_BrainParametersName = "m_BrainParameters";
2626
const string k_ModelName = "m_Model";
2727
const string k_InferenceDeviceName = "m_InferenceDevice";
28+
const string k_StochasticInference = "m_stochasticInference";
2829
const string k_BehaviorTypeName = "m_BehaviorType";
2930
const string k_TeamIdName = "TeamId";
3031
const string k_UseChildSensorsName = "m_UseChildSensors";
@@ -68,6 +69,7 @@ public override void OnInspectorGUI()
6869
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
6970
EditorGUI.indentLevel++;
7071
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
72+
EditorGUILayout.PropertyField(so.FindProperty(k_StochasticInference), true);
7173
EditorGUI.indentLevel--;
7274
}
7375
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
@@ -156,7 +158,7 @@ void DisplayFailedModelChecks()
156158
{
157159
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
158160
barracudaModel, brainParameters, sensors, actuatorComponents,
159-
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
161+
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.StochasticInference
160162
);
161163
foreach (var check in failedChecks)
162164
{

com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
402402
{
403403
if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
404404
{
405-
if(model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
405+
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
406406
{
407407
failedModelChecks.Add(
408408
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
@@ -423,7 +423,7 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
423423

424424
if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
425425
{
426-
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null )
426+
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
427427
{
428428
failedModelChecks.Add(
429429
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")

com.unity.ml-agents/Tests/Editor/Inference/ModelRunnerTest.cs

+82
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Linq;
23
using NUnit.Framework;
34
using UnityEngine;
@@ -6,9 +7,31 @@
67
using Unity.MLAgents.Actuators;
78
using Unity.MLAgents.Inference;
89
using Unity.MLAgents.Policies;
10+
using System.Collections;
11+
using System.Collections.Generic;
12+
using UnityEngine.Assertions.Comparers;
913

1014
namespace Unity.MLAgents.Tests
1115
{
16+
public class FloatThresholdComparer : IEqualityComparer<float>
17+
{
18+
private readonly float _threshold;
19+
public FloatThresholdComparer(float threshold)
20+
{
21+
_threshold = threshold;
22+
}
23+
24+
public bool Equals(float x, float y)
25+
{
26+
return Math.Abs(x - y) < _threshold;
27+
}
28+
29+
public int GetHashCode(float f)
30+
{
31+
throw new NotImplementedException("Unable to generate a hash code for threshold floats, do not use this method");
32+
}
33+
}
34+
1235
[TestFixture]
1336
public class ModelRunnerTest
1437
{
@@ -19,13 +42,18 @@ public class ModelRunnerTest
1942
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx";
2043
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn";
2144
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn";
45+
// models with deterministic action tensors
46+
private const string k_deter_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx";
47+
private const string k_deter_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx";
2248

2349
NNModel hybridONNXModelV2;
2450
NNModel continuousONNXModel;
2551
NNModel discreteONNXModel;
2652
NNModel hybridONNXModel;
2753
NNModel continuousNNModel;
2854
NNModel discreteNNModel;
55+
NNModel deterDiscreteNNModel;
56+
NNModel deterContinuousNNModel;
2957
Test3DSensorComponent sensor_21_20_3;
3058
Test3DSensorComponent sensor_20_22_3;
3159

@@ -55,6 +83,8 @@ public void SetUp()
5583
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel));
5684
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel));
5785
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel));
86+
deterDiscreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deter_discreteNNPath, typeof(NNModel));
87+
deterContinuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_deter_continuousNNPath, typeof(NNModel));
5888
var go = new GameObject("SensorA");
5989
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>();
6090
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3);
@@ -71,6 +101,8 @@ public void TestModelExist()
71101
Assert.IsNotNull(continuousNNModel);
72102
Assert.IsNotNull(discreteNNModel);
73103
Assert.IsNotNull(hybridONNXModelV2);
104+
Assert.IsNotNull(deterDiscreteNNModel);
105+
Assert.IsNotNull(deterContinuousNNModel);
74106
}
75107

76108
[Test]
@@ -99,6 +131,15 @@ public void TestCreation()
99131
// This one was trained with 2.0 so it should not raise an error:
100132
modelRunner = new ModelRunner(hybridONNXModelV2, new ActionSpec(2, new[] { 2, 3 }), inferenceDevice);
101133
modelRunner.Dispose();
134+
135+
// V2.0 Model that has serialized deterministic action tensors, discrete
136+
modelRunner = new ModelRunner(deterDiscreteNNModel, new ActionSpec(0, new[] { 7 }), inferenceDevice);
137+
modelRunner.Dispose();
138+
// V2.0 Model that has serialized deterministic action tensors, continuous
139+
modelRunner = new ModelRunner(deterContinuousNNModel,
140+
GetContinuous2vis8vec2actionActionSpec(), inferenceDevice,
141+
stochasticInference: false);
142+
modelRunner.Dispose();
102143
}
103144

104145
[Test]
@@ -138,5 +179,46 @@ public void TestRunModel()
138179
Assert.AreEqual(actionSpec.NumDiscreteActions, modelRunner.GetAction(1).DiscreteActions.Length);
139180
modelRunner.Dispose();
140181
}
182+
183+
184+
[Test]
185+
public void TestRunModel_deterministic()
186+
{
187+
var actionSpec = GetContinuous2vis8vec2actionActionSpec();
188+
var modelRunner = new ModelRunner(deterContinuousNNModel, actionSpec, InferenceDevice.Burst);
189+
var sensor_8 = new Sensors.VectorSensor(8, "VectorSensor8");
190+
var info1 = new AgentInfo();
191+
var obs = new[]
192+
{
193+
sensor_8,
194+
sensor_21_20_3.CreateSensors()[0],
195+
sensor_20_22_3.CreateSensors()[0]
196+
}.ToList();
197+
info1.episodeId = 1;
198+
modelRunner.PutObservations(info1, obs);
199+
modelRunner.DecideBatch();
200+
var stochAction1 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone();
201+
202+
modelRunner.PutObservations(info1, obs);
203+
modelRunner.DecideBatch();
204+
var stochAction2 = (float[])modelRunner.GetAction(1).ContinuousActions.Array.Clone();
205+
// Stochastic action selection should output randomly different action values with same obs
206+
Assert.IsFalse(Enumerable.SequenceEqual(stochAction1, stochAction2, new FloatThresholdComparer(0.001f)));
207+
208+
209+
var deterModelRunner = new ModelRunner(deterContinuousNNModel, actionSpec, InferenceDevice.Burst,
210+
stochasticInference: false);
211+
info1.episodeId = 1;
212+
deterModelRunner.PutObservations(info1, obs);
213+
deterModelRunner.DecideBatch();
214+
var deterAction1 = (float[])deterModelRunner.GetAction(1).ContinuousActions.Array.Clone();
215+
216+
deterModelRunner.PutObservations(info1, obs);
217+
deterModelRunner.DecideBatch();
218+
var deterAction2 = (float[])deterModelRunner.GetAction(1).ContinuousActions.Array.Clone();
219+
// Deterministic action selection should output same action everytime
220+
Assert.IsTrue(Enumerable.SequenceEqual(deterAction1, deterAction2, new FloatThresholdComparer(0.001f)));
221+
modelRunner.Dispose();
222+
}
141223
}
142224
}

com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx.meta

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Binary file not shown.

com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx.meta

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)