Skip to content

Commit 83b1daf

Browse files
committed
Merge branch 'PCAEstimator' of https://github.com/sfilipi/machinelearning-1 into PCAEstimator
2 parents 6708f7f + 431cfae commit 83b1daf

File tree

3 files changed

+93
-9
lines changed

3 files changed

+93
-9
lines changed

src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,6 @@ private static ColumnType GetPredColType(ColumnType scoreType, ISchemaBoundRowMa
298298
}
299299

300300
private static bool OutputTypeMatches(ColumnType scoreType)
301-
=> scoreType == NumberType.Float;
301+
=> scoreType == NumberType.Float;
302302
}
303303
}

src/Microsoft.ML.Data/Scorers/PredictionTransformer.cs

+89-5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
[assembly: LoadableClass(typeof(RankingPredictionTransformer<IPredictorProducing<float>>), typeof(RankingPredictionTransformer), null, typeof(SignatureLoadModel),
2121
"", RankingPredictionTransformer.LoaderSignature)]
2222

23+
[assembly: LoadableClass(typeof(AnomalyPredictionTransformer<IPredictorProducing<float>>), typeof(AnomalyPredictionTransformer), null, typeof(SignatureLoadModel),
24+
"", AnomalyPredictionTransformer.LoaderSignature)]
25+
2326
namespace Microsoft.ML.Runtime.Data
2427
{
2528

@@ -174,8 +177,6 @@ public SingleFeaturePredictionTransformerBase(IHost host, TModel model, ISchema
174177
FeatureColumnType = trainSchema.GetColumnType(col);
175178

176179
BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
177-
178-
GetScorer();
179180
}
180181

181182
internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
@@ -221,13 +222,80 @@ protected virtual void SaveCore(ModelSaveContext ctx)
221222
ctx.SaveStringOrNull(FeatureColumn);
222223
}
223224

224-
protected virtual GenericScorer GetScorer()
225+
protected virtual GenericScorer GetGenericScorer()
225226
{
226227
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
227228
return new GenericScorer(Host, new GenericScorer.Arguments(), new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
228229
}
229230
}
230231

232+
/// <summary>
233+
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on anomaly detection tasks.
234+
/// </summary>
235+
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
236+
public sealed class AnomalyPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, BinaryClassifierScorer>
237+
where TModel : class, IPredictorProducing<float>
238+
{
239+
public readonly string ThresholdColumn;
240+
public readonly float Threshold;
241+
242+
public AnomalyPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn,
243+
float threshold = 0f, string thresholdColumn = DefaultColumnNames.Score)
244+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
245+
{
246+
Host.CheckNonEmpty(thresholdColumn, nameof(thresholdColumn));
247+
Threshold = threshold;
248+
ThresholdColumn = thresholdColumn;
249+
250+
SetScorer();
251+
}
252+
253+
public AnomalyPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
254+
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(BinaryPredictionTransformer<TModel>)), ctx)
255+
{
256+
// *** Binary format ***
257+
// <base info>
258+
// float: scorer threshold
259+
// id of string: scorer threshold column
260+
261+
Threshold = ctx.Reader.ReadSingle();
262+
ThresholdColumn = ctx.LoadString();
263+
SetScorer();
264+
}
265+
266+
private void SetScorer()
267+
{
268+
var schema = new RoleMappedSchema(TrainSchema, null, FeatureColumn);
269+
var args = new BinaryClassifierScorer.Arguments { Threshold = Threshold, ThresholdColumn = ThresholdColumn };
270+
Scorer = new BinaryClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
271+
}
272+
273+
protected override void SaveCore(ModelSaveContext ctx)
274+
{
275+
Contracts.AssertValue(ctx);
276+
ctx.SetVersionInfo(GetVersionInfo());
277+
278+
// *** Binary format ***
279+
// <base info>
280+
// float: scorer threshold
281+
// id of string: scorer threshold column
282+
base.SaveCore(ctx);
283+
284+
ctx.Writer.Write(Threshold);
285+
ctx.SaveString(ThresholdColumn);
286+
}
287+
288+
private static VersionInfo GetVersionInfo()
289+
{
290+
return new VersionInfo(
291+
modelSignature: "ANOMPRED",
292+
verWrittenCur: 0x00010001, // Initial
293+
verReadableCur: 0x00010001,
294+
verWeCanReadBack: 0x00010001,
295+
loaderSignature: AnomalyPredictionTransformer.LoaderSignature);
296+
}
297+
}
298+
231299
/// <summary>
232300
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on binary classification tasks.
233301
/// </summary>
@@ -367,11 +435,13 @@ public sealed class RegressionPredictionTransformer<TModel> : SingleFeaturePredi
367435
public RegressionPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
368436
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
369437
{
438+
Scorer = GetGenericScorer();
370439
}
371440

