diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs
index ab84f8a715..5f6e374e11 100644
--- a/src/Microsoft.ML/Models/CrossValidator.cs
+++ b/src/Microsoft.ML/Models/CrossValidator.cs
@@ -1,4 +1,8 @@
-using Microsoft.ML.Runtime;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
diff --git a/src/Microsoft.ML/Models/OneVersusAll.cs b/src/Microsoft.ML/Models/OneVersusAll.cs
new file mode 100644
index 0000000000..2f0a265dad
--- /dev/null
+++ b/src/Microsoft.ML/Models/OneVersusAll.cs
@@ -0,0 +1,74 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.EntryPoints;
+using static Microsoft.ML.Runtime.EntryPoints.CommonInputs;
+
+namespace Microsoft.ML.Models
+{
+ public sealed partial class OneVersusAll
+ {
+ ///
+ /// Create OneVersusAll multiclass trainer.
+ ///
+ /// Underlying binary trainer
+ /// "Use probabilities (vs. raw outputs) to identify top-score category
+ public static ILearningPipelineItem With(ITrainerInputWithLabel trainer, bool useProbabilities = true)
+ {
+ return new OvaPipelineItem(trainer, useProbabilities);
+ }
+
+ private class OvaPipelineItem : ILearningPipelineItem
+ {
+ private Var _data;
+ private ITrainerInputWithLabel _trainer;
+ private bool _useProbabilities;
+
+ public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities)
+ {
+ _trainer = trainer;
+ _useProbabilities = useProbabilities;
+ }
+
+ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)
+ {
+ using (var env = new TlcEnvironment())
+ {
+ var subgraph = env.CreateExperiment();
+ subgraph.Add(_trainer);
+ var ova = new OneVersusAll();
+ if (previousStep != null)
+ {
+ if (!(previousStep is ILearningPipelineDataStep dataStep))
+ {
+ throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input.");
+ }
+
+ _data = dataStep.Data;
+ ova.TrainingData = dataStep.Data;
+ ova.UseProbabilities = _useProbabilities;
+ ova.Nodes = subgraph;
+ }
+ Output output = experiment.Add(ova);
+ return new OvaPipelineStep(output);
+ }
+ }
+
+ public Var GetInputData() => _data;
+ }
+
+ private class OvaPipelineStep : ILearningPipelinePredictorStep
+ {
+ public OvaPipelineStep(Output output)
+ {
+ Model = output.PredictorModel;
+ }
+
+ public Var Model { get; }
+ }
+ }
+}
diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs
index ae00a34de6..e3de7a4e50 100644
--- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs
+++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs
@@ -1,4 +1,8 @@
-using Microsoft.ML.Runtime;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
@@ -110,7 +114,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TrainTestEvaluate TrainTestEvaluate predictor;
using (var memoryStream = new MemoryStream())
{
diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
index 6612dfea69..696fb9e92d 100644
--- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs
@@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest()
{
string dataPath = GetDataPath("iris.txt");
- var pipeline = new LearningPipeline(seed:1, conc:1);
+ var pipeline = new LearningPipeline(seed: 1, conc: 1);
pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false));
pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
@@ -33,7 +33,7 @@ public void TrainAndPredictIrisModelTest()
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
- PetalWidth= 5.1f,
+ PetalWidth = 5.1f,
});
Assert.Equal(1, prediction.PredictedLabels[0], 2);
@@ -136,6 +136,37 @@ public class IrisPrediction
[ColumnName("Score")]
public float[] PredictedLabels;
}
+
+ [Fact]
+ public void TrainOva()
+ {
+ string dataPath = GetDataPath("iris.txt");
+
+ var pipeline = new LearningPipeline(seed: 1, conc: 1);
+ pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false));
+ pipeline.Add(new ColumnConcatenator(outputColumn: "Features",
+ "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"));
+
+ pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier()));
+
+ var model = pipeline.Train();
+
+ var testData = new TextLoader(dataPath).CreateFrom(useHeader: false);
+ var evaluator = new ClassificationEvaluator();
+ ClassificationMetrics metrics = evaluator.Evaluate(model, testData);
+ CheckMetrics(metrics);
+
+ var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate(pipeline, testData);
+ CheckMetrics(trainTest.ClassificationMetrics);
+ }
+
+ private void CheckMetrics(ClassificationMetrics metrics)
+ {
+ Assert.Equal(.96, metrics.AccuracyMacro, 2);
+ Assert.Equal(.96, metrics.AccuracyMicro, 2);
+ Assert.Equal(.19, metrics.LogLoss, 1);
+ Assert.InRange(metrics.LogLossReduction, 80, 84);
+ }
}
}