@@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model)
112
112
/// <param name="model">
113
113
/// The Barracuda engine model for loading static parameters.
114
114
/// </param>
115
+ /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
116
+ /// deterministic. </param>
115
117
/// <returns>Array of the output tensor names of the model</returns>
116
- public static string [ ] GetOutputNames ( this Model model )
118
+ public static string [ ] GetOutputNames ( this Model model , bool deterministicInference = false )
117
119
{
118
120
var names = new List < string > ( ) ;
119
121
@@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model)
122
124
return names . ToArray ( ) ;
123
125
}
124
126
125
- if ( model . HasContinuousOutputs ( ) )
127
+ if ( model . HasContinuousOutputs ( deterministicInference ) )
126
128
{
127
- names . Add ( model . ContinuousOutputName ( ) ) ;
129
+ names . Add ( model . ContinuousOutputName ( deterministicInference ) ) ;
128
130
}
129
- if ( model . HasDiscreteOutputs ( ) )
131
+ if ( model . HasDiscreteOutputs ( deterministicInference ) )
130
132
{
131
- names . Add ( model . DiscreteOutputName ( ) ) ;
133
+ names . Add ( model . DiscreteOutputName ( deterministicInference ) ) ;
132
134
}
133
135
134
136
var modelVersion = model . GetVersion ( ) ;
@@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model)
149
151
/// <param name="model">
150
152
/// The Barracuda engine model for loading static parameters.
151
153
/// </param>
154
+ /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
155
+ /// deterministic. </param>
152
156
/// <returns>True if the model has continuous action outputs.</returns>
153
- public static bool HasContinuousOutputs ( this Model model )
157
+ public static bool HasContinuousOutputs ( this Model model , bool deterministicInference = false )
154
158
{
155
159
if ( model == null )
156
160
return false ;
@@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model)
160
164
}
161
165
else
162
166
{
163
- return model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
164
- ( int ) model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) [ 0 ] > 0 ;
167
+ bool hasStochasticOutput = ! deterministicInference &&
168
+ model . outputs . Contains ( TensorNames . ContinuousActionOutput ) ;
169
+ bool hasDeterministicOutput = deterministicInference &&
170
+ model . outputs . Contains ( TensorNames . DeterministicContinuousActionOutput ) ;
171
+
172
+ return ( hasStochasticOutput || hasDeterministicOutput ) &&
173
+ ( int ) model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) [ 0 ] > 0 ;
165
174
}
166
175
}
167
176
@@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model)
194
203
/// <param name="model">
195
204
/// The Barracuda engine model for loading static parameters.
196
205
/// </param>
206
+ /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
207
+ /// deterministic. </param>
197
208
/// <returns>Tensor name of continuous action output.</returns>
198
- public static string ContinuousOutputName ( this Model model )
209
+ public static string ContinuousOutputName ( this Model model , bool deterministicInference = false )
199
210
{
200
211
if ( model == null )
201
212
return null ;
@@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model)
205
216
}
206
217
else
207
218
{
208
- return TensorNames . ContinuousActionOutput ;
219
+ return deterministicInference ? TensorNames . DeterministicContinuousActionOutput : TensorNames . ContinuousActionOutput ;
209
220
}
210
221
}
211
222
@@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model)
215
226
/// <param name="model">
216
227
/// The Barracuda engine model for loading static parameters.
217
228
/// </param>
229
+ /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
230
+ /// deterministic. </param>
218
231
/// <returns>True if the model has discrete action outputs.</returns>
219
- public static bool HasDiscreteOutputs ( this Model model )
232
+ public static bool HasDiscreteOutputs ( this Model model , bool deterministicInference = false )
220
233
{
221
234
if ( model == null )
222
235
return false ;
@@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model)
226
239
}
227
240
else
228
241
{
229
- return model . outputs . Contains ( TensorNames . DiscreteActionOutput ) && model . DiscreteOutputSize ( ) > 0 ;
242
+ bool hasStochasticOutput = ! deterministicInference &&
243
+ model . outputs . Contains ( TensorNames . DiscreteActionOutput ) ;
244
+ bool hasDeterministicOutput = deterministicInference &&
245
+ model . outputs . Contains ( TensorNames . DeterministicDiscreteActionOutput ) ;
246
+ return ( hasStochasticOutput || hasDeterministicOutput ) &&
247
+ model . DiscreteOutputSize ( ) > 0 ;
230
248
}
231
249
}
232
250
@@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model)
279
297
/// <param name="model">
280
298
/// The Barracuda engine model for loading static parameters.
281
299
/// </param>
300
+ /// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
301
+ /// deterministic. </param>
282
302
/// <returns>Tensor name of discrete action output.</returns>
283
- public static string DiscreteOutputName ( this Model model )
303
+ public static string DiscreteOutputName ( this Model model , bool deterministicInference = false )
284
304
{
285
305
if ( model == null )
286
306
return null ;
@@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model)
290
310
}
291
311
else
292
312
{
293
- return TensorNames . DiscreteActionOutput ;
313
+ return deterministicInference ? TensorNames . DeterministicDiscreteActionOutput : TensorNames . DiscreteActionOutput ;
294
314
}
295
315
}
296
316
@@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
316
336
/// The Barracuda engine model for loading static parameters.
317
337
/// </param>
318
338
/// <param name="failedModelChecks">Output list of failure messages</param>
319
- ///
339
+ ///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
340
+ /// deterministic. </param>
320
341
/// <returns>True if the model contains all the expected tensors.</returns>
321
- public static bool CheckExpectedTensors ( this Model model , List < FailedCheck > failedModelChecks )
342
+ /// TODO: add checks for deterministic actions
343
+ public static bool CheckExpectedTensors ( this Model model , List < FailedCheck > failedModelChecks , bool deterministicInference = false )
322
344
{
323
345
// Check the presence of model version
324
346
var modelApiVersionTensor = model . GetTensorByName ( TensorNames . VersionNumber ) ;
@@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
343
365
// Check the presence of action output tensor
344
366
if ( ! model . outputs . Contains ( TensorNames . ActionOutputDeprecated ) &&
345
367
! model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
346
- ! model . outputs . Contains ( TensorNames . DiscreteActionOutput ) )
368
+ ! model . outputs . Contains ( TensorNames . DiscreteActionOutput ) &&
369
+ ! model . outputs . Contains ( TensorNames . DeterministicContinuousActionOutput ) &&
370
+ ! model . outputs . Contains ( TensorNames . DeterministicDiscreteActionOutput ) )
347
371
{
348
372
failedModelChecks . Add (
349
373
FailedCheck . Warning ( "The model does not contain any Action Output Node." )
@@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
373
397
}
374
398
else
375
399
{
376
- if ( model . outputs . Contains ( TensorNames . ContinuousActionOutput ) &&
377
- model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) == null )
400
+ if ( model . outputs . Contains ( TensorNames . ContinuousActionOutput ) )
378
401
{
379
- failedModelChecks . Add (
380
- FailedCheck . Warning ( "The model uses continuous action but does not contain Continuous Action Output Shape Node." )
402
+ if ( model . GetTensorByName ( TensorNames . ContinuousActionOutputShape ) == null )
403
+ {
404
+ failedModelChecks . Add (
405
+ FailedCheck . Warning ( "The model uses continuous action but does not contain Continuous Action Output Shape Node." )
381
406
) ;
382
- return false ;
407
+ return false ;
408
+ }
409
+
410
+ else if ( ! model . HasContinuousOutputs ( deterministicInference ) )
411
+ {
412
+ var actionType = deterministicInference ? "deterministic" : "stochastic" ;
413
+ var actionName = deterministicInference ? "Deterministic" : "" ;
414
+ failedModelChecks . Add (
415
+ FailedCheck . Warning ( $ "The model uses { actionType } inference but does not contain { actionName } Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
416
+ ) ;
417
+ return false ;
418
+ }
383
419
}
384
- if ( model . outputs . Contains ( TensorNames . DiscreteActionOutput ) &&
385
- model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) == null )
420
+
421
+ if ( model . outputs . Contains ( TensorNames . DiscreteActionOutput ) )
386
422
{
387
- failedModelChecks . Add (
388
- FailedCheck . Warning ( "The model uses discrete action but does not contain Discrete Action Output Shape Node." )
423
+ if ( model . GetTensorByName ( TensorNames . DiscreteActionOutputShape ) == null )
424
+ {
425
+ failedModelChecks . Add (
426
+ FailedCheck . Warning ( "The model uses discrete action but does not contain Discrete Action Output Shape Node." )
389
427
) ;
390
- return false ;
428
+ return false ;
429
+ }
430
+ else if ( ! model . HasDiscreteOutputs ( deterministicInference ) )
431
+ {
432
+ var actionType = deterministicInference ? "deterministic" : "stochastic" ;
433
+ var actionName = deterministicInference ? "Deterministic" : "" ;
434
+ failedModelChecks . Add (
435
+ FailedCheck . Warning ( $ "The model uses { actionType } inference but does not contain { actionName } Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
436
+ ) ;
437
+ return false ;
438
+ }
439
+
391
440
}
441
+
442
+
443
+
444
+
392
445
}
393
446
return true ;
394
447
}
0 commit comments