372441
internal RegressionPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
373442
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RegressionPredictionTransformer<TModel>)), ctx)
374443
{
444+
Scorer = GetGenericScorer();
375445
}
376446

377447
protected override void SaveCore(ModelSaveContext ctx)
@@ -387,7 +457,7 @@ protected override void SaveCore(ModelSaveContext ctx)
387457
private static VersionInfo GetVersionInfo()
388458
{
389459
return new VersionInfo(
390-
modelSignature: "MC PRED",
460+
modelSignature: "REG PRED",
391461
verWrittenCur: 0x00010001, // Initial
392462
verReadableCur: 0x00010001,
393463
verWeCanReadBack: 0x00010001,
@@ -396,17 +466,23 @@ private static VersionInfo GetVersionInfo()
396466
}
397467
}
398468

469+
/// <summary>
470+
/// Base class for the <see cref="ISingleFeaturePredictionTransformer{TModel}"/> working on ranking tasks.
471+
/// </summary>
472+
/// <typeparam name="TModel">An implementation of the <see cref="IPredictorProducing{TResult}"/></typeparam>
399473
public sealed class RankingPredictionTransformer<TModel> : SingleFeaturePredictionTransformerBase<TModel, GenericScorer>
400474
where TModel : class, IPredictorProducing<float>
401475
{
402476
public RankingPredictionTransformer(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn)
403477
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), model, inputSchema, featureColumn)
404478
{
479+
Scorer = GetGenericScorer();
405480
}
406481

407482
internal RankingPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
408483
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(RankingPredictionTransformer<TModel>)), ctx)
409484
{
485+
Scorer = GetGenericScorer();
410486
}
411487

412488
protected override void SaveCore(ModelSaveContext ctx)
@@ -422,7 +498,7 @@ protected override void SaveCore(ModelSaveContext ctx)
422498
private static VersionInfo GetVersionInfo()
423499
{
424500
return new VersionInfo(
425-
modelSignature: "MC RANK",
501+
modelSignature: "RANK PRED",
426502
verWrittenCur: 0x00010001, // Initial
427503
verReadableCur: 0x00010001,
428504
verWeCanReadBack: 0x00010001,
@@ -462,4 +538,12 @@ internal static class RankingPredictionTransformer
462538
public static RankingPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
463539
=> new RankingPredictionTransformer<IPredictorProducing<float>>(env, ctx);
464540
}
541+
542+
internal static class AnomalyPredictionTransformer
543+
{
544+
public const string LoaderSignature = "AnomalyPredXfer";
545+
546+
public static AnomalyPredictionTransformer<IPredictorProducing<float>> Create(IHostEnvironment env, ModelLoadContext ctx)
547+
=> new AnomalyPredictionTransformer<IPredictorProducing<float>>(env, ctx);
548+
}
465549
}

src/Microsoft.ML.PCA/PcaTrainer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace Microsoft.ML.Runtime.PCA
4242
/// <remarks>
4343
/// This PCA can be made into Kernel PCA by using Random Fourier Features transform
4444
/// </remarks>
45-
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<BinaryPredictionTransformer<PcaPredictor>, PcaPredictor>
45+
public sealed class RandomizedPcaTrainer : TrainerEstimatorBase<AnomalyPredictionTransformer<PcaPredictor>, PcaPredictor>
4646
{
4747
public const string LoadNameValue = "pcaAnomaly";
4848
internal const string UserNameValue = "PCA Anomaly Detector";
@@ -335,8 +335,8 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
335335
};
336336
}
337337

338-
protected override BinaryPredictionTransformer<PcaPredictor> MakeTransformer(PcaPredictor model, ISchema trainSchema)
339-
=> new BinaryPredictionTransformer<PcaPredictor>(Host, model, trainSchema, _featureColumn);
338+
protected override AnomalyPredictionTransformer<PcaPredictor> MakeTransformer(PcaPredictor model, ISchema trainSchema)
339+
=> new AnomalyPredictionTransformer<PcaPredictor>(Host, model, trainSchema, _featureColumn);
340340

341341
[TlcModule.EntryPoint(Name = "Trainers.PcaAnomalyDetector",
342342
Desc = "Train an PCA Anomaly model.",

0 commit comments

Comments
 (0)