1
+ using System ;
1
2
using System . Linq ;
2
3
using NUnit . Framework ;
3
4
using UnityEngine ;
6
7
using Unity . MLAgents . Actuators ;
7
8
using Unity . MLAgents . Inference ;
8
9
using Unity . MLAgents . Policies ;
10
+ using System . Collections ;
11
+ using System . Collections . Generic ;
12
+ using UnityEngine . Assertions . Comparers ;
9
13
10
14
namespace Unity . MLAgents . Tests
11
15
{
16
+ public class FloatThresholdComparer : IEqualityComparer < float >
17
+ {
18
+ private readonly float _threshold ;
19
+ public FloatThresholdComparer ( float threshold )
20
+ {
21
+ _threshold = threshold ;
22
+ }
23
+
24
+ public bool Equals ( float x , float y )
25
+ {
26
+ return Math . Abs ( x - y ) < _threshold ;
27
+ }
28
+
29
+ public int GetHashCode ( float f )
30
+ {
31
+ throw new NotImplementedException ( "Unable to generate a hash code for threshold floats, do not use this method" ) ;
32
+ }
33
+ }
34
+
12
35
[ TestFixture ]
13
36
public class ModelRunnerTest
14
37
{
@@ -19,13 +42,18 @@ public class ModelRunnerTest
19
42
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction_v1_0.onnx" ;
20
43
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated_v1_0.nn" ;
21
44
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated_v1_0.nn" ;
45
+ // models with deterministic action tensors
46
+ private const string k_deter_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterDiscrete1obs3action_v2_0.onnx" ;
47
+ private const string k_deter_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/deterContinuous2vis8vec2action_v2_0.onnx" ;
22
48
23
49
NNModel hybridONNXModelV2 ;
24
50
NNModel continuousONNXModel ;
25
51
NNModel discreteONNXModel ;
26
52
NNModel hybridONNXModel ;
27
53
NNModel continuousNNModel ;
28
54
NNModel discreteNNModel ;
55
+ NNModel deterDiscreteNNModel ;
56
+ NNModel deterContinuousNNModel ;
29
57
Test3DSensorComponent sensor_21_20_3 ;
30
58
Test3DSensorComponent sensor_20_22_3 ;
31
59
@@ -55,6 +83,8 @@ public void SetUp()
55
83
hybridONNXModel = ( NNModel ) AssetDatabase . LoadAssetAtPath ( k_hybridONNXPath , typeof ( NNModel ) ) ;
56
84
continuousNNModel = ( NNModel ) AssetDatabase . LoadAssetAtPath ( k_continuousNNPath , typeof ( NNModel ) ) ;
57
85
discreteNNModel = ( NNModel ) AssetDatabase . LoadAssetAtPath ( k_discreteNNPath , typeof ( NNModel ) ) ;
86
+ deterDiscreteNNModel = ( NNModel ) AssetDatabase . LoadAssetAtPath ( k_deter_discreteNNPath , typeof ( NNModel ) ) ;
87
+ deterContinuousNNModel = ( NNModel ) AssetDatabase . LoadAssetAtPath ( k_deter_continuousNNPath , typeof ( NNModel ) ) ;
58
88
var go = new GameObject ( "SensorA" ) ;
59
89
sensor_21_20_3 = go . AddComponent < Test3DSensorComponent > ( ) ;
60
90
sensor_21_20_3 . Sensor = new Test3DSensor ( "SensorA" , 21 , 20 , 3 ) ;
@@ -71,6 +101,8 @@ public void TestModelExist()
71
101
Assert . IsNotNull ( continuousNNModel ) ;
72
102
Assert . IsNotNull ( discreteNNModel ) ;
73
103
Assert . IsNotNull ( hybridONNXModelV2 ) ;
104
+ Assert . IsNotNull ( deterDiscreteNNModel ) ;
105
+ Assert . IsNotNull ( deterContinuousNNModel ) ;
74
106
}
75
107
76
108
[ Test ]
@@ -99,6 +131,15 @@ public void TestCreation()
99
131
// This one was trained with 2.0 so it should not raise an error:
100
132
modelRunner = new ModelRunner ( hybridONNXModelV2 , new ActionSpec ( 2 , new [ ] { 2 , 3 } ) , inferenceDevice ) ;
101
133
modelRunner . Dispose ( ) ;
134
+
135
+ // V2.0 Model that has serialized deterministic action tensors, discrete
136
+ modelRunner = new ModelRunner ( deterDiscreteNNModel , new ActionSpec ( 0 , new [ ] { 7 } ) , inferenceDevice ) ;
137
+ modelRunner . Dispose ( ) ;
138
+ // V2.0 Model that has serialized deterministic action tensors, continuous
139
+ modelRunner = new ModelRunner ( deterContinuousNNModel ,
140
+ GetContinuous2vis8vec2actionActionSpec ( ) , inferenceDevice ,
141
+ stochasticInference : false ) ;
142
+ modelRunner . Dispose ( ) ;
102
143
}
103
144
104
145
[ Test ]
@@ -138,5 +179,46 @@ public void TestRunModel()
138
179
Assert . AreEqual ( actionSpec . NumDiscreteActions , modelRunner . GetAction ( 1 ) . DiscreteActions . Length ) ;
139
180
modelRunner . Dispose ( ) ;
140
181
}
182
+
183
+
184
+ [ Test ]
185
+ public void TestRunModel_deterministic ( )
186
+ {
187
+ var actionSpec = GetContinuous2vis8vec2actionActionSpec ( ) ;
188
+ var modelRunner = new ModelRunner ( deterContinuousNNModel , actionSpec , InferenceDevice . Burst ) ;
189
+ var sensor_8 = new Sensors . VectorSensor ( 8 , "VectorSensor8" ) ;
190
+ var info1 = new AgentInfo ( ) ;
191
+ var obs = new [ ]
192
+ {
193
+ sensor_8 ,
194
+ sensor_21_20_3 . CreateSensors ( ) [ 0 ] ,
195
+ sensor_20_22_3 . CreateSensors ( ) [ 0 ]
196
+ } . ToList ( ) ;
197
+ info1 . episodeId = 1 ;
198
+ modelRunner . PutObservations ( info1 , obs ) ;
199
+ modelRunner . DecideBatch ( ) ;
200
+ var stochAction1 = ( float [ ] ) modelRunner . GetAction ( 1 ) . ContinuousActions . Array . Clone ( ) ;
201
+
202
+ modelRunner . PutObservations ( info1 , obs ) ;
203
+ modelRunner . DecideBatch ( ) ;
204
+ var stochAction2 = ( float [ ] ) modelRunner . GetAction ( 1 ) . ContinuousActions . Array . Clone ( ) ;
205
+ // Stochastic action selection should output randomly different action values with same obs
206
+ Assert . IsFalse ( Enumerable . SequenceEqual ( stochAction1 , stochAction2 , new FloatThresholdComparer ( 0.001f ) ) ) ;
207
+
208
+
209
+ var deterModelRunner = new ModelRunner ( deterContinuousNNModel , actionSpec , InferenceDevice . Burst ,
210
+ stochasticInference : false ) ;
211
+ info1 . episodeId = 1 ;
212
+ deterModelRunner . PutObservations ( info1 , obs ) ;
213
+ deterModelRunner . DecideBatch ( ) ;
214
+ var deterAction1 = ( float [ ] ) deterModelRunner . GetAction ( 1 ) . ContinuousActions . Array . Clone ( ) ;
215
+
216
+ deterModelRunner . PutObservations ( info1 , obs ) ;
217
+ deterModelRunner . DecideBatch ( ) ;
218
+ var deterAction2 = ( float [ ] ) deterModelRunner . GetAction ( 1 ) . ContinuousActions . Array . Clone ( ) ;
219
+ // Deterministic action selection should output same action everytime
220
+ Assert . IsTrue ( Enumerable . SequenceEqual ( deterAction1 , deterAction2 , new FloatThresholdComparer ( 0.001f ) ) ) ;
221
+ modelRunner . Dispose ( ) ;
222
+ }
141
223
}
142
224
}
0 commit comments