Skip to content

Commit e5de547

Browse files
authored
add pipelineitem for Ova (#363)
add pipelineitem for Ova
1 parent 8b01fc5 commit e5de547

File tree

4 files changed

+125
-12
lines changed

4 files changed

+125
-12
lines changed

src/Microsoft.ML/Models/CrossValidator.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
using Microsoft.ML.Runtime;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
26
using Microsoft.ML.Runtime.Api;
37
using Microsoft.ML.Runtime.Data;
48
using Microsoft.ML.Runtime.EntryPoints;
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Microsoft.ML.Runtime;
7+
using Microsoft.ML.Runtime.Data;
8+
using Microsoft.ML.Runtime.EntryPoints;
9+
using static Microsoft.ML.Runtime.EntryPoints.CommonInputs;
10+
11+
namespace Microsoft.ML.Models
12+
{
13+
public sealed partial class OneVersusAll
14+
{
15+
/// <summary>
16+
/// Create OneVersusAll multiclass trainer.
17+
/// </summary>
18+
/// <param name="trainer">Underlying binary trainer</param>
19+
/// <param name="useProbabilities">"Use probabilities (vs. raw outputs) to identify top-score category</param>
20+
public static ILearningPipelineItem With(ITrainerInputWithLabel trainer, bool useProbabilities = true)
21+
{
22+
return new OvaPipelineItem(trainer, useProbabilities);
23+
}
24+
25+
private class OvaPipelineItem : ILearningPipelineItem
26+
{
27+
private Var<IDataView> _data;
28+
private ITrainerInputWithLabel _trainer;
29+
private bool _useProbabilities;
30+
31+
public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities)
32+
{
33+
_trainer = trainer;
34+
_useProbabilities = useProbabilities;
35+
}
36+
37+
public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
38+
{
39+
using (var env = new TlcEnvironment())
40+
{
41+
var subgraph = env.CreateExperiment();
42+
subgraph.Add(_trainer);
43+
var ova = new OneVersusAll();
44+
if (previousStep != null)
45+
{
46+
if (!(previousStep is ILearningPipelineDataStep dataStep))
47+
{
48+
throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
49+
}
50+
51+
_data = dataStep.Data;
52+
ova.TrainingData = dataStep.Data;
53+
ova.UseProbabilities = _useProbabilities;
54+
ova.Nodes = subgraph;
55+
}
56+
Output output = experiment.Add(ova);
57+
return new OvaPipelineStep(output);
58+
}
59+
}
60+
61+
public Var<IDataView> GetInputData() => _data;
62+
}
63+
64+
private class OvaPipelineStep : ILearningPipelinePredictorStep
65+
{
66+
public OvaPipelineStep(Output output)
67+
{
68+
Model = output.PredictorModel;
69+
}
70+
71+
public Var<IPredictorModel> Model { get; }
72+
}
73+
}
74+
}

src/Microsoft.ML/Models/TrainTestEvaluator.cs

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
using Microsoft.ML.Runtime;
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
26
using Microsoft.ML.Runtime.Api;
37
using Microsoft.ML.Runtime.Data;
48
using Microsoft.ML.Runtime.EntryPoints;
@@ -110,7 +114,7 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
110114
Inputs.Data = firstTransform.GetInputData();
111115
Outputs.PredictorModel = null;
112116
Outputs.TransformModel = lastTransformModel;
113-
var crossValidateOutput = experiment.Add(this);
117+
var trainTestNodeOutput = experiment.Add(this);
114118
experiment.Compile();
115119
foreach (ILearningPipelineLoader loader in loaders)
116120
loader.SetInput(environment, experiment);
@@ -124,35 +128,35 @@ public TrainTestEvaluatorOutput<TInput, TOutput> TrainTestEvaluate<TInput, TOutp
124128
{
125129
trainTestOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics(
126130
environment,
127-
experiment.GetOutput(crossValidateOutput.OverallMetrics),
128-
experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault();
131+
experiment.GetOutput(trainTestNodeOutput.OverallMetrics),
132+
experiment.GetOutput(trainTestNodeOutput.ConfusionMatrix)).FirstOrDefault();
129133
}
130134
else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer)
131135
{
132136
trainTestOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics(
133137
environment,
134-
experiment.GetOutput(crossValidateOutput.OverallMetrics),
135-
experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault();
138+
experiment.GetOutput(trainTestNodeOutput.OverallMetrics),
139+
experiment.GetOutput(trainTestNodeOutput.ConfusionMatrix)).FirstOrDefault();
136140
}
137141
else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer)
138142
{
139143
trainTestOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics(
140144
environment,
141-
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
145+
experiment.GetOutput(trainTestNodeOutput.OverallMetrics)).FirstOrDefault();
142146
}
143147
else if (Kind == MacroUtilsTrainerKinds.SignatureClusteringTrainer)
144148
{
145149
trainTestOutput.ClusterMetrics = ClusterMetrics.FromOverallMetrics(
146150
environment,
147-
experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault();
151+
experiment.GetOutput(trainTestNodeOutput.OverallMetrics)).FirstOrDefault();
148152
}
149153
else
150154
{
151155
//Implement metrics for ranking, clustering and anomaly detection.
152156
throw Contracts.Except($"{Kind.ToString()} is not supported at the moment.");
153157
}
154158

155-
ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel);
159+
ITransformModel model = experiment.GetOutput(trainTestNodeOutput.TransformModel);
156160
BatchPredictionEngine<TInput, TOutput> predictor;
157161
using (var memoryStream = new MemoryStream())
158162
{

test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs

+33-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest()
1818
{
1919
string dataPath = GetDataPath("iris.txt");
2020

21-
var pipeline = new LearningPipeline(seed:1, conc:1);
21+
var pipeline = new LearningPipeline(seed: 1, conc: 1);
2222

2323
pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false));
2424
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
@@ -33,7 +33,7 @@ public void TrainAndPredictIrisModelTest()
3333
SepalLength = 3.3f,
3434
SepalWidth = 1.6f,
3535
PetalLength = 0.2f,
36-
PetalWidth= 5.1f,
36+
PetalWidth = 5.1f,
3737
});
3838

3939
Assert.Equal(1, prediction.PredictedLabels[0], 2);
@@ -136,6 +136,37 @@ public class IrisPrediction
136136
[ColumnName("Score")]
137137
public float[] PredictedLabels;
138138
}
139+
140+
[Fact]
141+
public void TrainOva()
142+
{
143+
string dataPath = GetDataPath("iris.txt");
144+
145+
var pipeline = new LearningPipeline(seed: 1, conc: 1);
146+
pipeline.Add(new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false));
147+
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
148+
"SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
149+
150+
pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier()));
151+
152+
var model = pipeline.Train<IrisData, IrisPrediction>();
153+
154+
var testData = new TextLoader(dataPath).CreateFrom<IrisData>(useHeader: false);
155+
var evaluator = new ClassificationEvaluator();
156+
ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
157+
CheckMetrics(metrics);
158+
159+
var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate<IrisData, IrisPrediction>(pipeline, testData);
160+
CheckMetrics(trainTest.ClassificationMetrics);
161+
}
162+
163+
private void CheckMetrics(ClassificationMetrics metrics)
164+
{
165+
Assert.Equal(.96, metrics.AccuracyMacro, 2);
166+
Assert.Equal(.96, metrics.AccuracyMicro, 2);
167+
Assert.Equal(.19, metrics.LogLoss, 1);
168+
Assert.InRange(metrics.LogLossReduction, 80, 84);
169+
}
139170
}
140171
}
141172

0 commit comments

Comments
 (0)