Skip to content

Commit 215856d

Browse files
author
Pete Luferenko
committed
Added evaluation
1 parent 184027b commit 215856d

File tree

3 files changed

+68
-9
lines changed

3 files changed

+68
-9
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/Evaluation.cs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
using Microsoft.ML.Runtime.Data;
1+
using Microsoft.ML.Runtime.Api;
2+
using Microsoft.ML.Runtime.Data;
23
using Microsoft.ML.Runtime.Learners;
34
using Xunit;
45
using Microsoft.ML.Models;
6+
using System.Linq;
57

68
namespace Microsoft.ML.Tests.Scenarios.Api
79
{
@@ -46,5 +48,34 @@ public void Evaluation()
4648
var metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
4749
}
4850
}
51+
52+
/// <summary>
53+
/// Evaluation: Similar to the simple train scenario, except instead of having some
54+
/// predictive structure, be able to score another "test" data file, run the result
55+
/// through an evaluator and get metrics like AUC, accuracy, PR curves, and whatnot.
56+
/// Getting metrics out of this shoudl be as straightforward and unannoying as possible.
57+
/// </summary>
58+
[Fact]
59+
public void New_Evaluation()
60+
{
61+
var dataPath = GetDataPath(SentimentDataPath);
62+
var testDataPath = GetDataPath(SentimentTestPath);
63+
64+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
65+
{
66+
// Pipeline.
67+
var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
68+
.Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
69+
.Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
70+
71+
// Train.
72+
var model = pipeline.Fit(new MultiFileSource(dataPath));
73+
74+
// Evaluate on the test set.
75+
var dataEval = model.Read(new MultiFileSource(testDataPath));
76+
var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { });
77+
var metrics = evaluator.Evaluate(dataEval, labelColumn: "Label", probabilityColumn: "Probability");
78+
}
79+
}
4980
}
5081
}

test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ public void New_SimpleTrainAndPredict()
7171
{
7272
// Pipeline.
7373
var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
74-
.StartPipe() // Actually optional
7574
.Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
7675
.Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
7776

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.Core.Data;
2+
using Microsoft.ML.Models;
23
using Microsoft.ML.Runtime;
34
using Microsoft.ML.Runtime.Api;
45
using Microsoft.ML.Runtime.Data;
@@ -9,6 +10,8 @@
910
using System;
1011
using System.Collections.Generic;
1112
using System.IO;
13+
using System.Linq;
14+
1215
[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
1316
"Transform wrapper", TransformWrapper.LoaderSignature)]
1417
[assembly: LoadableClass(typeof(LoaderWrapper), null, typeof(SignatureLoadModel),
@@ -163,16 +166,16 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
163166
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
164167
}
165168

166-
public interface IPredictorTransformer<out TModel>: ITransformer
169+
public interface IPredictorTransformer<out TModel> : ITransformer
167170
{
168171
TModel TrainedModel { get; }
169172
}
170173

171-
public class ScorerWrapper<TModel>: TransformWrapper, IPredictorTransformer<TModel>
172-
where TModel: IPredictor
174+
public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TModel>
175+
where TModel : IPredictor
173176
{
174177
public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel)
175-
:base(env, scorer)
178+
: base(env, scorer)
176179
{
177180
Model = trainedModel;
178181
}
@@ -206,7 +209,7 @@ public SchemaShape GetOutputSchema()
206209
}
207210

208211
public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
209-
where TModel: IPredictor
212+
where TModel : IPredictor
210213
{
211214
protected readonly IHostEnvironment _env;
212215
private readonly string _featureCol;
@@ -298,12 +301,12 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
298301
public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
299302
}
300303

301-
public sealed class MyAveragedPerceptron: TrainerBase<IPredictor>
304+
public sealed class MyAveragedPerceptron : TrainerBase<IPredictor>
302305
{
303306
private readonly AveragedPerceptronTrainer _trainer;
304307

305308
public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol)
306-
:base(env, false, featureCol, labelCol)
309+
: base(env, false, featureCol, labelCol)
307310
{
308311
_trainer = new AveragedPerceptronTrainer(env, args);
309312
}
@@ -334,5 +337,31 @@ public TDst Predict(TSrc example)
334337
}
335338
}
336339

340+
public sealed class MyBinaryClassifierEvaluator
341+
{
342+
private readonly IHostEnvironment _env;
343+
private readonly BinaryClassifierEvaluator _evaluator;
344+
345+
public MyBinaryClassifierEvaluator(IHostEnvironment env, BinaryClassifierEvaluator.Arguments args)
346+
{
347+
_env = env;
348+
_evaluator = new BinaryClassifierEvaluator(env, args);
349+
}
350+
351+
public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn, string probabilityColumn)
352+
{
353+
var ci = EvaluateUtils.GetScoreColumnInfo(_env, data.Schema, null, "Score", MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
354+
var map = new KeyValuePair<RoleMappedSchema.ColumnRole, string>[]
355+
{
356+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probabilityColumn),
357+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, ci.Name)
358+
};
359+
var rmd = new RoleMappedData(data, labelColumn, "Features", opt: true, custom: map);
360+
361+
var metricsDict = _evaluator.Evaluate(rmd);
362+
return BinaryClassificationMetrics.FromMetrics(_env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"]).Single();
363+
}
364+
}
365+
337366

338367
}

0 commit comments

Comments
 (0)