diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index ab84f8a715..5f6e374e11 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; diff --git a/src/Microsoft.ML/Models/OneVersusAll.cs b/src/Microsoft.ML/Models/OneVersusAll.cs new file mode 100644 index 0000000000..2f0a265dad --- /dev/null +++ b/src/Microsoft.ML/Models/OneVersusAll.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using static Microsoft.ML.Runtime.EntryPoints.CommonInputs; + +namespace Microsoft.ML.Models +{ + public sealed partial class OneVersusAll + { + /// + /// Create OneVersusAll multiclass trainer. + /// + /// Underlying binary trainer + /// "Use probabilities (vs. raw outputs) to identify top-score category + public static ILearningPipelineItem With(ITrainerInputWithLabel trainer, bool useProbabilities = true) + { + return new OvaPipelineItem(trainer, useProbabilities); + } + + private class OvaPipelineItem : ILearningPipelineItem + { + private Var _data; + private ITrainerInputWithLabel _trainer; + private bool _useProbabilities; + + public OvaPipelineItem(ITrainerInputWithLabel trainer, bool useProbabilities) + { + _trainer = trainer; + _useProbabilities = useProbabilities; + } + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + using (var env = new TlcEnvironment()) + { + var subgraph = env.CreateExperiment(); + subgraph.Add(_trainer); + var ova = new OneVersusAll(); + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + _data = dataStep.Data; + ova.TrainingData = dataStep.Data; + ova.UseProbabilities = _useProbabilities; + ova.Nodes = subgraph; + } + Output output = experiment.Add(ova); + return new OvaPipelineStep(output); + } + } + + public Var GetInputData() => _data; + } + + private class OvaPipelineStep : ILearningPipelinePredictorStep + { + public OvaPipelineStep(Output output) + { + Model = output.PredictorModel; + } + + public Var Model { get; } + } + } +} diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index ae00a34de6..e3de7a4e50 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; @@ -110,7 +114,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TrainTestEvaluate TrainTestEvaluate predictor; using (var memoryStream = new MemoryStream()) { diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 6612dfea69..696fb9e92d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -18,7 +18,7 @@ public void TrainAndPredictIrisModelTest() { string dataPath = GetDataPath("iris.txt"); - var pipeline = new LearningPipeline(seed:1, conc:1); + var pipeline = new LearningPipeline(seed: 1, conc: 1); pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); pipeline.Add(new ColumnConcatenator(outputColumn: "Features", @@ -33,7 +33,7 @@ public void TrainAndPredictIrisModelTest() SepalLength = 3.3f, SepalWidth = 1.6f, PetalLength = 0.2f, - PetalWidth= 5.1f, + PetalWidth = 5.1f, }); Assert.Equal(1, prediction.PredictedLabels[0], 2); @@ -136,6 +136,37 @@ public class IrisPrediction [ColumnName("Score")] public float[] PredictedLabels; } + + [Fact] + public void TrainOva() + { + string dataPath = GetDataPath("iris.txt"); + + var pipeline = new LearningPipeline(seed: 1, conc: 1); + pipeline.Add(new TextLoader(dataPath).CreateFrom(useHeader: false)); + pipeline.Add(new ColumnConcatenator(outputColumn: "Features", + "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")); + + pipeline.Add(OneVersusAll.With(new StochasticDualCoordinateAscentBinaryClassifier())); + + var model = pipeline.Train(); + + var testData = new TextLoader(dataPath).CreateFrom(useHeader: false); + var evaluator = new ClassificationEvaluator(); + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + CheckMetrics(metrics); + + var trainTest = new TrainTestEvaluator() { Kind = MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer }.TrainTestEvaluate(pipeline, testData); + CheckMetrics(trainTest.ClassificationMetrics); + } + + private void CheckMetrics(ClassificationMetrics metrics) + { + Assert.Equal(.96, metrics.AccuracyMacro, 2); + Assert.Equal(.96, metrics.AccuracyMicro, 2); + Assert.Equal(.19, metrics.LogLoss, 1); + Assert.InRange(metrics.LogLossReduction, 80, 84); + } } }