20
20
[ assembly: LoadableClass ( typeof ( RankingPredictionTransformer < IPredictorProducing < float > > ) , typeof ( RankingPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
21
21
"" , RankingPredictionTransformer . LoaderSignature ) ]
22
22
23
+ [ assembly: LoadableClass ( typeof ( AnomalyPredictionTransformer < IPredictorProducing < float > > ) , typeof ( AnomalyPredictionTransformer ) , null , typeof ( SignatureLoadModel ) ,
24
+ "" , AnomalyPredictionTransformer . LoaderSignature ) ]
25
+
23
26
namespace Microsoft . ML . Runtime . Data
24
27
{
25
28
@@ -174,8 +177,6 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema
174
177
FeatureColumnType = trainSchema . GetColumnType ( col ) ;
175
178
176
179
BindableMapper = ScoreUtils . GetSchemaBindableMapper ( Host , model ) ;
177
-
178
- GetScorer ( ) ;
179
180
}
180
181
181
182
internal SingleFeaturePredictionTransformerBase ( IHost host , ModelLoadContext ctx )
@@ -221,13 +222,80 @@ protected virtual void SaveCore(ModelSaveContext ctx)
221
222
ctx . SaveStringOrNull ( FeatureColumn ) ;
222
223
}
223
224
224
- protected virtual GenericScorer GetScorer ( )
225
+ protected virtual GenericScorer GetGenericScorer ( )
225
226
{
226
227
var schema = new RoleMappedSchema ( TrainSchema , null , FeatureColumn ) ;
227
228
return new GenericScorer ( Host , new GenericScorer . Arguments ( ) , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
228
229
}
229
230
}
230
231
232
+ /// <summary>
233
+ /// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on anomaly detection tasks.
234
+ /// </summary>
235
+ /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
236
+ public sealed class AnomalyPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel , BinaryClassifierScorer >
237
+ where TModel : class , IPredictorProducing < float >
238
+ {
239
+ public readonly string ThresholdColumn ;
240
+ public readonly float Threshold ;
241
+
242
+ public AnomalyPredictionTransformer ( IHostEnvironment env , TModel model , ISchema inputSchema , string featureColumn ,
243
+ float threshold = 0f , string thresholdColumn = DefaultColumnNames . Score )
244
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( BinaryPredictionTransformer < TModel > ) ) , model , inputSchema , featureColumn )
245
+ {
246
+ Host . CheckNonEmpty ( thresholdColumn , nameof ( thresholdColumn ) ) ;
247
+ Threshold = threshold ;
248
+ ThresholdColumn = thresholdColumn ;
249
+
250
+ SetScorer ( ) ;
251
+ }
252
+
253
+ public AnomalyPredictionTransformer ( IHostEnvironment env , ModelLoadContext ctx )
254
+ : base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( BinaryPredictionTransformer < TModel > ) ) , ctx )
255
+ {
256
+ // *** Binary format ***
257
+ // <base info>
258
+ // float: scorer threshold
259
+ // id of string: scorer threshold column
260
+
261
+ Threshold = ctx . Reader . ReadSingle ( ) ;
262
+ ThresholdColumn = ctx . LoadString ( ) ;
263
+ SetScorer ( ) ;
264
+ }
265
+
266
+ private void SetScorer ( )
267
+ {
268
+ var schema = new RoleMappedSchema ( TrainSchema , null , FeatureColumn ) ;
269
+ var args = new BinaryClassifierScorer . Arguments { Threshold = Threshold , ThresholdColumn = ThresholdColumn } ;
270
+ Scorer = new BinaryClassifierScorer ( Host , args , new EmptyDataView ( Host , TrainSchema ) , BindableMapper . Bind ( Host , schema ) , schema ) ;
271
+ }
272
+
273
+ protected override void SaveCore ( ModelSaveContext ctx )
274
+ {
275
+ Contracts . AssertValue ( ctx ) ;
276
+ ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
277
+
278
+ // *** Binary format ***
279
+ // <base info>
280
+ // float: scorer threshold
281
+ // id of string: scorer threshold column
282
+ base . SaveCore ( ctx ) ;
283
+
284
+ ctx . Writer . Write ( Threshold ) ;
285
+ ctx . SaveString ( ThresholdColumn ) ;
286
+ }
287
+
288
+ private static VersionInfo GetVersionInfo ( )
289
+ {
290
+ return new VersionInfo (
291
+ modelSignature : "ANOMPRED" ,
292
+ verWrittenCur : 0x00010001 , // Initial
293
+ verReadableCur : 0x00010001 ,
294
+ verWeCanReadBack : 0x00010001 ,
295
+ loaderSignature : AnomalyPredictionTransformer . LoaderSignature ) ;
296
+ }
297
+ }
298
+
231
299
/// <summary>
232
300
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on binary classification tasks.
233
301
/// </summary>
@@ -367,11 +435,13 @@ public sealed class RegressionPredictionTransformer<TModel> : SingleFeaturePredi
367
435
public RegressionPredictionTransformer ( IHostEnvironment env , TModel model , ISchema inputSchema , string featureColumn )
368
436
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( RegressionPredictionTransformer < TModel > ) ) , model , inputSchema , featureColumn )
369
437
{
438
+ Scorer = GetGenericScorer ( ) ;
370
439
}
371
440
372
441
internal RegressionPredictionTransformer ( IHostEnvironment env , ModelLoadContext ctx )
373
442
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( RegressionPredictionTransformer < TModel > ) ) , ctx )
374
443
{
444
+ Scorer = GetGenericScorer ( ) ;
375
445
}
376
446
377
447
protected override void SaveCore ( ModelSaveContext ctx )
@@ -387,7 +457,7 @@ protected override void SaveCore(ModelSaveContext ctx)
387
457
private static VersionInfo GetVersionInfo ( )
388
458
{
389
459
return new VersionInfo (
390
- modelSignature : "MC PRED" ,
460
+ modelSignature : "REG PRED" ,
391
461
verWrittenCur : 0x00010001 , // Initial
392
462
verReadableCur : 0x00010001 ,
393
463
verWeCanReadBack : 0x00010001 ,
@@ -396,17 +466,23 @@ private static VersionInfo GetVersionInfo()
396
466
}
397
467
}
398
468
469
+ /// <summary>
470
+ /// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on ranking tasks.
471
+ /// </summary>
472
+ /// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
399
473
public sealed class RankingPredictionTransformer < TModel > : SingleFeaturePredictionTransformerBase < TModel , GenericScorer >
400
474
where TModel : class , IPredictorProducing < float >
401
475
{
402
476
public RankingPredictionTransformer ( IHostEnvironment env , TModel model , ISchema inputSchema , string featureColumn )
403
477
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( RankingPredictionTransformer < TModel > ) ) , model , inputSchema , featureColumn )
404
478
{
479
+ Scorer = GetGenericScorer ( ) ;
405
480
}
406
481
407
482
internal RankingPredictionTransformer ( IHostEnvironment env , ModelLoadContext ctx )
408
483
: base ( Contracts . CheckRef ( env , nameof ( env ) ) . Register ( nameof ( RankingPredictionTransformer < TModel > ) ) , ctx )
409
484
{
485
+ Scorer = GetGenericScorer ( ) ;
410
486
}
411
487
412
488
protected override void SaveCore ( ModelSaveContext ctx )
@@ -422,7 +498,7 @@ protected override void SaveCore(ModelSaveContext ctx)
422
498
private static VersionInfo GetVersionInfo ( )
423
499
{
424
500
return new VersionInfo (
425
- modelSignature : "MC RANK" ,
501
+ modelSignature : "RANK PRED " ,
426
502
verWrittenCur : 0x00010001 , // Initial
427
503
verReadableCur : 0x00010001 ,
428
504
verWeCanReadBack : 0x00010001 ,
@@ -462,4 +538,12 @@ internal static class RankingPredictionTransformer
462
538
public static RankingPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
463
539
=> new RankingPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
464
540
}
541
+
542
+ internal static class AnomalyPredictionTransformer
543
+ {
544
+ public const string LoaderSignature = "AnomalyPredXfer" ;
545
+
546
+ public static AnomalyPredictionTransformer < IPredictorProducing < float > > Create ( IHostEnvironment env , ModelLoadContext ctx )
547
+ => new AnomalyPredictionTransformer < IPredictorProducing < float > > ( env , ctx ) ;
548
+ }
465
549
}
0 commit comments