From 31459710b9589a3a1025b43621ebaf2bb04950b7 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 18 May 2018 11:51:43 -0700 Subject: [PATCH 01/18] initial commit. --- src/Microsoft.ML/CSharpApi.cs | 4 +- .../Models/BinaryCrossValidator.cs | 124 ++++++++++++++++++ .../Scenarios/SentimentPredictionTests.cs | 5 +- 3 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 src/Microsoft.ML/Models/BinaryCrossValidator.cs diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 2ca1af7159..8bcc2d00e4 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -12028,10 +12028,10 @@ public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Exper { if (!(previousStep is ILearningPipelineDataStep dataStep)) { - throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + // throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); } - Data = dataStep.Data; + //Data = dataStep.Data; Output output = experiment.Add(this); return new TextFeaturizerPipelineStep(output); } diff --git a/src/Microsoft.ML/Models/BinaryCrossValidator.cs b/src/Microsoft.ML/Models/BinaryCrossValidator.cs new file mode 100644 index 0000000000..a90fb313cc --- /dev/null +++ b/src/Microsoft.ML/Models/BinaryCrossValidator.cs @@ -0,0 +1,124 @@ +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Microsoft.ML.Models +{ + public sealed partial class BinaryCrossValidator + { + public void CrossValidate(LearningPipeline pipeline) + where TInput : class + where TOutput : class, new() + { + using (var environment = new TlcEnvironment()) + { + Experiment subGraph = environment.CreateExperiment(); + ILearningPipelineStep step = null; + List loaders = new List(); + List> transformModels = new List>(); + Var lastTransformModel = null; + Var firstInput = null; + Var firstModel = null; + foreach (ILearningPipelineItem currentItem in pipeline) + { + if (currentItem is ILearningPipelineLoader loader) + { + loaders.Add(loader); + continue; + } + + step = currentItem.ApplyStep(step, subGraph); + + if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) + { + transformModels.Add(dataStep.Model); + if (firstInput == null) + firstInput = dataStep.Data; + } + + else if (step is ILearningPipelinePredictorStep predictorDataStep) + { + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + + Var predictorModel; + if (transformModels.Count != 0) + { + var localModelInput = new Transforms.ManyHeterogeneousModelCombiner + { + PredictorModel = predictorDataStep.Model, + TransformModels = new ArrayVar(transformModels.ToArray()) + }; + var localModelOutput = subGraph.Add(localModelInput); + predictorModel = localModelOutput.PredictorModel; + } + else + predictorModel = predictorDataStep.Model; + + var scorer = new Transforms.Scorer + { + PredictorModel = predictorModel + }; + firstModel = predictorModel; + var scorerOutput = subGraph.Add(scorer); + lastTransformModel = scorerOutput.ScoringTransform; + step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); + transformModels.Clear(); + } + } + + if (transformModels.Count > 0) + { + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + + var modelInput = new Transforms.ModelCombiner + { + Models = new ArrayVar(transformModels.ToArray()) + }; + + var modelOutput = subGraph.Add(modelInput); + lastTransformModel = modelOutput.OutputModel; + } + + var experiment = environment.CreateExperiment(); + + var importTextOutput = loaders[0].ApplyStep(null, experiment); + + var crossValidateBinary = new ML.Models.BinaryCrossValidator + { + Data = (importTextOutput as ILearningPipelineDataStep).Data, + Nodes = subGraph + }; + crossValidateBinary.Inputs.Data = firstInput; + crossValidateBinary.Outputs.Model = firstModel; + var crossValidateOutput = experiment.Add(crossValidateBinary); + + experiment.Compile(); + foreach (ILearningPipelineLoader loader in loaders) + { + loader.SetInput(environment, experiment); + } + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + ITransformModel model = experiment.GetOutput(lastTransformModel); + BatchPredictionEngine predictor; + using (var memoryStream = new MemoryStream()) + { + model.Save(environment, memoryStream); + + memoryStream.Position = 0; + + predictor = environment.CreateBatchPredictionEngine(memoryStream); + + //return new PredictionModel(predictor, memoryStream); + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 608cbef144..af742a3156 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -37,8 +37,9 @@ public void TrainAndPredictSentimentModelTest() }); pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - + //pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + BinaryCrossValidator bcv = new BinaryCrossValidator(); + bcv.CrossValidate(pipeline); PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] From f3e42ef75f37c474c7326d7529b4d88db3af496c Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 21 May 2018 19:10:17 -0700 Subject: [PATCH 02/18] changes. --- src/Microsoft.ML.Core/Data/ITransformModel.cs | 2 + .../EntryPoints/TransformModel.cs | 5 + src/Microsoft.ML/CSharpApi.cs | 1356 ++++++++++++----- src/Microsoft.ML/Data/CollectionDataSource.cs | 2 + src/Microsoft.ML/ILearningPipelineItem.cs | 1 + ...aryCrossValidator.cs => CrossValidator.cs} | 45 +- .../EntryPoints/CrossValidationBinaryMacro.cs | 23 + .../EntryPoints/CrossValidationMacro.cs | 129 +- .../Runtime/EntryPoints/ModelOperations.cs | 5 +- .../Runtime/EntryPoints/TrainTestMacro.cs | 114 +- .../Runtime/Experiment/Experiment.cs | 1 - .../Internal/Tools/CSharpApiGenerator.cs | 13 + src/Microsoft.ML/TextLoader.cs | 4 +- test/Microsoft.ML.Tests/CSharpCodeGen.cs | 3 +- .../Scenarios/SentimentPredictionTests.cs | 15 +- 15 files changed, 1248 insertions(+), 470 deletions(-) rename src/Microsoft.ML/Models/{BinaryCrossValidator.cs => CrossValidator.cs} (76%) diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ccb65d43ab..ec249ce768 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -25,6 +25,8 @@ public interface ITransformModel /// ISchema InputSchema { get; } + IDataView Data { get; } + /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index b840529e77..acbce34b24 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -44,6 +44,11 @@ public ISchema InputSchema get { return _schemaRoot; } } + public IDataView Data + { + get { return _chain; } + } + /// /// Create a TransformModel containing the transforms from "result" back to "input". /// diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 8bcc2d00e4..389c9dd6e5 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -22,6 +22,18 @@ namespace Runtime { public sealed partial class Experiment { + public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) + { + var output = new Microsoft.ML.Data.DataViewReference.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) + { + _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + } + public Microsoft.ML.Data.IDataViewArrayConverter.Output Add(Microsoft.ML.Data.IDataViewArrayConverter input) { var output = new Microsoft.ML.Data.IDataViewArrayConverter.Output(); @@ -53,22 +65,23 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu return output; } - public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) + public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output) { - var output = new Microsoft.ML.Data.DataViewReference.Output(); - Add(input, output); - return output; + _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); } - public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output) + public Microsoft.ML.Data.TransformModelArrayConverter.Output Add(Microsoft.ML.Data.TransformModelArrayConverter input) { - _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); + var output = new Microsoft.ML.Data.TransformModelArrayConverter.Output(); + Add(input, output); + return output; } - public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) + public void Add(Microsoft.ML.Data.TransformModelArrayConverter input, Microsoft.ML.Data.TransformModelArrayConverter.Output output) { - _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + _jsonNodes.Add(Serialize("Data.TransformModelArrayConverter", input, output)); } + public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input) { var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output(); @@ -453,6 +466,18 @@ public void Add(Microsoft.ML.Trainers.GeneralizedAdditiveModelRegressor input, M _jsonNodes.Add(Serialize("Trainers.GeneralizedAdditiveModelRegressor", input, output)); } + public Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output Add(Microsoft.ML.Trainers.KMeansPlusPlusClusterer input) + { + var output = new Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Trainers.KMeansPlusPlusClusterer input, Microsoft.ML.Trainers.KMeansPlusPlusClusterer.Output output) + { + _jsonNodes.Add(Serialize("Trainers.KMeansPlusPlusClusterer", input, output)); + } + public Microsoft.ML.Trainers.LinearSvmBinaryClassifier.Output Add(Microsoft.ML.Trainers.LinearSvmBinaryClassifier input) { var output = new Microsoft.ML.Trainers.LinearSvmBinaryClassifier.Output(); @@ -1271,6 +1296,33 @@ public void Add(Microsoft.ML.Transforms.WordTokenizer input, Microsoft.ML.Transf } } + namespace Data + { + + /// + /// Pass dataview from memory to experiment + /// + public sealed partial class DataViewReference + { + + + /// + /// Pointer to IDataView in memory + /// + public Var Data { get; set; } = new Var(); + + + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + } + namespace Data { @@ -1355,20 +1407,30 @@ public sealed class Output } } + } - public sealed partial class DataViewReference + namespace Data + { + + /// + /// Create and array variable + /// + public sealed partial class TransformModelArrayConverter { + + /// - /// Location of the input file + /// The models /// - public Var Data { get; set; } = new Var(); + public ArrayVar TransformModel { get; set; } = new ArrayVar(); + public sealed class Output { /// - /// The resulting data view + /// The model array /// - public Var Data { get; set; } = new Var(); + public ArrayVar OutputModel { get; set; } = new ArrayVar(); } } @@ -1842,6 +1904,21 @@ public sealed class CrossValidationMacroSubGraphOutput /// public Var Model { get; set; } = new Var(); + /// + /// The transform model + /// + public Var TransformModel { get; set; } = new Var(); + + /// + /// Transform data + /// + public Var TransformData { get; set; } = new Var(); + + /// + /// Indicates to use transform model instead of predictor model. + /// + public bool UseTransformModel { get; set; } = false; + } /// @@ -1899,6 +1976,11 @@ public sealed class Output /// public ArrayVar PredictorModel { get; set; } = new ArrayVar(); + /// + /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. + /// + public ArrayVar TransformModel { get; set; } = new ArrayVar(); + /// /// Warning dataset /// @@ -1994,14 +2076,19 @@ public sealed class Output public Var OutputData { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(DatasetTransformer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(DatasetTransformer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new DatasetTransformerPipelineStep(output); } @@ -2064,14 +2151,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICal public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FixedPlattCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FixedPlattCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new FixedPlattCalibratorPipelineStep(output); } @@ -2196,14 +2288,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICal public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(NaiveCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(NaiveCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new NaiveCalibratorPipelineStep(output); } @@ -2309,14 +2406,19 @@ public sealed class Output public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OneVersusAll)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new OneVersusAllPipelineStep(output); } @@ -2392,14 +2494,19 @@ public sealed class Output public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(OvaModelCombiner)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OvaModelCombiner)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new OvaModelCombinerPipelineStep(output); } @@ -2451,14 +2558,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICal public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PAVCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PAVCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new PAVCalibratorPipelineStep(output); } @@ -2568,14 +2680,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICal public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PlattCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PlattCalibrator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new PlattCalibratorPipelineStep(output); } @@ -2978,6 +3095,21 @@ public sealed class TrainTestMacroSubGraphOutput /// public Var Model { get; set; } = new Var(); + /// + /// Transform model + /// + public Var TransformModel { get; set; } = new Var(); + + /// + /// Transform data + /// + public Var TransformData { get; set; } = new Var(); + + /// + /// Indicates to use transform model instead of predictor model. + /// + public bool UseTransformModel { get; set; } = false; + } /// @@ -3040,6 +3172,11 @@ public sealed class Output /// public Var PredictorModel { get; set; } = new Var(); + /// + /// The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel. + /// + public Var TransformModel { get; set; } = new Var(); + /// /// Warning dataset /// @@ -3221,14 +3358,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(AveragedPerceptronBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(AveragedPerceptronBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new AveragedPerceptronBinaryClassifierPipelineStep(output); } @@ -3516,14 +3658,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastForestBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastForestBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastForestBinaryClassifierPipelineStep(output); } @@ -3793,14 +3940,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastForestRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastForestRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastForestRegressorPipelineStep(output); } @@ -4186,14 +4338,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastTreeBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastTreeBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastTreeBinaryClassifierPipelineStep(output); } @@ -4607,14 +4764,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IRan public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastTreeRanker)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastTreeRanker)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastTreeRankerPipelineStep(output); } @@ -4988,14 +5150,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastTreeRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastTreeRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastTreeRegressorPipelineStep(output); } @@ -5374,14 +5541,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FastTreeTweedieRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FastTreeTweedieRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new FastTreeTweedieRegressorPipelineStep(output); } @@ -5526,14 +5698,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(GeneralizedAdditiveModelBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(GeneralizedAdditiveModelBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new GeneralizedAdditiveModelBinaryClassifierPipelineStep(output); } @@ -5662,14 +5839,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(GeneralizedAdditiveModelRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(GeneralizedAdditiveModelRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new GeneralizedAdditiveModelRegressorPipelineStep(output); } @@ -5688,80 +5870,51 @@ public GeneralizedAdditiveModelRegressorPipelineStep(Output output) namespace Trainers { + public enum KMeansPlusPlusTrainerInitAlgorithm + { + KMeansPlusPlus = 0, + Random = 1, + KMeansParallel = 2 + } + /// - /// Train a linear SVM. + /// K-means is a popular clustering algorithm. With K-means, the data is clustered into a specified number of clusters in order to minimize the within-cluster sum of squares. K-means++ improves upon K-means by using a better method for choosing the initial cluster centers. /// - public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class KMeansPlusPlusClusterer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { /// - /// Regularizer constant - /// - [TlcModule.SweepableFloatParamAttribute("Lambda", 1E-05f, 0.1f, stepSize:10, isLogScale:true)] - public float Lambda { get; set; } = 0.001f; - - /// - /// Batch size - /// - public int BatchSize { get; set; } = 1; - - /// - /// Perform projection to unit-ball? Typically used with batch size > 1. - /// - [TlcModule.SweepableDiscreteParamAttribute("PerformProjection", new object[]{false, true})] - public bool PerformProjection { get; set; } = false; - - /// - /// No bias - /// - [TlcModule.SweepableDiscreteParamAttribute("NoBias", new object[]{false, true})] - public bool NoBias { get; set; } = false; - - /// - /// The calibrator kind to apply to the predictor. Specify null for no calibration - /// - [JsonConverter(typeof(ComponentSerializer))] - public CalibratorTrainer Calibrator { get; set; } = new PlattCalibratorCalibratorTrainer(); - - /// - /// The maximum number of examples to use when training the calibrator - /// - public int MaxCalibrationExamples { get; set; } = 1000000; - - /// - /// Number of iterations + /// The number of clusters /// - [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize:10, isLogScale:true)] - public int NumIterations { get; set; } = 1; + [TlcModule.SweepableDiscreteParamAttribute("K", new object[]{5, 10, 20, 40})] + public int K { get; set; } = 5; /// - /// Initial Weights and bias, comma-separated + /// Cluster initialization algorithm /// - public string InitialWeights { get; set; } + public Trainers.KMeansPlusPlusTrainerInitAlgorithm InitAlgorithm { get; set; } = Trainers.KMeansPlusPlusTrainerInitAlgorithm.KMeansParallel; /// - /// Init weights diameter + /// Tolerance parameter for trainer convergence. Lower = slower, more accurate /// - [TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0f, 1f, numSteps:5)] - public float InitWtsDiameter { get; set; } + public float OptTol { get; set; } = 1E-07f; /// - /// Whether to shuffle for each training iteration + /// Maximum number of iterations. /// - [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[]{false, true})] - public bool Shuffle { get; set; } = true; + public int MaxIterations { get; set; } = 1000; /// - /// Size of cache when trained in Scope + /// Memory budget (in MBs) to use for KMeans acceleration /// - public int StreamingCacheSize { get; set; } = 1000000; + public int AccelMemBudgetMb { get; set; } = 4096; /// - /// Column to use for labels + /// Degree of lock-free parallelism. Defaults to automatic. Determinism not guaranteed. /// - public string LabelColumn { get; set; } = "Label"; + public int? NumThreads { get; set; } /// /// The data to be used for training @@ -5784,7 +5937,7 @@ public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.Ent public Models.CachingOptions Caching { get; set; } = Models.CachingOptions.Auto; - public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IClusteringOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput { /// /// The trained model @@ -5792,21 +5945,26 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LinearSvmBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(KMeansPlusPlusClusterer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); - return new LinearSvmBinaryClassifierPipelineStep(output); + return new KMeansPlusPlusClustererPipelineStep(output); } - private class LinearSvmBinaryClassifierPipelineStep : ILearningPipelinePredictorStep + private class KMeansPlusPlusClustererPipelineStep : ILearningPipelinePredictorStep { - public LinearSvmBinaryClassifierPipelineStep(Output output) + public KMeansPlusPlusClustererPipelineStep(Output output) { Model = output.PredictorModel; } @@ -5820,14 +5978,149 @@ namespace Trainers { /// - /// Train a logistic regression binary model + /// Train a linear SVM. /// - public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + public sealed partial class LinearSvmBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem { /// - /// Show statistics of training examples. + /// Regularizer constant + /// + [TlcModule.SweepableFloatParamAttribute("Lambda", 1E-05f, 0.1f, stepSize:10, isLogScale:true)] + public float Lambda { get; set; } = 0.001f; + + /// + /// Batch size + /// + public int BatchSize { get; set; } = 1; + + /// + /// Perform projection to unit-ball? Typically used with batch size > 1. + /// + [TlcModule.SweepableDiscreteParamAttribute("PerformProjection", new object[]{false, true})] + public bool PerformProjection { get; set; } = false; + + /// + /// No bias + /// + [TlcModule.SweepableDiscreteParamAttribute("NoBias", new object[]{false, true})] + public bool NoBias { get; set; } = false; + + /// + /// The calibrator kind to apply to the predictor. Specify null for no calibration + /// + [JsonConverter(typeof(ComponentSerializer))] + public CalibratorTrainer Calibrator { get; set; } = new PlattCalibratorCalibratorTrainer(); + + /// + /// The maximum number of examples to use when training the calibrator + /// + public int MaxCalibrationExamples { get; set; } = 1000000; + + /// + /// Number of iterations + /// + [TlcModule.SweepableLongParamAttribute("NumIterations", 1, 100, stepSize:10, isLogScale:true)] + public int NumIterations { get; set; } = 1; + + /// + /// Initial Weights and bias, comma-separated + /// + public string InitialWeights { get; set; } + + /// + /// Init weights diameter + /// + [TlcModule.SweepableFloatParamAttribute("InitWtsDiameter", 0f, 1f, numSteps:5)] + public float InitWtsDiameter { get; set; } + + /// + /// Whether to shuffle for each training iteration + /// + [TlcModule.SweepableDiscreteParamAttribute("Shuffle", new object[]{false, true})] + public bool Shuffle { get; set; } = true; + + /// + /// Size of cache when trained in Scope + /// + public int StreamingCacheSize { get; set; } = 1000000; + + /// + /// Column to use for labels + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// The data to be used for training + /// + public Var TrainingData { get; set; } = new Var(); + + /// + /// Column to use for features + /// + public string FeatureColumn { get; set; } = "Features"; + + /// + /// Normalize option for the feature column + /// + public Models.NormalizeOption NormalizeFeatures { get; set; } = Models.NormalizeOption.Auto; + + /// + /// Whether learner should cache input training data + /// + public Models.CachingOptions Caching { get; set; } = Models.CachingOptions.Auto; + + + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBinaryClassificationOutput, Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITrainerOutput + { + /// + /// The trained model + /// + public Var PredictorModel { get; set; } = new Var(); + + } + public Var GetInputData() => TrainingData; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LinearSvmBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + TrainingData = dataStep.Data; + } + Output output = experiment.Add(this); + return new LinearSvmBinaryClassifierPipelineStep(output); + } + + private class LinearSvmBinaryClassifierPipelineStep : ILearningPipelinePredictorStep + { + public LinearSvmBinaryClassifierPipelineStep(Output output) + { + Model = output.PredictorModel; + } + + public Var Model { get; } + } + } + } + + namespace Trainers + { + + /// + /// Train a logistic regression binary model + /// + public sealed partial class LogisticRegressionBinaryClassifier : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithWeight, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInputWithLabel, Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITrainerInput, Microsoft.ML.ILearningPipelineItem + { + + + /// + /// Show statistics of training examples. /// public bool ShowTrainingStats { get; set; } = false; @@ -5937,14 +6230,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LogisticRegressionBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LogisticRegressionBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new LogisticRegressionBinaryClassifierPipelineStep(output); } @@ -6082,14 +6380,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMul public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LogisticRegressionClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LogisticRegressionClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new LogisticRegressionClassifierPipelineStep(output); } @@ -6150,14 +6453,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMul public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(NaiveBayesClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(NaiveBayesClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new NaiveBayesClassifierPipelineStep(output); } @@ -6300,14 +6608,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(OnlineGradientDescentRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OnlineGradientDescentRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new OnlineGradientDescentRegressorPipelineStep(output); } @@ -6384,14 +6697,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(OrdinaryLeastSquaresRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OrdinaryLeastSquaresRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new OrdinaryLeastSquaresRegressorPipelineStep(output); } @@ -6524,14 +6842,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PoissonRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PoissonRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new PoissonRegressorPipelineStep(output); } @@ -6660,14 +6983,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new StochasticDualCoordinateAscentBinaryClassifierPipelineStep(output); } @@ -6780,14 +7108,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IMul public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new StochasticDualCoordinateAscentClassifierPipelineStep(output); } @@ -6900,14 +7233,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IReg public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(StochasticDualCoordinateAscentRegressor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new StochasticDualCoordinateAscentRegressorPipelineStep(output); } @@ -7034,14 +7372,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IBin public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(StochasticGradientDescentBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(StochasticGradientDescentBinaryClassifier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new StochasticGradientDescentBinaryClassifierPipelineStep(output); } @@ -7107,14 +7450,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ApproximateBootstrapSampler)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ApproximateBootstrapSampler)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ApproximateBootstrapSamplerPipelineStep(output); } @@ -7167,14 +7515,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(BinaryPredictionScoreColumnsRenamer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(BinaryPredictionScoreColumnsRenamer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new BinaryPredictionScoreColumnsRenamerPipelineStep(output); } @@ -7311,14 +7664,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(BinNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(BinNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new BinNormalizerPipelineStep(output); } @@ -7483,14 +7841,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(CategoricalHashOneHotVectorizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(CategoricalHashOneHotVectorizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new CategoricalHashOneHotVectorizerPipelineStep(output); } @@ -7653,14 +8016,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(CategoricalOneHotVectorizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(CategoricalOneHotVectorizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new CategoricalOneHotVectorizerPipelineStep(output); } @@ -7772,14 +8140,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(CharacterTokenizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(CharacterTokenizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new CharacterTokenizerPipelineStep(output); } @@ -7862,14 +8235,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ColumnConcatenator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ColumnConcatenator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ColumnConcatenatorPipelineStep(output); } @@ -7976,14 +8354,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ColumnCopier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ColumnCopier)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ColumnCopierPipelineStep(output); } @@ -8036,14 +8419,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ColumnDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ColumnDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ColumnDropperPipelineStep(output); } @@ -8096,14 +8484,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ColumnSelector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ColumnSelector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ColumnSelectorPipelineStep(output); } @@ -8258,14 +8651,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ColumnTypeConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ColumnTypeConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ColumnTypeConverterPipelineStep(output); } @@ -8323,14 +8721,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(CombinerByContiguousGroupId)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(CombinerByContiguousGroupId)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new CombinerByContiguousGroupIdPipelineStep(output); } @@ -8457,14 +8860,19 @@ public sealed class Output public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ConditionalNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ConditionalNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ConditionalNormalizerPipelineStep(output); } @@ -8518,14 +8926,19 @@ public sealed class Output public Var OutputData { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(DataCache)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(DataCache)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new DataCachePipelineStep(output); } @@ -8750,14 +9163,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(Dictionarizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(Dictionarizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new DictionarizerPipelineStep(output); } @@ -8810,14 +9228,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FeatureCombiner)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FeatureCombiner)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new FeatureCombinerPipelineStep(output); } @@ -8875,14 +9298,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FeatureSelectorByCount)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FeatureSelectorByCount)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new FeatureSelectorByCountPipelineStep(output); } @@ -8950,14 +9378,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(FeatureSelectorByMutualInformation)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(FeatureSelectorByMutualInformation)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new FeatureSelectorByMutualInformationPipelineStep(output); } @@ -9094,14 +9527,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(GlobalContrastNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(GlobalContrastNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new GlobalContrastNormalizerPipelineStep(output); } @@ -9253,14 +9691,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(HashConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(HashConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new HashConverterPipelineStep(output); } @@ -9367,14 +9810,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(KeyToTextConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(KeyToTextConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new KeyToTextConverterPipelineStep(output); } @@ -9432,14 +9880,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LabelColumnKeyBooleanConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LabelColumnKeyBooleanConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new LabelColumnKeyBooleanConverterPipelineStep(output); } @@ -9556,14 +10009,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LabelIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LabelIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new LabelIndicatorPipelineStep(output); } @@ -9616,14 +10074,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LabelToFloatConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LabelToFloatConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new LabelToFloatConverterPipelineStep(output); } @@ -9745,14 +10208,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LogMeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LogMeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new LogMeanVarianceNormalizerPipelineStep(output); } @@ -9887,14 +10355,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(LpNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(LpNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new LpNormalizerPipelineStep(output); } @@ -10034,14 +10507,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MeanVarianceNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MeanVarianceNormalizerPipelineStep(output); } @@ -10144,14 +10622,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MinMaxNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MinMaxNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MinMaxNormalizerPipelineStep(output); } @@ -10300,14 +10783,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MissingValueHandler)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MissingValueHandler)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MissingValueHandlerPipelineStep(output); } @@ -10414,14 +10902,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MissingValueIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MissingValueIndicator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MissingValueIndicatorPipelineStep(output); } @@ -10528,14 +11021,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MissingValuesDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MissingValuesDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MissingValuesDropperPipelineStep(output); } @@ -10593,14 +11091,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MissingValuesRowDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MissingValuesRowDropper)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MissingValuesRowDropperPipelineStep(output); } @@ -10747,14 +11250,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(MissingValueSubstitutor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(MissingValueSubstitutor)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new MissingValueSubstitutorPipelineStep(output); } @@ -10796,6 +11304,11 @@ public sealed class Output /// public Var OutputModel { get; set; } = new Var(); + /// + /// Data + /// + public Var Data { get; set; } = new Var(); + } } } @@ -10945,14 +11458,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(NGramTranslator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(NGramTranslator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new NGramTranslatorPipelineStep(output); } @@ -11000,14 +11518,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(NoOperation)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(NoOperation)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new NoOperationPipelineStep(output); } @@ -11060,14 +11583,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(OptionalColumnCreator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(OptionalColumnCreator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new OptionalColumnCreatorPipelineStep(output); } @@ -11120,14 +11648,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PredictedLabelColumnOriginalValueConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PredictedLabelColumnOriginalValueConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new PredictedLabelColumnOriginalValueConverterPipelineStep(output); } @@ -11209,14 +11742,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(RandomNumberGenerator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(RandomNumberGenerator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new RandomNumberGeneratorPipelineStep(output); } @@ -11294,14 +11832,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(RowRangeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(RowRangeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new RowRangeFilterPipelineStep(output); } @@ -11359,14 +11902,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(RowSkipAndTakeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(RowSkipAndTakeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new RowSkipAndTakeFilterPipelineStep(output); } @@ -11419,14 +11967,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(RowSkipFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(RowSkipFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new RowSkipFilterPipelineStep(output); } @@ -11479,14 +12032,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(RowTakeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(RowTakeFilter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new RowTakeFilterPipelineStep(output); } @@ -11539,14 +12097,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(ScoreColumnSelector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(ScoreColumnSelector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new ScoreColumnSelectorPipelineStep(output); } @@ -11643,14 +12206,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(Segregator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(Segregator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new SegregatorPipelineStep(output); } @@ -11708,14 +12276,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(SentimentAnalyzer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(SentimentAnalyzer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new SentimentAnalyzerPipelineStep(output); } @@ -11833,14 +12406,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(SupervisedBinNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(SupervisedBinNormalizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new SupervisedBinNormalizerPipelineStep(output); } @@ -12024,14 +12602,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - // throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(TextFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - //Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new TextFeaturizerPipelineStep(output); } @@ -12144,14 +12727,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(TextToKeyConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(TextToKeyConverter)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new TextToKeyConverterPipelineStep(output); } @@ -12256,14 +12844,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(TreeLeafFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(TreeLeafFeaturizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new TreeLeafFeaturizerPipelineStep(output); } @@ -12412,14 +13005,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(WordTokenizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(WordTokenizer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new WordTokenizerPipelineStep(output); } diff --git a/src/Microsoft.ML/Data/CollectionDataSource.cs b/src/Microsoft.ML/Data/CollectionDataSource.cs index 56523fc994..8551079d30 100644 --- a/src/Microsoft.ML/Data/CollectionDataSource.cs +++ b/src/Microsoft.ML/Data/CollectionDataSource.cs @@ -52,6 +52,8 @@ public void SetInput(IHostEnvironment environment, Experiment experiment) experiment.SetInput(_dataViewEntryPoint.Data, _dataView); } + public Var GetInputData() => null; + public abstract IDataView GetDataView(IHostEnvironment environment); } diff --git a/src/Microsoft.ML/ILearningPipelineItem.cs b/src/Microsoft.ML/ILearningPipelineItem.cs index d0430b711d..3d29b65a2e 100644 --- a/src/Microsoft.ML/ILearningPipelineItem.cs +++ b/src/Microsoft.ML/ILearningPipelineItem.cs @@ -14,6 +14,7 @@ namespace Microsoft.ML public interface ILearningPipelineItem { ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment); + Var GetInputData(); } /// diff --git a/src/Microsoft.ML/Models/BinaryCrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs similarity index 76% rename from src/Microsoft.ML/Models/BinaryCrossValidator.cs rename to src/Microsoft.ML/Models/CrossValidator.cs index a90fb313cc..6841c04808 100644 --- a/src/Microsoft.ML/Models/BinaryCrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -9,9 +9,9 @@ namespace Microsoft.ML.Models { - public sealed partial class BinaryCrossValidator + public sealed partial class CrossValidator { - public void CrossValidate(LearningPipeline pipeline) + public PredictionModel CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() { @@ -22,8 +22,10 @@ public void CrossValidate(LearningPipeline pipeline) List loaders = new List(); List> transformModels = new List>(); Var lastTransformModel = null; - Var firstInput = null; + Var firstPipelineDataStep = null; Var firstModel = null; + Var lastData = null; + ILearningPipelineItem firstTransform = null; foreach (ILearningPipelineItem currentItem in pipeline) { if (currentItem is ILearningPipelineLoader loader) @@ -37,8 +39,11 @@ public void CrossValidate(LearningPipeline pipeline) if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) { transformModels.Add(dataStep.Model); - if (firstInput == null) - firstInput = dataStep.Data; + if (firstPipelineDataStep == null) + { + firstPipelineDataStep = dataStep.Data; + firstTransform = currentItem; + } } else if (step is ILearningPipelinePredictorStep predictorDataStep) @@ -59,14 +64,16 @@ public void CrossValidate(LearningPipeline pipeline) } else predictorModel = predictorDataStep.Model; + firstModel = predictorModel; var scorer = new Transforms.Scorer { PredictorModel = predictorModel }; - firstModel = predictorModel; + var scorerOutput = subGraph.Add(scorer); lastTransformModel = scorerOutput.ScoringTransform; + lastData = scorerOutput.ScoredData; step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); transformModels.Clear(); } @@ -84,29 +91,29 @@ public void CrossValidate(LearningPipeline pipeline) var modelOutput = subGraph.Add(modelInput); lastTransformModel = modelOutput.OutputModel; + lastData = modelOutput.Data; } var experiment = environment.CreateExperiment(); - var importTextOutput = loaders[0].ApplyStep(null, experiment); - - var crossValidateBinary = new ML.Models.BinaryCrossValidator - { - Data = (importTextOutput as ILearningPipelineDataStep).Data, - Nodes = subGraph - }; - crossValidateBinary.Inputs.Data = firstInput; - crossValidateBinary.Outputs.Model = firstModel; - var crossValidateOutput = experiment.Add(crossValidateBinary); + Data = (importTextOutput as ILearningPipelineDataStep).Data; + Nodes = subGraph; + TransformModel = null; + Inputs.Data = firstTransform.GetInputData(); + Outputs.Model = null; + Outputs.TransformModel = lastTransformModel; + Outputs.TransformData = lastData; + Outputs.UseTransformModel = true; + var crossValidateOutput = experiment.Add(this); experiment.Compile(); foreach (ILearningPipelineLoader loader in loaders) { loader.SetInput(environment, experiment); } + experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); - ITransformModel model = experiment.GetOutput(lastTransformModel); + ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[0]); BatchPredictionEngine predictor; using (var memoryStream = new MemoryStream()) { @@ -116,7 +123,7 @@ public void CrossValidate(LearningPipeline pipeline) predictor = environment.CreateBatchPredictionEngine(memoryStream); - //return new PredictionModel(predictor, memoryStream); + return new PredictionModel(predictor, memoryStream); } } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs index 302a71245c..c6fed29fd3 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs @@ -234,6 +234,29 @@ public static ArrayIPredictorModelOutput MakeArray(IHostEnvironment env, ArrayIP return result; } + public sealed class ArrayITransformModelInput + { + [Argument(ArgumentType.Required, HelpText = "The models", SortOrder = 1)] + public ITransformModel[] TransformModel; + } + + public sealed class ArrayITransformModelOutput + { + [TlcModule.Output(Desc = "The model array", SortOrder = 1)] + public ITransformModel[] OutputModel; + } + + [TlcModule.EntryPoint(Desc = "Create and array variable", Name = "Data.TransformModelArrayConverter")] + public static ArrayITransformModelOutput MakeArray(IHostEnvironment env, ArrayITransformModelInput input) + { + var result = new ArrayITransformModelOutput + { + OutputModel = input.TransformModel + }; + return result; + } + + public sealed class ArrayIDataViewInput { [Argument(ArgumentType.Required, HelpText = "The data sets", SortOrder = 1)] diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 1cb950f939..db086ba014 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -31,6 +31,15 @@ public sealed class SubGraphOutput { [Argument(ArgumentType.Required, HelpText = "The model", SortOrder = 1)] public Var Model; + + [Argument(ArgumentType.Required, HelpText = "The transform model", SortOrder = 2)] + public Var TransformModel; + + [Argument(ArgumentType.Required, HelpText = "Transform data", SortOrder = 3)] + public Var TransformData; + + [Argument(ArgumentType.Required, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 4)] + public bool UseTransformModel = false; } public sealed class Arguments @@ -42,7 +51,8 @@ public sealed class Arguments public IDataView Data; [TlcModule.OptionalInput] - [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model from the pipeline before this command. It gets included in the Output.PredictorModel.", SortOrder = 2)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model from the pipeline before this command. " + + "It gets included in the Output.PredictorModel.", SortOrder = 2)] public ITransformModel TransformModel; // This is the subgraph that describes how to train a model for each fold. It should @@ -79,19 +89,24 @@ public sealed class Arguments // but that requires changes in the entry points infrastructure to support structs in the output classes. public sealed class Output { - [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel.", SortOrder = 1)] + [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + + "provided as the Input.TransformModel.", SortOrder = 1)] public IPredictorModel[] PredictorModel; - [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + + "provided as the Input.TransformModel.", SortOrder = 2)] + public ITransformModel[] TransformModel; + + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView[] Warnings; - [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 4)] public IDataView[] OverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 5)] public IDataView[] PerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 6)] public IDataView[] ConfusionMatrix; } @@ -121,6 +136,7 @@ public static CommonOutputs.MacroOutput CrossValidate( subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); var predModelVars = new Var[input.NumFolds]; + var transformModelVars = new Var[input.NumFolds]; var inputTransformModelVars = new Var[input.NumFolds]; var warningsVars = new Var[input.NumFolds]; var overallMetricsVars = new Var[input.NumFolds]; @@ -152,11 +168,30 @@ public static CommonOutputs.MacroOutput CrossValidate( { VarName = mapping[input.Inputs.Data.VarName] }; - args.Outputs.Model = new Var + + if (input.Outputs.Model != null) + { + args.Outputs.Model = new Var + { + VarName = mapping[input.Outputs.Model.VarName] + }; + } + + if (input.Outputs.TransformModel != null) + { + args.Outputs.TransformModel = new Var + { + VarName = mapping[input.Outputs.TransformModel.VarName] + }; + } + + args.Outputs.TransformData = new Var { - VarName = mapping[input.Outputs.Model.VarName] + VarName = mapping[input.Outputs.TransformData.VarName] }; + args.Outputs.UseTransformModel = input.Outputs.UseTransformModel; + // Set train/test trainer kind to match. args.Kind = input.Kind; @@ -170,23 +205,48 @@ public static CommonOutputs.MacroOutput CrossValidate( inputBindingMap.Add(nameof(args.TestingData), new List { testingData }); inputMap.Add(testingData, new ArrayIndexVariableBinding(cvSplitOutput.TestData.VarName, k)); var outputMap = new Dictionary(); + var transformModelVar = new Var(); var predModelVar = new Var(); - outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName); - predModelVars[k] = predModelVar; - - ML.Transforms.TwoHeterogeneousModelCombiner.Output modelCombineOutput = null; - if (transformModelVarName != null && transformModelVarName.VariableName != null) + if (input.Outputs.UseTransformModel) { - var modelCombine = new ML.Transforms.TwoHeterogeneousModelCombiner + outputMap.Add(nameof(TrainTestMacro.Output.TransformModel), transformModelVar.VarName); + transformModelVars[k] = transformModelVar; + ML.Transforms.ModelCombiner.Output modelCombineOutput = null; + if (transformModelVarName != null && transformModelVarName.VariableName != null) { - TransformModel = { VarName = transformModelVarName.VariableName }, - PredictorModel = predModelVar - }; - - exp.Reset(); - modelCombineOutput = exp.Add(modelCombine); - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); - predModelVars[k] = modelCombineOutput.PredictorModel; + var modelCombine = new ML.Transforms.ModelCombiner + { + Models = new ArrayVar( + new Var[] { + new Var { VarName = transformModelVarName.VariableName }, + transformModelVar } + ) + }; + + exp.Reset(); + modelCombineOutput = exp.Add(modelCombine); + subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + transformModelVars[k] = modelCombineOutput.OutputModel; + } + } + else + { + outputMap.Add(nameof(TrainTestMacro.Output.PredictorModel), predModelVar.VarName); + predModelVars[k] = predModelVar; + ML.Transforms.TwoHeterogeneousModelCombiner.Output modelCombineOutput = null; + if (transformModelVarName != null && transformModelVarName.VariableName != null) + { + var modelCombine = new ML.Transforms.TwoHeterogeneousModelCombiner + { + TransformModel = { VarName = transformModelVarName.VariableName }, + PredictorModel = predModelVar + }; + + exp.Reset(); + modelCombineOutput = exp.Add(modelCombine); + subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + predModelVars[k] = modelCombineOutput.PredictorModel; + } } var warningVar = new Var(); @@ -206,13 +266,26 @@ public static CommonOutputs.MacroOutput CrossValidate( exp.Reset(); - var outModels = new ML.Data.PredictorModelArrayConverter + if (input.Outputs.UseTransformModel) { - Model = new ArrayVar(predModelVars) - }; - var outModelsOutput = new ML.Data.PredictorModelArrayConverter.Output(); - outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); - exp.Add(outModels, outModelsOutput); + var outModels = new ML.Data.TransformModelArrayConverter + { + TransformModel = new ArrayVar(transformModelVars) + }; + var outModelsOutput = new ML.Data.TransformModelArrayConverter.Output(); + outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.TransformModel)); + exp.Add(outModels, outModelsOutput); + } + else + { + var outModels = new ML.Data.PredictorModelArrayConverter + { + Model = new ArrayVar(predModelVars) + }; + var outModelsOutput = new ML.Data.PredictorModelArrayConverter.Output(); + outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); + exp.Add(outModels, outModelsOutput); + } var warnings = new ML.Data.IDataViewArrayConverter { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs index fa34cfd7ac..9f7cbb727b 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs @@ -23,6 +23,9 @@ public sealed class CombineTransformModelsOutput { [TlcModule.Output(Desc = "Combined model", SortOrder = 1)] public ITransformModel OutputModel; + + [TlcModule.Output(Desc = "Data", SortOrder = 2)] + public IDataView Data; } public sealed class PredictorModelInput @@ -89,7 +92,7 @@ public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironme for (int i = input.Models.Length - 2; i >= 0; i--) model = model.Apply(env, input.Models[i]); - return new CombineTransformModelsOutput { OutputModel = model }; + return new CombineTransformModelsOutput { OutputModel = model, Data = model.Data }; } [TlcModule.EntryPoint(Name = "Transforms.ManyHeterogeneousModelCombiner", Desc = "Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel.")] diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index b05b5e5c69..1eb7d08908 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Transforms; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(TrainTestMacro), null, typeof(SignatureEntryPointModule), "TrainTestMacro")] @@ -25,6 +26,15 @@ public sealed class SubGraphOutput { [Argument(ArgumentType.Required, HelpText = "The model", SortOrder = 1)] public Var Model; + + [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 2)] + public Var TransformModel; + + [Argument(ArgumentType.Required, HelpText = "Transform data", SortOrder = 3)] + public Var TransformData; + + [Argument(ArgumentType.Required, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 4)] + public bool UseTransformModel = false; } public sealed class Arguments @@ -62,31 +72,36 @@ public sealed class Arguments public sealed class Output { - [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel.", SortOrder = 1)] + [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + + "provided as the Input.TransformModel.", SortOrder = 1)] public IPredictorModel PredictorModel; - [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + [TlcModule.Output(Desc = "The final model including the trained predictor model and the model from the transforms, " + + "provided as the Input.TransformModel.", SortOrder = 2)] + public ITransformModel TransformModel; + + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView Warnings; - [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 4)] public IDataView OverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 5)] public IDataView PerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 6)] public IDataView ConfusionMatrix; - [TlcModule.Output(Desc = "Warning dataset for training", SortOrder = 6)] + [TlcModule.Output(Desc = "Warning dataset for training", SortOrder = 7)] public IDataView TrainingWarnings; - [TlcModule.Output(Desc = "Overall metrics dataset for training", SortOrder = 7)] + [TlcModule.Output(Desc = "Overall metrics dataset for training", SortOrder = 8)] public IDataView TrainingOverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset for training", SortOrder = 8)] + [TlcModule.Output(Desc = "Per instance metrics dataset for training", SortOrder = 9)] public IDataView TrainingPerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset for training", SortOrder = 9)] + [TlcModule.Output(Desc = "Confusion matrix dataset for training", SortOrder = 10)] public IDataView TrainingConfusionMatrix; } @@ -117,10 +132,13 @@ public static CommonOutputs.MacroOutput TrainTest( subGraphRunContext.RemoveVariable(dataVariable); // Change the subgraph to use the model variable as output. - varName = input.Outputs.Model.VarName; + varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.Model.VarName; if (!subGraphRunContext.TryGetVariable(varName, out dataVariable)) throw env.Except($"Invalid variable name '{varName}'."); - string outputVarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); + + string outputVarName = input.Outputs.UseTransformModel ? node.GetOutputVariableName(nameof(Output.TransformModel)) : + node.GetOutputVariableName(nameof(Output.PredictorModel)); + foreach (var subGraphNode in subGraphNodes) subGraphNode.RenameOutputVariable(dataVariable.Name, outputVarName); subGraphRunContext.RemoveVariable(dataVariable); @@ -136,26 +154,50 @@ public static CommonOutputs.MacroOutput TrainTest( var testingVar = node.GetInputVariable(nameof(input.TestingData)); var exp = new Experiment(env); - //combine the predictor model with any potential transfrom model passed from the outer graph - if (transformModelVarName != null && transformModelVarName.VariableName != null) + DatasetScorer.Output scoreNodeOutput = null; + if (input.Outputs.UseTransformModel) { - var modelCombine = new ML.Transforms.TwoHeterogeneousModelCombiner + //combine the predictor model with any potential transfrom model passed from the outer graph + if (transformModelVarName != null && transformModelVarName.VariableName != null) { - TransformModel = { VarName = transformModelVarName.VariableName }, + var modelCombine = new ML.Transforms.ModelCombiner + { + Models = new ArrayVar( + new Var[] { + new Var { VarName = transformModelVarName.VariableName }, + new Var { VarName = outputVarName} } + ) + }; + + var modelCombineOutput = exp.Add(modelCombine); + outputVarName = modelCombineOutput.OutputModel.VarName; + } + } + else + { + //combine the predictor model with any potential transfrom model passed from the outer graph + if (transformModelVarName != null && transformModelVarName.VariableName != null) + { + var modelCombine = new TwoHeterogeneousModelCombiner + { + TransformModel = { VarName = transformModelVarName.VariableName }, + PredictorModel = { VarName = outputVarName } + }; + + var modelCombineOutput = exp.Add(modelCombine); + outputVarName = modelCombineOutput.PredictorModel.VarName; + } + + // Add the scoring node for testing. + var scoreNode = new DatasetScorer + { + Data = { VarName = testingVar.ToJson() }, PredictorModel = { VarName = outputVarName } }; - var modelCombineOutput = exp.Add(modelCombine); - outputVarName = modelCombineOutput.PredictorModel.VarName; + scoreNodeOutput = exp.Add(scoreNode); } - // Add the scoring node for testing. - var scoreNode = new ML.Transforms.DatasetScorer - { - Data = { VarName = testingVar.ToJson() }, - PredictorModel = { VarName = outputVarName } - }; - var scoreNodeOutput = exp.Add(scoreNode); subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); // Do not double-add previous nodes. @@ -172,23 +214,29 @@ public static CommonOutputs.MacroOutput TrainTest( if (input.IncludeTrainingMetrics) { - // Add the scoring node for training. - var scoreNodeTraining = new ML.Transforms.DatasetScorer + DatasetScorer.Output scoreNodeTrainingOutput = null; + if (!input.Outputs.UseTransformModel) { - Data = { VarName = trainingVar.ToJson() }, - PredictorModel = { VarName = outputVarName } - }; - var scoreNodeTrainingOutput = exp.Add(scoreNodeTraining); + // Add the scoring node for training. + var scoreNodeTraining = new DatasetScorer + { + Data = { VarName = trainingVar.ToJson() }, + PredictorModel = { VarName = outputVarName } + }; + scoreNodeTrainingOutput = exp.Add(scoreNodeTraining); + } + subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); // Do not double-add previous nodes. exp.Reset(); - // Add the evaluator node for training. + // Add the evaluator node for training. var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNodeTraining = evalInputOutputTraining.Item1; var evalOutputTraining = evalInputOutputTraining.Item2; - evalNodeTraining.Data.VarName = scoreNodeTrainingOutput.ScoredData.VarName; + evalNodeTraining.Data.VarName = input.Outputs.UseTransformModel ? input.Outputs.TransformData.VarName : + scoreNodeTrainingOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName)) evalOutputTraining.Warnings.VarName = outVariableName; @@ -211,7 +259,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutput = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNode = evalInputOutput.Item1; var evalOutput = evalInputOutput.Item2; - evalNode.Data.VarName = scoreNodeOutput.ScoredData.VarName; + evalNode.Data.VarName = input.Outputs.UseTransformModel ? input.Outputs.TransformData.VarName : scoreNodeOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out outVariableName)) evalOutput.Warnings.VarName = outVariableName; diff --git a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs index 9fb0560701..108befb74b 100644 --- a/src/Microsoft.ML/Runtime/Experiment/Experiment.cs +++ b/src/Microsoft.ML/Runtime/Experiment/Experiment.cs @@ -34,7 +34,6 @@ private sealed class SerializationHelper private readonly JsonSerializer _serializer; private readonly SerializationHelper _helper; private EntryPointGraph _graph; - public Experiment(Runtime.IHostEnvironment env) { _env = env; diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index f1e45fa446..b05bc4496e 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -882,9 +882,19 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata if (classBase.Contains("ICalibratorInput")) isCalibrator = true; + if (isTransform) + writer.WriteLine("public Var GetInputData() => Data;"); + else + writer.WriteLine("public Var GetInputData() => TrainingData;"); + + writer.WriteLine(""); string className = GeneratorUtils.GetClassAndMethodNames(entryPointInfo).Item2; writer.WriteLine("public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); writer.WriteLine("{"); + + writer.Indent(); + writer.WriteLine("if (previousStep != null)"); + writer.WriteLine("{"); writer.Indent(); writer.WriteLine("if (!(previousStep is ILearningPipelineDataStep dataStep))"); writer.WriteLine("{"); @@ -901,6 +911,9 @@ private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCata else writer.WriteLine("TrainingData = dataStep.Data;"); + writer.Outdent(); + writer.WriteLine("}"); + string pipelineStep = $"{className}PipelineStep"; writer.WriteLine($"Output output = experiment.Add(this);"); writer.WriteLine($"return new {pipelineStep}(output);"); diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs index 4e3e3fb8e4..32cdc551ad 100644 --- a/src/Microsoft.ML/TextLoader.cs +++ b/src/Microsoft.ML/TextLoader.cs @@ -43,7 +43,7 @@ public TextLoader(string inputFilePath, bool useHeader = false, _inputFilePath = inputFilePath; SetCustomStringFromType(useHeader, separator, allowQuotedStrings, supportSparse, trimWhitespace); } - + private IFileHandle GetTextLoaderFileHandle(IHostEnvironment env, string trainFilePath) => new SimpleFileHandle(env, trainFilePath, false, false); @@ -94,6 +94,8 @@ private string TypeToName(Type type) throw new System.NotSupportedException("Type ${type.FullName} is not implemented or supported."); //Add more types. } + public Var GetInputData() => null; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { Contracts.Assert(previousStep == null); diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs index c647110702..316d7eab55 100644 --- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs +++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs @@ -15,7 +15,8 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output) { } - [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] + //[Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] + [Fact] public void GenerateCSharpAPI() { var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs"); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index af742a3156..d198998399 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -24,6 +24,7 @@ public void TrainAndPredictSentimentModelTest() string dataPath = GetDataPath(SentimentDataPath); var pipeline = new LearningPipeline(); pipeline.Add(new TextLoader(dataPath, useHeader: true, separator: "tab")); + pipeline.Add(new Dictionarizer("Label")); pipeline.Add(new TextFeaturizer("Features", "SentimentText") { KeepDiacritics = false, @@ -37,10 +38,10 @@ public void TrainAndPredictSentimentModelTest() }); pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); - //pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - BinaryCrossValidator bcv = new BinaryCrossValidator(); - bcv.CrossValidate(pipeline); - PredictionModel model = pipeline.Train(); + pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + CrossValidator bcv = new CrossValidator(); + PredictionModel model = bcv.CrossValidate(pipeline); + // PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] { @@ -57,8 +58,8 @@ public void TrainAndPredictSentimentModelTest() IEnumerable predictions = model.Predict(sentiments); Assert.Equal(2, predictions.Count()); - Assert.False(predictions.ElementAt(0).Sentiment); - Assert.True(predictions.ElementAt(1).Sentiment); + //Assert.False(predictions.ElementAt(0).Sentiment); + //Assert.True(predictions.ElementAt(1).Sentiment); string testDataPath = GetDataPath(SentimentTestPath); var testData = new TextLoader(testDataPath, useHeader: true, separator: "tab"); @@ -106,7 +107,7 @@ public class SentimentData public class SentimentPrediction { [ColumnName("PredictedLabel")] - public bool Sentiment; + public float Sentiment; } } } From c39dfc7fbde6b37445e990735deb46960bc8271d Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 21 May 2018 20:39:29 -0700 Subject: [PATCH 03/18] resolve merge conflicts. --- src/Microsoft.ML/CSharpApi.cs | 281 +++++++++++++++--- src/Microsoft.ML/Data/TextLoader.cs | 1 - .../Internal/Tools/CSharpApiGenerator.cs | 18 +- src/Microsoft.ML/TextLoader.cs | 126 -------- .../Microsoft.ML.TestFramework/ModelHelper.cs | 1 - 5 files changed, 258 insertions(+), 169 deletions(-) delete mode 100644 src/Microsoft.ML/TextLoader.cs diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 389c9dd6e5..6865861de7 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -22,6 +22,18 @@ namespace Runtime { public sealed partial class Experiment { + public Microsoft.ML.Data.CustomTextLoader.Output Add(Microsoft.ML.Data.CustomTextLoader input) + { + var output = new Microsoft.ML.Data.CustomTextLoader.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Data.CustomTextLoader input, Microsoft.ML.Data.CustomTextLoader.Output output) + { + _jsonNodes.Add(Serialize("Data.CustomTextLoader", input, output)); + } + public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) { var output = new Microsoft.ML.Data.DataViewReference.Output(); @@ -1296,6 +1308,39 @@ public void Add(Microsoft.ML.Transforms.WordTokenizer input, Microsoft.ML.Transf } } + namespace Data + { + + /// + /// Import a dataset from a text file + /// + [Obsolete("Use TextLoader instead.")] + public sealed partial class CustomTextLoader + { + + + /// + /// Location of the input file + /// + public Var InputFile { get; set; } = new Var(); + + /// + /// Custom schema to use for parsing + /// + public string CustomSchema { get; set; } + + + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + } + namespace Data { @@ -1380,12 +1425,176 @@ public sealed class Output namespace Data { + public sealed partial class TextLoaderArguments + { + /// + /// Use separate parsing threads? + /// + public bool UseThreads { get; set; } = true; + + /// + /// File containing a header with feature names. If specified, header defined in the data file (header+) is ignored. + /// + public string HeaderFile { get; set; } + + /// + /// Maximum number of rows to produce + /// + public long? MaxRows { get; set; } + + /// + /// Whether the input may include quoted values, which can contain separator characters, colons, and distinguish empty values from missing values. When true, consecutive separators denote a missing value and an empty value is denoted by "". When false, consecutive separators denote an empty value. + /// + public bool AllowQuoting { get; set; } = true; + + /// + /// Whether the input may include sparse representations + /// + public bool AllowSparse { get; set; } = true; + + /// + /// Number of source columns in the text data. Default is that sparse rows contain their size information. + /// + public int? InputSize { get; set; } + + /// + /// Source column separator. + /// + public char[] Separator { get; set; } = { '\t' }; + + /// + /// Column groups. Each group is specified as name:type:numeric-ranges, eg, col=Features:R4:1-17,26,35-40 + /// + public TextLoaderColumn[] Column { get; set; } + + /// + /// Remove trailing whitespace from lines + /// + public bool TrimWhitespace { get; set; } = false; + + /// + /// Data file has header with feature names. Header is read only if options 'hs' and 'hf' are not specified. + /// + public bool HasHeader { get; set; } = false; + + } + + public sealed partial class TextLoaderColumn + { + /// + /// Name of the column + /// + public string Name { get; set; } + + /// + /// Type of the items in the column + /// + public DataKind? Type { get; set; } + + /// + /// Source index range(s) of the column + /// + public TextLoaderRange[] Source { get; set; } + + /// + /// For a key column, this defines the range of values + /// + public KeyRange KeyRange { get; set; } + + } + + public sealed partial class TextLoaderRange + { + /// + /// First index in the range + /// + public int Min { get; set; } + + /// + /// Last index in the range + /// + public int? Max { get; set; } + + /// + /// This range extends to the end of the line, but should be a fixed number of items + /// + public bool AutoEnd { get; set; } = false; + + /// + /// This range extends to the end of the line, which can vary from line to line + /// + public bool VariableEnd { get; set; } = false; + + /// + /// This range includes only other indices not specified + /// + public bool AllOther { get; set; } = false; + + /// + /// Force scalar columns to be treated as vectors of length one + /// + public bool ForceVector { get; set; } = false; + + } + + public sealed partial class KeyRange + { + /// + /// First index in the range + /// + public ulong Min { get; set; } = 0; + + /// + /// Last index in the range + /// + public ulong? Max { get; set; } + + /// + /// Whether the key is contiguous + /// + public bool Contiguous { get; set; } = true; + + } + /// /// Import a dataset from a text file /// - public sealed partial class TextLoader + public sealed partial class TextLoader : Microsoft.ML.ILearningPipelineLoader { + [JsonIgnore] + private string _inputFilePath = null; + public TextLoader(string filePath) + { + _inputFilePath = filePath; + } + + public void SetInput(IHostEnvironment env, Experiment experiment) + { + IFileHandle inputFile = new SimpleFileHandle(env, _inputFilePath, false, false); + experiment.SetInput(InputFile, inputFile); + } + + public Var GetInputData() => null; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + Contracts.Assert(previousStep == null); + + return new TextLoaderPipelineStep(experiment.Add(this)); + } + + private class TextLoaderPipelineStep : ILearningPipelineDataStep + { + public TextLoaderPipelineStep (Output output) + { + Data = output.Data; + Model = null; + } + + public Var Data { get; } + public Var Model { get; } + } /// /// Location of the input file @@ -1393,9 +1602,9 @@ public sealed partial class TextLoader public Var InputFile { get; set; } = new Var(); /// - /// Custom schema to use for parsing + /// Arguments /// - public string CustomSchema { get; set; } + public Data.TextLoaderArguments Arguments { get; set; } = new Data.TextLoaderArguments(); public sealed class Output @@ -1623,7 +1832,7 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ICla namespace Models { - public sealed class CrossValidationBinaryMacroSubGraphInput + public sealed partial class CrossValidationBinaryMacroSubGraphInput { /// /// The data to be used for training @@ -1632,7 +1841,7 @@ public sealed class CrossValidationBinaryMacroSubGraphInput } - public sealed class CrossValidationBinaryMacroSubGraphOutput + public sealed partial class CrossValidationBinaryMacroSubGraphOutput { /// /// The model @@ -1888,7 +2097,7 @@ public enum MacroUtilsTrainerKinds } - public sealed class CrossValidationMacroSubGraphInput + public sealed partial class CrossValidationMacroSubGraphInput { /// /// The data to be used for training @@ -1897,7 +2106,7 @@ public sealed class CrossValidationMacroSubGraphInput } - public sealed class CrossValidationMacroSubGraphOutput + public sealed partial class CrossValidationMacroSubGraphOutput { /// /// The model @@ -2336,7 +2545,7 @@ public enum CachingOptions } - public sealed class OneVersusAllMacroSubGraphOutput + public sealed partial class OneVersusAllMacroSubGraphOutput { /// /// The predictor model for the subgraph exemplar. @@ -2994,7 +3203,7 @@ public sealed class Output namespace Models { - public sealed class TrainTestBinaryMacroSubGraphInput + public sealed partial class TrainTestBinaryMacroSubGraphInput { /// /// The data to be used for training @@ -3003,7 +3212,7 @@ public sealed class TrainTestBinaryMacroSubGraphInput } - public sealed class TrainTestBinaryMacroSubGraphOutput + public sealed partial class TrainTestBinaryMacroSubGraphOutput { /// /// The model @@ -3079,7 +3288,7 @@ public sealed class Output namespace Models { - public sealed class TrainTestMacroSubGraphInput + public sealed partial class TrainTestMacroSubGraphInput { /// /// The data to be used for training @@ -3088,7 +3297,7 @@ public sealed class TrainTestMacroSubGraphInput } - public sealed class TrainTestMacroSubGraphOutput + public sealed partial class TrainTestMacroSubGraphOutput { /// /// The model @@ -7549,7 +7758,7 @@ public BinaryPredictionScoreColumnsRenamerPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformBinColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformBinColumn : OneToOneColumn, IOneToOneColumn { /// /// Max number of bins, power of 2 recommended @@ -7706,7 +7915,7 @@ public enum CategoricalTransformOutputKind : byte } - public sealed class CategoricalHashTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CategoricalHashTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The number of bits to hash into. Must be between 1 and 30, inclusive. @@ -7881,7 +8090,7 @@ public enum TermTransformSortOrder : byte } - public sealed class CategoricalTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CategoricalTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Output kind: Bag (multi-set vector), Ind (indicator vector), Key (index), or Binary encoded indicator vector @@ -8050,7 +8259,7 @@ public CategoricalOneHotVectorizerPipelineStep(Output output) namespace Transforms { - public sealed class CharTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CharTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -8174,7 +8383,7 @@ public CharacterTokenizerPipelineStep(Output output) namespace Transforms { - public sealed class ConcatTransformColumn : ManyToOneColumn, IManyToOneColumn + public sealed partial class ConcatTransformColumn : ManyToOneColumn, IManyToOneColumn { /// /// Name of the new column @@ -8269,7 +8478,7 @@ public ColumnConcatenatorPipelineStep(Output output) namespace Transforms { - public sealed class CopyColumnsTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class CopyColumnsTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -8546,7 +8755,7 @@ public enum DataKind : byte } - public sealed class ConvertTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class ConvertTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The result type @@ -8755,7 +8964,7 @@ public CombinerByContiguousGroupIdPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformAffineColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformAffineColumn : OneToOneColumn, IOneToOneColumn { /// /// Whether to map zero to zero, preserving sparsity @@ -9038,7 +9247,7 @@ public sealed class Output namespace Transforms { - public sealed class TermTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class TermTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Maximum number of terms to keep when auto-training @@ -9412,7 +9621,7 @@ public FeatureSelectorByMutualInformationPipelineStep(Output output) namespace Transforms { - public sealed class LpNormNormalizerTransformGcnColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LpNormNormalizerTransformGcnColumn : OneToOneColumn, IOneToOneColumn { /// /// Normalize by standard deviation rather than L2 norm @@ -9561,7 +9770,7 @@ public GlobalContrastNormalizerPipelineStep(Output output) namespace Transforms { - public sealed class HashJoinTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class HashJoinTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Whether the values need to be combined for a single hash @@ -9725,7 +9934,7 @@ public HashConverterPipelineStep(Output output) namespace Transforms { - public sealed class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class KeyToValueTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -9914,7 +10123,7 @@ public LabelColumnKeyBooleanConverterPipelineStep(Output output) namespace Transforms { - public sealed class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LabelIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The positive example class for binary classification. @@ -10108,7 +10317,7 @@ public LabelToFloatConverterPipelineStep(Output output) namespace Transforms { - public sealed class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NormalizeTransformLogNormalColumn : OneToOneColumn, IOneToOneColumn { /// /// Max number of examples used to train the normalizer @@ -10250,7 +10459,7 @@ public enum LpNormNormalizerTransformNormalizerKind : byte } - public sealed class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class LpNormNormalizerTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The norm to use to normalize each sample @@ -10668,7 +10877,7 @@ public enum NAHandleTransformReplacementKind } - public sealed class NAHandleTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAHandleTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// The replacement method to utilize @@ -10817,7 +11026,7 @@ public MissingValueHandlerPipelineStep(Output output) namespace Transforms { - public sealed class NAIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAIndicatorTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -10936,7 +11145,7 @@ public MissingValueIndicatorPipelineStep(Output output) namespace Transforms { - public sealed class NADropTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NADropTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Name of the new column @@ -11140,7 +11349,7 @@ public enum NAReplaceTransformReplacementKind } - public sealed class NAReplaceTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NAReplaceTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Replacement value for NAs (uses default value if not given) @@ -11323,7 +11532,7 @@ public enum NgramTransformWeightingCriteria } - public sealed class NgramTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class NgramTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Maximum ngram length @@ -11682,7 +11891,7 @@ public PredictedLabelColumnOriginalValueConverterPipelineStep(Output output) namespace Transforms { - public sealed class GenerateNumberTransformColumn + public sealed partial class GenerateNumberTransformColumn { /// /// Name of the new column @@ -12466,7 +12675,7 @@ public enum TextTransformTextNormKind } - public sealed class TextTransformColumn : ManyToOneColumn, IManyToOneColumn + public sealed partial class TextTransformColumn : ManyToOneColumn, IManyToOneColumn { /// /// Name of the new column @@ -12480,7 +12689,7 @@ public sealed class TextTransformColumn : ManyToOneColumn, } - public sealed class TermLoaderArguments + public sealed partial class TermLoaderArguments { /// /// List of terms @@ -12910,7 +13119,7 @@ public sealed class Output namespace Transforms { - public sealed class DelimitedTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn + public sealed partial class DelimitedTokenizeTransformColumn : OneToOneColumn, IOneToOneColumn { /// /// Comma separated set of term separator(s). Commonly: 'space', 'comma', 'semicolon' or other single character. diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 3c8550ef09..1829d6225e 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -21,7 +21,6 @@ public TextLoaderRange() /// /// Convenience constructor for the scalar case, when a given column /// in the schema spans only a single column in the dataset. - /// and are set to the single value . /// /// Column index in the dataset. public TextLoaderRange(int ordinal) diff --git a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs index 0067e1a133..7f5114b185 100644 --- a/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs +++ b/src/Microsoft.ML/Runtime/Internal/Tools/CSharpApiGenerator.cs @@ -750,6 +750,10 @@ private void GenerateLoaderAddInputMethod(IndentingTextWriter writer, string cla writer.WriteLine("}"); writer.WriteLine(""); + //GetInputData + writer.WriteLine("public Var GetInputData() => null;"); + writer.WriteLine(""); + //Apply. writer.WriteLine($"public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment)"); writer.WriteLine("{"); @@ -955,22 +959,26 @@ private void GenerateInput(IndentingTextWriter writer, writer.WriteLine(); GenerateOutput(writer, entryPointInfo, out HashSet outputVariableNames); - GenerateApplyFunction(writer, entryPointInfo, transformType, classBase, outputVariableNames); + GenerateApplyFunction(writer, entryPointInfo, transformType, outputVariableNames, entryPointInfo.InputKinds); writer.Outdent(); writer.WriteLine("}"); } private static void GenerateApplyFunction(IndentingTextWriter writer, ModuleCatalog.EntryPointInfo entryPointInfo, - Type type, string classBase, HashSet outputVariableNames) + Type type, HashSet outputVariableNames, Type[] inputKinds) { + if (inputKinds == null) + return; + bool isTransform = false; bool isCalibrator = false; - if (classBase.Contains("ITransformInput")) + + if (inputKinds.Any(t => typeof(ITransformInput).IsAssignableFrom(t))) isTransform = true; - else if (!classBase.Contains("ITrainerInput")) + else if (!inputKinds.Any(t => typeof(ITrainerInput).IsAssignableFrom(t))) return; - if (classBase.Contains("ICalibratorInput")) + if (inputKinds.Any(t => typeof(ICalibratorInput).IsAssignableFrom(t))) isCalibrator = true; if (isTransform) diff --git a/src/Microsoft.ML/TextLoader.cs b/src/Microsoft.ML/TextLoader.cs deleted file mode 100644 index 32cdc551ad..0000000000 --- a/src/Microsoft.ML/TextLoader.cs +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; -using System; -using System.Linq; -using System.Reflection; -using System.Text; - -namespace Microsoft.ML -{ - public class TextLoader : ILearningPipelineLoader - { - private string _inputFilePath; - private string CustomSchema; - private Data.TextLoader ImportTextInput; - - /// - /// Construct a TextLoader object - /// - /// Data file path - /// Does the file contains header? - /// How the columns are seperated? - /// Options: separator="tab", separator="space", separator="comma" or separator=[single character]. - /// By default separator=null means "tab" - /// Whether the input may include quoted values, - /// which can contain separator characters, colons, - /// and distinguish empty values from missing values. When true, consecutive separators - /// denote a missing value and an empty value is denoted by \"\". - /// When false, consecutive separators denote an empty value. - /// Whether the input may include sparse representations e.g. - /// if one of the row contains "5 2:6 4:3" that's mean there are 5 columns all zero - /// except for 3rd and 5th columns which have values 6 and 3 - /// Remove trailing whitespace from lines - public TextLoader(string inputFilePath, bool useHeader = false, - string separator = null, bool allowQuotedStrings = true, - bool supportSparse = true, bool trimWhitespace = false) - { - _inputFilePath = inputFilePath; - SetCustomStringFromType(useHeader, separator, allowQuotedStrings, supportSparse, trimWhitespace); - } - - private IFileHandle GetTextLoaderFileHandle(IHostEnvironment env, string trainFilePath) => - new SimpleFileHandle(env, trainFilePath, false, false); - - private void SetCustomStringFromType(bool useHeader, string separator, - bool allowQuotedStrings, bool supportSparse, bool trimWhitespace) - { - StringBuilder schemaBuilder = new StringBuilder(CustomSchema); - foreach (var field in typeof(TInput).GetFields()) - { - var mappingAttr = field.GetCustomAttribute(); - if(mappingAttr == null) - throw Contracts.ExceptParam(field.Name, $"{field.Name} is missing ColumnAttribute"); - - schemaBuilder.AppendFormat("col={0}:{1}:{2} ", - mappingAttr.Name ?? field.Name, - TypeToName(field.FieldType.IsArray ? field.FieldType.GetElementType() : field.FieldType), - mappingAttr.Ordinal); - } - - if (useHeader) - schemaBuilder.Append(nameof(TextLoader.Arguments.HasHeader)).Append("+ "); - - if (separator != null) - schemaBuilder.Append(nameof(TextLoader.Arguments.Separator)).Append("=").Append(separator).Append(" "); - - if (!allowQuotedStrings) - schemaBuilder.Append(nameof(TextLoader.Arguments.AllowQuoting)).Append("- "); - - if (!supportSparse) - schemaBuilder.Append(nameof(TextLoader.Arguments.AllowSparse)).Append("- "); - - if (trimWhitespace) - schemaBuilder.Append(nameof(TextLoader.Arguments.TrimWhitespace)).Append("+ "); - - schemaBuilder.Length--; - CustomSchema = schemaBuilder.ToString(); - } - - private string TypeToName(Type type) - { - if (type == typeof(string)) - return "TX"; - else if (type == typeof(float) || type == typeof(double)) - return "R4"; - else if (type == typeof(bool)) - return "BL"; - else - throw new System.NotSupportedException("Type ${type.FullName} is not implemented or supported."); //Add more types. - } - - public Var GetInputData() => null; - - public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) - { - Contracts.Assert(previousStep == null); - - ImportTextInput = new Data.TextLoader(); - ImportTextInput.CustomSchema = CustomSchema; - var importOutput = experiment.Add(ImportTextInput); - return new TextLoaderPipelineStep(importOutput.Data); - } - - public void SetInput(IHostEnvironment env, Experiment experiment) - { - IFileHandle inputFile = GetTextLoaderFileHandle(env, _inputFilePath); - experiment.SetInput(ImportTextInput.InputFile, inputFile); - } - - private class TextLoaderPipelineStep : ILearningPipelineDataStep - { - public TextLoaderPipelineStep(Var data) - { - Data = data; - } - - public Var Data { get; } - public Var Model => null; - } - } -} diff --git a/test/Microsoft.ML.TestFramework/ModelHelper.cs b/test/Microsoft.ML.TestFramework/ModelHelper.cs index 1b0ab4eb8e..edf4408bcb 100644 --- a/test/Microsoft.ML.TestFramework/ModelHelper.cs +++ b/test/Microsoft.ML.TestFramework/ModelHelper.cs @@ -58,7 +58,6 @@ public static IDataView GetKcHouseDataView(string dataPath) private static ITransformModel CreateKcHousePricePredictorModel(string dataPath) { Experiment experiment = s_environment.CreateExperiment(); - var importData = new Data.TextLoader(dataPath) { Arguments = new TextLoaderArguments From 6f273a3b1a25272a03672b509c6444ca143012f4 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 22 May 2018 14:16:18 -0700 Subject: [PATCH 04/18] Add entry point to combine arrays of metrics data views into one data view --- .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 130 ++- .../Commands/CrossValidationCommand.cs | 221 +--- .../Commands/EvaluateCommand.cs | 6 +- src/Microsoft.ML.Data/Commands/TestCommand.cs | 6 +- .../Commands/TrainTestCommand.cs | 6 +- .../Evaluators/AnomalyDetectionEvaluator.cs | 11 +- .../Evaluators/BinaryClassifierEvaluator.cs | 25 +- .../Evaluators/EvaluatorUtils.cs | 979 ++++++++++++------ .../Evaluators/MamlEvaluator.cs | 79 +- .../MulticlassClassifierEvaluator.cs | 29 +- .../Evaluators/QuantileRegressionEvaluator.cs | 14 +- .../Evaluators/RankerEvaluator.cs | 12 +- src/Microsoft.ML/CSharpApi.cs | 154 ++- .../EntryPoints/CrossValidationMacro.cs | 137 ++- .../UnitTests/TestCSharpApi.cs | 43 +- .../UnitTests/TestEntryPoints.cs | 12 +- 17 files changed, 1190 insertions(+), 675 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index 47007edaa6..7fff1c646a 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -7,6 +7,7 @@ Models.BinaryClassificationEvaluator Evaluates a binary classification scored da Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output] Models.ClassificationEvaluator Evaluates a multi class classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate MultiClass Microsoft.ML.Runtime.Data.MultiClassMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.ClusterEvaluator Evaluates a clustering scored dataset. Microsoft.ML.Runtime.Data.Evaluate Clustering Microsoft.ML.Runtime.Data.ClusteringMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput +Models.CrossValidationResultsCombiner Combine the metric data views returned from cross validation. Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro CombineMetrics Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+CombineMetricsInput Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+CombinedOutput Models.CrossValidator Cross validation for general learning Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro CrossValidate Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationMacro+Output] Models.CrossValidatorDatasetSplitter Split the dataset into the specified number of cross-validation folds (train and test sets) Microsoft.ML.Runtime.EntryPoints.CVSplit Split Microsoft.ML.Runtime.EntryPoints.CVSplit+Input Microsoft.ML.Runtime.EntryPoints.CVSplit+Output Models.DatasetTransformer Applies a TransformModel to a dataset. Microsoft.ML.Runtime.EntryPoints.ModelOperations Apply Microsoft.ML.Runtime.EntryPoints.ModelOperations+ApplyTransformModelInput Microsoft.ML.Runtime.EntryPoints.ModelOperations+ApplyTransformModelOutput diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index a6309fe36a..224b32c108 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -888,6 +888,116 @@ "IEvaluatorOutput" ] }, + { + "Name": "Models.CrossValidationResultsCombiner", + "Desc": "Combine the metric data views returned from cross validation.", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "Kind", + "Type": { + "Kind": "Enum", + "Values": [ + "SignatureBinaryClassifierTrainer", + "SignatureMultiClassClassifierTrainer", + "SignatureRankerTrainer", + "SignatureRegressorTrainer", + "SignatureMultiOutputRegressorTrainer", + "SignatureAnomalyDetectorTrainer", + "SignatureClusteringTrainer" + ] + }, + "Desc": "Specifies the trainer kind, which determines the evaluator to be used.", + "Required": true, + "SortOrder": 0.0, + "IsNullable": false, + "Default": "SignatureBinaryClassifierTrainer" + }, + { + "Name": "OverallMetrics", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Overall metrics datasets", + "Required": false, + "SortOrder": 1.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "PerInstanceMetrics", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Per instance metrics datasets", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "ConfusionMatrix", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Confusion matrix datasets", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "Warnings", + "Type": { + "Kind": "Array", + "ItemType": "DataView" + }, + "Desc": "Warning datasets", + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "LabelColumn", + "Type": "String", + "Desc": "The label column name", + "Aliases": [ + "Label" + ], + "Required": false, + "SortOrder": 4.0, + "IsNullable": false, + "Default": "Label" + } + ], + "Outputs": [ + { + "Name": "Warnings", + "Type": "DataView", + "Desc": "Warning dataset" + }, + { + "Name": "OverallMetrics", + "Type": "DataView", + "Desc": "Overall metrics dataset" + }, + { + "Name": "PerInstanceMetrics", + "Type": "DataView", + "Desc": "Per instance metrics dataset" + }, + { + "Name": "ConfusionMatrix", + "Type": "DataView", + "Desc": "Confusion matrix dataset" + } + ] + }, { "Name": "Models.CrossValidator", "Desc": "Cross validation for general learning", @@ -1018,34 +1128,22 @@ }, { "Name": "Warnings", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Warning dataset" }, { "Name": "OverallMetrics", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Overall metrics dataset" }, { "Name": "PerInstanceMetrics", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Per instance metrics dataset" }, { "Name": "ConfusionMatrix", - "Type": { - "Kind": "Array", - "ItemType": "DataView" - }, + "Type": "DataView", "Desc": "Confusion matrix dataset" } ] diff --git a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs index c430abf85a..35e764762e 100644 --- a/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs +++ b/src/Microsoft.ML.Data/Commands/CrossValidationCommand.cs @@ -211,38 +211,34 @@ private void RunCore(IChannel ch, string cmd) } // Print the overall results. - eval.PrintOverallResults(ch, Args.SummaryFilename, tasks.Select(t => t.Result.Metrics).ToArray()); + if (!TryGetOverallMetrics(tasks.Select(t => t.Result.Metrics).ToArray(), out var overallList)) + throw ch.Except("No overall metrics found"); + + var overall = eval.GetOverallResults(overallList.ToArray()); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, Args.NumFolds); + eval.PrintAdditionalMetrics(ch, tasks.Select(t => t.Result.Metrics).ToArray()); Dictionary[] metricValues = tasks.Select(t => t.Result.Metrics).ToArray(); SendTelemetryMetric(metricValues); // Save the per-instance results. if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { - Func, int, IDataView> getPerInstance = - (task, i) => - { - if (!Args.OutputExampleFoldIndex) - return task.Result.PerInstanceResults; - - // If the fold index is requested, add a column containing it. We use the first column in the data view - // as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = task.Result.PerInstanceResults.Schema.GetColumnName(0); - var inputColType = task.Result.PerInstanceResults.Schema.GetColumnType(0); - return Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, - task.Result.PerInstanceResults, inputColName, MetricKinds.ColumnNames.FoldIndex, - inputColType, Args.NumFolds, i + 1, "FoldIndex", default(ValueGetter>)); - }; - - var foldDataViews = tasks.Select(getPerInstance).ToArray(); + var perInstance = EvaluateUtils.CombinePerInstanceDataViews(Host, eval, Args.CollateMetrics, + Args.OutputExampleFoldIndex, tasks.Select(t => t.Result.PerInstanceResults).ToArray(), out var variableSizeVectorColumnNames); + if (variableSizeVectorColumnNames.Length > 0) + { + ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", + string.Join(", ", variableSizeVectorColumnNames)); + } if (Args.CollateMetrics) { - var perInst = AppendPerInstanceDataViews(foldDataViews, ch); - MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInst); + ch.Assert(perInstance.Length == 1); + MetricWriter.SavePerInstance(Host, ch, Args.OutputDataFile, perInstance[0]); } else { int i = 0; - foreach (var idv in foldDataViews) + foreach (var idv in perInstance) { MetricWriter.SavePerInstance(Host, ch, ConstructPerFoldName(Args.OutputDataFile, i), idv); i++; @@ -251,166 +247,6 @@ private void RunCore(IChannel ch, string cmd) } } - private IDataView AppendPerInstanceDataViews(IEnumerable foldDataViews, IChannel ch) - { - // Make sure there are no variable size vector columns. - // This is a dictionary from the column name to its vector size. - var vectorSizes = new Dictionary(); - var firstDvSlotNames = new Dictionary>(); - var firstDvKeyColumns = new List(); - var firstDvVectorKeyColumns = new List(); - var variableSizeVectorColumnNames = new List(); - var list = new List(); - int dvNumber = 0; - foreach (var dv in foldDataViews) - { - var hidden = new List(); - for (int i = 0; i < dv.Schema.ColumnCount; i++) - { - if (dv.Schema.IsHidden(i)) - { - hidden.Add(i); - continue; - } - - var type = dv.Schema.GetColumnType(i); - var name = dv.Schema.GetColumnName(i); - if (type.IsVector) - { - if (dvNumber == 0) - { - if (dv.Schema.HasKeyNames(i, type.ItemType.KeyCount)) - firstDvVectorKeyColumns.Add(name); - // Store the slot names of the 1st idv and use them as baseline. - if (dv.Schema.HasSlotNames(i, type.VectorSize)) - { - VBuffer slotNames = default(VBuffer); - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); - firstDvSlotNames.Add(name, slotNames); - } - } - - int cachedSize; - if (vectorSizes.TryGetValue(name, out cachedSize)) - { - VBuffer slotNames; - // In the event that no slot names were recorded here, then slotNames will be - // the default, length 0 vector. - firstDvSlotNames.TryGetValue(name, out slotNames); - if (!VerifyVectorColumnsMatch(cachedSize, i, dv, type, ref slotNames)) - variableSizeVectorColumnNames.Add(name); - } - else - vectorSizes.Add(name, type.VectorSize); - } - else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount)) - { - // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. - firstDvKeyColumns.Add(name); - } - } - var idv = dv; - if (hidden.Count > 0) - { - var args = new ChooseColumnsByIndexTransform.Arguments(); - args.Drop = true; - args.Index = hidden.ToArray(); - idv = new ChooseColumnsByIndexTransform(Host, args, idv); - } - list.Add(idv); - dvNumber++; - } - - if (variableSizeVectorColumnNames.Count == 0 && firstDvKeyColumns.Count == 0) - return AppendRowsDataView.Create(Host, null, list.ToArray()); - - var views = list.ToArray(); - foreach (var keyCol in firstDvKeyColumns) - EvaluateUtils.ReconcileKeyValues(Host, views, keyCol); - foreach (var vectorKeyCol in firstDvVectorKeyColumns) - EvaluateUtils.ReconcileVectorKeyValues(Host, views, vectorKeyCol); - - Func keyToValue = - (idv, i) => - { - foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns)) - { - idv = new KeyToValueTransform(Host, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv); - var hidden = FindHiddenColumns(idv.Schema, keyCol); - idv = new ChooseColumnsByIndexTransform(Host, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv); - } - return idv; - }; - - Func selectDropNonVarLenthCol = - (idv) => - { - foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNames) - { - int index; - idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index); - var type = idv.Schema.GetColumnType(index); - - idv = Utils.MarshalInvoke(AddVarLengthColumn, type.ItemType.RawType, Host, idv, - variableSizeVectorColumnName, type); - - // Drop the old column that does not have variable length. - idv = new DropColumnsTransform(Host, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv); - } - return idv; - }; - - if (variableSizeVectorColumnNames.Count > 0) - ch.Warning("Detected columns of variable length: {0}. Consider setting collateMetrics- for meaningful per-Folds results.", string.Join(", ", variableSizeVectorColumnNames)); - return AppendRowsDataView.Create(Host, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray()); - } - - private static IEnumerable FindHiddenColumns(ISchema schema, string colName) - { - for (int i = 0; i < schema.ColumnCount; i++) - { - if (schema.IsHidden(i) && schema.GetColumnName(i) == colName) - yield return i; - } - } - - private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView dv, - ColumnType type, ref VBuffer firstDvSlotNames) - { - if (cachedSize != type.VectorSize) - return false; - - // If we detect mismatch it a sign that slots reshuffling has happened. - if (dv.Schema.HasSlotNames(col, type.VectorSize)) - { - // Verify that slots match with slots from 1st idv. - VBuffer currSlotNames = default(VBuffer); - dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref currSlotNames); - - if (currSlotNames.Length != firstDvSlotNames.Length) - return false; - else - { - var result = true; - VBufferUtils.ForEachEitherDefined(ref currSlotNames, ref firstDvSlotNames, - (slot, val1, val2) => result = result && DvText.Identical(val1, val2)); - return result; - } - } - else - { - // If we don't have slot names, then the first dataview should not have had slot names either. - return firstDvSlotNames.Length == 0; - } - } - - private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) - { - return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, - variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), - (ref VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); - } - /// /// Callback from the CV method to apply the transforms from the train data to the test and/or validation data. /// @@ -504,16 +340,32 @@ private string GetSplitColumn(IChannel ch, IDataView input, ref IDataView output return stratificationColumn; } + private bool TryGetOverallMetrics(Dictionary[] metrics, out List overallList) + { + Host.AssertNonEmpty(metrics); + + overallList = new List(); + for (int i = 0; i < metrics.Length; i++) + { + var dict = metrics[i]; + IDataView idv; + if (!dict.TryGetValue(MetricKinds.OverallMetrics, out idv)) + return false; + overallList.Add(idv); + } + return true; + } + private sealed class FoldHelper { public struct FoldResult { public readonly Dictionary Metrics; public readonly ISchema ScoreSchema; - public readonly IDataView PerInstanceResults; + public readonly RoleMappedData PerInstanceResults; public readonly RoleMappedSchema TrainSchema; - public FoldResult(Dictionary metrics, ISchema scoreSchema, IDataView perInstance, RoleMappedSchema trainSchema) + public FoldResult(Dictionary metrics, ISchema scoreSchema, RoleMappedData perInstance, RoleMappedSchema trainSchema) { Metrics = metrics; ScoreSchema = scoreSchema; @@ -735,12 +587,11 @@ private FoldResult RunFold(int fold) var dataEval = RoleMappedData.CreateOpt(scorePipe, testData.Schema.GetColumnRoleNames()); var dict = eval.Evaluate(dataEval); - IDataView perInstance = null; + RoleMappedData perInstance = null; if (_savePerInstance) { var perInst = eval.GetPerInstanceMetrics(dataEval); - var perInstData = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames()); - perInstance = eval.GetPerInstanceDataViewToSave(perInstData); + perInstance = RoleMappedData.CreateOpt(perInst, dataEval.Schema.GetColumnRoleNames()); } ch.Done(); return new FoldResult(dict, dataEval.Schema.Schema, perInstance, trainData.Schema); diff --git a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs index 533dfdcc41..d0e066d789 100644 --- a/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs +++ b/src/Microsoft.ML.Data/Commands/EvaluateCommand.cs @@ -240,7 +240,11 @@ private void RunCore(IChannel ch) var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) { var perInst = evaluator.GetPerInstanceMetrics(data); diff --git a/src/Microsoft.ML.Data/Commands/TestCommand.cs b/src/Microsoft.ML.Data/Commands/TestCommand.cs index 73ca98f66e..79e7bd5458 100644 --- a/src/Microsoft.ML.Data/Commands/TestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TestCommand.cs @@ -118,7 +118,11 @@ private void RunCore(IChannel ch) var metrics = evaluator.Evaluate(data); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) diff --git a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs index dda298e6b2..f6ffa772f9 100644 --- a/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs +++ b/src/Microsoft.ML.Data/Commands/TrainTestCommand.cs @@ -194,7 +194,11 @@ private void RunCore(IChannel ch, string cmd) var metrics = evaluator.Evaluate(dataEval); MetricWriter.PrintWarnings(ch, metrics); evaluator.PrintFoldResults(ch, metrics); - evaluator.PrintOverallResults(ch, Args.SummaryFilename, metrics); + if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out var overall)) + throw ch.Except("No overall metrics found"); + overall = evaluator.GetOverallResults(overall); + MetricWriter.PrintOverallMetrics(Host, ch, Args.SummaryFilename, overall, 1); + evaluator.PrintAdditionalMetrics(ch, metrics); Dictionary[] metricValues = { metrics }; SendTelemetryMetric(metricValues); if (!string.IsNullOrWhiteSpace(Args.OutputDataFile)) diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index adf4ee4fba..39a5f31c38 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -746,14 +746,8 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - var args = new DropColumnsTransform.Arguments(); args.Column = new[] { @@ -762,8 +756,7 @@ protected override void PrintOverallResultsCore(IChannel ch, string filename, Di AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP, AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos }; - overall = new DropColumnsTransform(Host, args, overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return new DropColumnsTransform(Host, args, overall); } protected override IEnumerable GetPerInstanceColumnsToSave(RoleMappedSchema schema) diff --git a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs index 673d497073..90078da9ee 100644 --- a/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs @@ -1171,18 +1171,16 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - var args = new DropColumnsTransform.Arguments(); args.Column = new[] { BinaryClassifierEvaluator.Entropy }; - overall = new DropColumnsTransform(Host, args, overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return new DropColumnsTransform(Host, args, overall); + } + + protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + { + ch.AssertNonEmpty(metrics); if (!string.IsNullOrEmpty(_prFileName)) { @@ -1228,14 +1226,7 @@ private bool TryGetPrMetrics(Dictionary[] metrics, out IDataV if (!dict.TryGetValue(BinaryClassifierEvaluator.PrCurve, out idv)) return false; if (metrics.Length != 1) - { - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, idv, - inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, metrics.Length, i + 1, "FoldIndex", - default(ValueGetter>)); - } + idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length); else pr = idv; prList.Add(idv); diff --git a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs index 9cc25761f3..dea601e5bc 100644 --- a/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs +++ b/src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs @@ -10,7 +10,6 @@ using System.Text; using System.Threading; using Microsoft.ML.Runtime.CommandLine; -using Microsoft.ML.Runtime.Data.Conversion; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; @@ -200,7 +199,7 @@ public static ColumnInfo GetOptAuxScoreColumnInfo(IExceptionContext ectx, ISchem return null; } - public static bool IsScoreColumnKind(IExceptionContext ectx, ISchema schema, int col, string kind) + private static bool IsScoreColumnKind(IExceptionContext ectx, ISchema schema, int col, string kind) { Contracts.CheckValueOrNull(ectx); ectx.CheckValue(schema, nameof(schema)); @@ -359,7 +358,7 @@ public static IEnumerable> GetMetrics(IDataView met } } - public static IDataView AddTextColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, + private static IDataView AddTextColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, ColumnType typeSrc, string value, string registrationName) { Contracts.Check(typeSrc.RawType == typeof(TSrc)); @@ -367,7 +366,29 @@ public static IDataView AddTextColumn(IHostEnvironment env, IDataView inpu (ref TSrc src, ref DvText dst) => dst = new DvText(value)); } - public static IDataView AddKeyColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, + /// + /// Add a text column containing a fold index to a data view. + /// + /// The host environment. + /// The data view to which we add the column + /// The current fold this data view belongs to. + /// The input data view with an additional text column containing the current fold index. + public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold) + { + // We use the first column in the data view as an input column to the LambdaColumnMapper, + // because it must have an input. + int inputCol = 0; + while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + inputCol++; + env.Assert(inputCol < input.Schema.ColumnCount); + + var inputColName = input.Schema.GetColumnName(0); + var inputColType = input.Schema.GetColumnType(0); + return Utils.MarshalInvoke(AddTextColumn, inputColType.RawType, env, + input, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, string.Format("Fold {0}", curFold), "FoldName"); + } + + private static IDataView AddKeyColumn(IHostEnvironment env, IDataView input, string inputColName, string outputColName, ColumnType typeSrc, int keyCount, int value, string registrationName, ValueGetter> keyValueGetter) { Contracts.Check(typeSrc.RawType == typeof(TSrc)); @@ -381,6 +402,30 @@ public static IDataView AddKeyColumn(IHostEnvironment env, IDataView input }, keyValueGetter); } + /// + /// Add a key type column containing a fold index to a data view. + /// + /// The host environment. + /// The data view to which we add the column + /// The current fold this data view belongs to. + /// The total number of folds. + /// The input data view with an additional key type column containing the current fold index. + public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int curFold, int numFolds) + { + // We use the first column in the data view as an input column to the LambdaColumnMapper, + // because it must have an input. + int inputCol = 0; + while (inputCol < input.Schema.ColumnCount && input.Schema.IsHidden(inputCol)) + inputCol++; + env.Assert(inputCol < input.Schema.ColumnCount); + + var inputColName = input.Schema.GetColumnName(inputCol); + var inputColType = input.Schema.GetColumnType(inputCol); + return Utils.MarshalInvoke(AddKeyColumn, inputColType.RawType, env, + input, inputColName, MetricKinds.ColumnNames.FoldIndex, + inputColType, numFolds, curFold + 1, "FoldIndex", default(ValueGetter>)); + } + /// /// This method takes an array of data views and a specified input vector column, and adds a new output column to each of the data views. /// First, we find the union set of the slot names in the different data views. Next we define a new vector column for each @@ -639,6 +684,594 @@ public static void ReconcileVectorKeyValues(IHostEnvironment env, IDataView[] vi } } + /// + /// This method gets the per-instance metrics from multiple scored data views and either returns them as an + /// array or combines them into a single data view, based on user specifications. + /// + /// A host environment. + /// The evaluator to use for getting the per-instance metrics. + /// If true, data views are combined into a single data view. Otherwise, data views + /// are returned as an array. + /// If true, a column containing the fold index is added to the returned data views. + /// The array of scored data views to evaluate. These are passed as + /// so that the evaluator can know the role mappings it needs. + /// A list of column names that are not included in the combined data view + /// since their types do not match. + /// + public static IDataView[] CombinePerInstanceDataViews(IHostEnvironment env, IMamlEvaluator eval, bool collate, bool outputFoldIndex, RoleMappedData[] perInstance, out string[] variableSizeVectorColumnNames) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckNonEmpty(perInstance, nameof(perInstance)); + + Func getPerInstance = + (rmd, i) => + { + var perInst = eval.GetPerInstanceDataViewToSave(rmd); + + if (!outputFoldIndex) + return perInst; + + // If the fold index is requested, add a column containing it. We use the first column in the data view + // as an input column to the LambdaColumnMapper, because it must have an input. + return AddFoldIndex(env, perInst, i, perInstance.Length); + }; + + var foldDataViews = perInstance.Select(getPerInstance).ToArray(); + if (collate) + { + var combined = AppendPerInstanceDataViews(env, foldDataViews, out variableSizeVectorColumnNames); + return new[] { combined }; + } + else + { + variableSizeVectorColumnNames = new string[0]; + return foldDataViews.ToArray(); + } + } + + /// + /// Combine an array of metric data views into one data view. + /// + public static IDataView CombineOverallMetrics(IHostEnvironment env, IDataView[] metrics) + { + env.AssertNonEmpty(metrics); + + if (metrics.Length == 1) + return metrics[0]; + + var overallList = new List(); + for (int i = 0; i < metrics.Length; i++) + { + // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. + var idv = AddFoldIndex(env, metrics[i], i); + overallList.Add(idv); + } + return AppendRowsDataView.Create(env, overallList[0].Schema, overallList.ToArray()); + } + + private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnumerable foldDataViews, out string[] variableSizeVectorColumnNames) + { + // Make sure there are no variable size vector columns. + // This is a dictionary from the column name to its vector size. + var vectorSizes = new Dictionary(); + var firstDvSlotNames = new Dictionary>(); + var firstDvKeyColumns = new List(); + var firstDvVectorKeyColumns = new List(); + var variableSizeVectorColumnNamesList = new List(); + var list = new List(); + int dvNumber = 0; + foreach (var dv in foldDataViews) + { + var hidden = new List(); + for (int i = 0; i < dv.Schema.ColumnCount; i++) + { + if (dv.Schema.IsHidden(i)) + { + hidden.Add(i); + continue; + } + + var type = dv.Schema.GetColumnType(i); + var name = dv.Schema.GetColumnName(i); + if (type.IsVector) + { + if (dvNumber == 0) + { + if (dv.Schema.HasKeyNames(i, type.ItemType.KeyCount)) + firstDvVectorKeyColumns.Add(name); + // Store the slot names of the 1st idv and use them as baseline. + if (dv.Schema.HasSlotNames(i, type.VectorSize)) + { + VBuffer slotNames = default(VBuffer); + dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); + firstDvSlotNames.Add(name, slotNames); + } + } + + int cachedSize; + if (vectorSizes.TryGetValue(name, out cachedSize)) + { + VBuffer slotNames; + // In the event that no slot names were recorded here, then slotNames will be + // the default, length 0 vector. + firstDvSlotNames.TryGetValue(name, out slotNames); + if (!VerifyVectorColumnsMatch(cachedSize, i, dv, type, ref slotNames)) + variableSizeVectorColumnNamesList.Add(name); + } + else + vectorSizes.Add(name, type.VectorSize); + } + else if (dvNumber == 0 && dv.Schema.HasKeyNames(i, type.KeyCount)) + { + // The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform. + firstDvKeyColumns.Add(name); + } + } + var idv = dv; + if (hidden.Count > 0) + { + var args = new ChooseColumnsByIndexTransform.Arguments(); + args.Drop = true; + args.Index = hidden.ToArray(); + idv = new ChooseColumnsByIndexTransform(env, args, idv); + } + list.Add(idv); + dvNumber++; + } + + variableSizeVectorColumnNames = variableSizeVectorColumnNamesList.ToArray(); + if (variableSizeVectorColumnNamesList.Count == 0 && firstDvKeyColumns.Count == 0) + return AppendRowsDataView.Create(env, null, list.ToArray()); + + var views = list.ToArray(); + foreach (var keyCol in firstDvKeyColumns) + ReconcileKeyValues(env, views, keyCol); + foreach (var vectorKeyCol in firstDvVectorKeyColumns) + ReconcileVectorKeyValues(env, views, vectorKeyCol); + + Func keyToValue = + (idv, i) => + { + foreach (var keyCol in firstDvKeyColumns.Concat(firstDvVectorKeyColumns)) + { + idv = new KeyToValueTransform(env, new KeyToValueTransform.Arguments() { Column = new[] { new KeyToValueTransform.Column() { Name = keyCol }, } }, idv); + var hidden = FindHiddenColumns(idv.Schema, keyCol); + idv = new ChooseColumnsByIndexTransform(env, new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = hidden.ToArray() }, idv); + } + return idv; + }; + + Func selectDropNonVarLenthCol = + (idv) => + { + foreach (var variableSizeVectorColumnName in variableSizeVectorColumnNamesList) + { + int index; + idv.Schema.TryGetColumnIndex(variableSizeVectorColumnName, out index); + var type = idv.Schema.GetColumnType(index); + + idv = Utils.MarshalInvoke(AddVarLengthColumn, type.ItemType.RawType, env, idv, + variableSizeVectorColumnName, type); + + // Drop the old column that does not have variable length. + idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv); + } + return idv; + }; + + return AppendRowsDataView.Create(env, null, views.Select(keyToValue).Select(selectDropNonVarLenthCol).ToArray()); + } + + private static IEnumerable FindHiddenColumns(ISchema schema, string colName) + { + for (int i = 0; i < schema.ColumnCount; i++) + { + if (schema.IsHidden(i) && schema.GetColumnName(i) == colName) + yield return i; + } + } + + private static bool VerifyVectorColumnsMatch(int cachedSize, int col, IDataView dv, + ColumnType type, ref VBuffer firstDvSlotNames) + { + if (cachedSize != type.VectorSize) + return false; + + // If we detect mismatch it a sign that slots reshuffling has happened. + if (dv.Schema.HasSlotNames(col, type.VectorSize)) + { + // Verify that slots match with slots from 1st idv. + VBuffer currSlotNames = default(VBuffer); + dv.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, col, ref currSlotNames); + + if (currSlotNames.Length != firstDvSlotNames.Length) + return false; + else + { + var result = true; + VBufferUtils.ForEachEitherDefined(ref currSlotNames, ref firstDvSlotNames, + (slot, val1, val2) => result = result && DvText.Identical(val1, val2)); + return result; + } + } + else + { + // If we don't have slot names, then the first dataview should not have had slot names either. + return firstDvSlotNames.Length == 0; + } + } + + private static IDataView AddVarLengthColumn(IHostEnvironment env, IDataView idv, string variableSizeVectorColumnName, ColumnType typeSrc) + { + return LambdaColumnMapper.Create(env, "ChangeToVarLength", idv, variableSizeVectorColumnName, + variableSizeVectorColumnName + "_VarLength", typeSrc, new VectorType(typeSrc.ItemType.AsPrimitive), + (ref VBuffer src, ref VBuffer dst) => src.CopyTo(ref dst)); + } + + private static List GetMetricNames(IChannel ch, ISchema schema, IRow row, Func ignoreCol, + ValueGetter[] getters, ValueGetter>[] vBufferGetters) + { + ch.AssertValue(schema); + ch.AssertValue(row); + ch.Assert(Utils.Size(getters) == schema.ColumnCount); + ch.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount); + + // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns + // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. + VBuffer names = default(VBuffer); + int metricCount = 0; + var metricNames = new List(); + for (int i = 0; i < schema.ColumnCount; i++) + { + if (schema.IsHidden(i) || ignoreCol(i)) + continue; + + var type = schema.GetColumnType(i); + var metricName = row.Schema.GetColumnName(i); + if (type.IsNumber) + { + getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, row, i); + metricNames.Add(metricName); + metricCount++; + } + else if (type.IsVector && type.ItemType == NumberType.R8) + { + if (type.VectorSize == 0) + { + ch.Warning("Vector metric '{0}' has different lengths in different folds and will not be averaged for overall results.", metricName); + continue; + } + + vBufferGetters[i] = row.GetGetter>(i); + metricCount += type.VectorSize; + var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); + if (slotNamesType != null && slotNamesType.VectorSize == type.VectorSize && slotNamesType.ItemType.IsText) + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); + else + { + var namesArray = names.Values; + if (Utils.Size(namesArray) < type.VectorSize) + namesArray = new DvText[type.VectorSize]; + for (int j = 0; j < type.VectorSize; j++) + namesArray[j] = new DvText(string.Format("Label_{0}", j)); + names = new VBuffer(type.VectorSize, namesArray); + } + foreach (var name in names.Items(all: true)) + metricNames.Add(string.Format("{0} {1}", metricName, name.Value)); + } + } + ch.Assert(metricNames.Count == metricCount); + return metricNames; + } + + internal static void GetOverallMetricsData(IHostEnvironment env, IDataView data, int numFolds, + out double[] sumMetrics, out double[] sumSqMetrics, out double[] sumWeightedMetrics, out double[] sumSqWeightedMetrics, + out IDataView overall, out List metricNames) + { + ComputeMetricsSum(env, data, numFolds, out int isWeightedCol, out int stratCol, out int stratVal, out int foldCol, + out sumMetrics, out sumSqMetrics, out sumWeightedMetrics, out sumSqWeightedMetrics, out metricNames); + + var nonAveragedCols = new List(); + var avgMetrics = GetAverageToDataView(env, data.Schema, sumMetrics, sumSqMetrics, sumWeightedMetrics, sumSqWeightedMetrics, + numFolds, stratCol, stratVal, isWeightedCol, foldCol, nonAveragedCols); + + var idvList = new List() { avgMetrics }; + + var hasStrat = stratCol >= 0; + if (numFolds > 1 || hasStrat) + { + if (Utils.Size(nonAveragedCols) > 0) + { + var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() }; + data = new DropColumnsTransform(env, dropArgs, data); + } + idvList.Add(data); + } + + overall = AppendRowsDataView.Create(env, avgMetrics.Schema, idvList.ToArray()); + + // If there are stratified results, apply a KeyToValue transform to get the stratification column + // names from the key column. + if (hasStrat) + { + var args = new KeyToValueTransform.Arguments(); + args.Column = new[] { new KeyToValueTransform.Column() { Source = MetricKinds.ColumnNames.StratCol }, }; + overall = new KeyToValueTransform(env, args, overall); + } + } + + internal static void ComputeMetricsSum(IHostEnvironment env, IDataView data, int numFolds, out int isWeightedCol, + out int stratCol, out int stratVal, out int foldCol, out double[] sumMetrics, out double[] sumSqMetrics, + out double[] sumWeightedMetrics, out double[] sumSqWeightedMetrics, out List metricNames) + { + var hasWeighted = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out int wcol); + var hasStrats = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out int scol); + var hasStratVals = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out int svalcol); + env.Assert(hasStrats == hasStratVals); + var hasFoldCol = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out int fcol); + + isWeightedCol = hasWeighted ? wcol : -1; + stratCol = hasStrats ? scol : -1; + stratVal = hasStratVals ? svalcol : -1; + foldCol = hasFoldCol ? fcol : -1; + + // We currently have only double valued or vector of double valued metrics. + int colCount = data.Schema.ColumnCount; + var getters = new ValueGetter[colCount]; + var vBufferGetters = new ValueGetter>[colCount]; + int numResults = 0; + int numWeightedResults = 0; + using (var cursor = data.GetRowCursor(col => true)) + { + DvBool isWeighted = DvBool.False; + ValueGetter isWeightedGetter; + if (hasWeighted) + isWeightedGetter = cursor.GetGetter(isWeightedCol); + else + isWeightedGetter = (ref DvBool dst) => dst = DvBool.False; + + ValueGetter stratColGetter; + if (hasStrats) + { + var type = cursor.Schema.GetColumnType(stratCol); + stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); + } + else + stratColGetter = (ref uint dst) => dst = 0; + + // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns + // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. + using (var ch = env.Register("GetMetricsAsString").Start("Get Metric Names")) + { + metricNames = GetMetricNames(ch, data.Schema, cursor, + i => hasWeighted && i == wcol || hasStrats && (i == scol || i == svalcol) || + hasFoldCol && i == fcol, getters, vBufferGetters); + ch.Done(); + } + + Double metricVal = 0; + VBuffer metricVals = default(VBuffer); + sumMetrics = new double[metricNames.Count]; + sumSqMetrics = numFolds > 1 ? new double[metricNames.Count] : null; + if (hasWeighted) + { + sumWeightedMetrics = new double[metricNames.Count]; + sumSqWeightedMetrics = numFolds > 1 ? new double[metricNames.Count] : null; + } + else + { + sumWeightedMetrics = null; + sumSqWeightedMetrics = null; + } + uint strat = 0; + while (cursor.MoveNext()) + { + stratColGetter(ref strat); + // REVIEW: how to print stratified results? + if (strat > 0) + continue; + + isWeightedGetter(ref isWeighted); + if (isWeighted.IsTrue) + { + // If !average, we should have only one relevant row. + if (numWeightedResults > numFolds) + throw Contracts.Except("Multiple weighted rows found in metrics data view."); + + numWeightedResults++; + UpdateSums(isWeightedCol, stratCol, stratVal, sumWeightedMetrics, sumSqWeightedMetrics, metricNames, hasWeighted, hasStrats, colCount, getters, vBufferGetters, ref metricVal, ref metricVals); + } + else + { + // If !average, we should have only one relevant row. + if (numResults > numFolds) + throw Contracts.Except("Multiple unweighted rows found in metrics data view."); + + numResults++; + UpdateSums(isWeightedCol, stratCol, stratVal, sumMetrics, sumSqMetrics, metricNames, hasWeighted, hasStrats, colCount, getters, vBufferGetters, ref metricVal, ref metricVals); + } + + if (numResults == numFolds && (!hasWeighted || numWeightedResults == numFolds)) + break; + } + } + } + + private static void UpdateSums(int isWeightedCol, int stratCol, int stratVal, double[] sumMetrics, double[] sumSqMetrics, List metricNames, bool hasWeighted, bool hasStrats, int colCount, ValueGetter[] getters, ValueGetter>[] vBufferGetters, ref double metricVal, ref VBuffer metricVals) + { + int iMetric = 0; + for (int i = 0; i < colCount; i++) + { + if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal)) + continue; + + // REVIEW: What to do with metrics that are not doubles? + if (getters[i] != null) + { + getters[i](ref metricVal); + sumMetrics[iMetric] += metricVal; + if (sumSqMetrics != null) + sumSqMetrics[iMetric] += metricVal * metricVal; + iMetric++; + } + else if (vBufferGetters[i] != null) + { + vBufferGetters[i](ref metricVals); + foreach (var metric in metricVals.Items(all: true)) + { + sumMetrics[iMetric] += metric.Value; + if (sumSqMetrics != null) + sumSqMetrics[iMetric] += metric.Value * metric.Value; + iMetric++; + } + } + } + Contracts.Assert(iMetric == metricNames.Count); + } + + internal static IDataView GetAverageToDataView(IHostEnvironment env, ISchema schema, double[] sumMetrics, + double[] sumSqMetrics, double[] sumWeightedMetrics, double[] sumSqWeightedMetrics, int numFolds, + int stratCol, int stratVal, int isWeightedCol, int foldCol, List nonAveragedCols = null) + { + Contracts.AssertValue(env); + + int colCount = schema.ColumnCount; + + var dvBldr = new ArrayDataViewBuilder(env); + var weightedDvBldr = isWeightedCol >= 0 ? new ArrayDataViewBuilder(env) : null; + + int iMetric = 0; + for (int i = 0; i < colCount; i++) + { + if (schema.IsHidden(i)) + continue; + + var type = schema.GetColumnType(i); + var name = schema.GetColumnName(i); + if (i == stratCol) + { + var keyValuesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); + if (keyValuesType == null || !keyValuesType.ItemType.IsText || + keyValuesType.VectorSize != type.KeyCount) + { + throw env.Except("Column '{0}' must have key values metadata", + MetricKinds.ColumnNames.StratCol); + } + + ValueGetter> getKeyValues = + (ref VBuffer dst) => + { + schema.GetMetadata(MetadataUtils.Kinds.KeyValues, stratCol, ref dst); + Contracts.Assert(dst.IsDense); + }; + + var keys = foldCol >= 0 ? new uint[] { 0, 0 } : new uint[] { 0 }; + dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, keys); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, keys); + } + else if (i == stratVal) + { + var stratVals = foldCol >= 0 ? new[] { DvText.NA, DvText.NA } : new[] { DvText.NA }; + dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, stratVals); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, stratVals); + } + else if (i == isWeightedCol) + { + env.AssertValue(weightedDvBldr); + dvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, foldCol >= 0 ? new[] { DvBool.False, DvBool.False } : new[] { DvBool.False }); + weightedDvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, foldCol >= 0 ? new[] { DvBool.True, DvBool.True } : new[] { DvBool.True }); + } + else if (i == foldCol) + { + var foldVals = new[] { new DvText("Average"), new DvText("Standard Deviation") }; + dvBldr.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, foldVals); + weightedDvBldr?.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, foldVals); + } + else if (type.IsNumber) + { + dvBldr.AddScalarColumn(schema, sumMetrics, sumSqMetrics, numFolds, iMetric, name); + weightedDvBldr?.AddScalarColumn(schema, sumWeightedMetrics, sumSqWeightedMetrics, numFolds, iMetric, name); + iMetric++; + } + else if (type.IsKnownSizeVector && type.ItemType == NumberType.R8) + { + dvBldr.AddVectorColumn(env, schema, sumMetrics, sumSqMetrics, numFolds, iMetric, i, type, name); + weightedDvBldr?.AddVectorColumn(env, schema, sumWeightedMetrics, sumSqWeightedMetrics, numFolds, iMetric, i, type, name); + iMetric += type.VectorSize; + } + else + nonAveragedCols?.Add(name); + } + var idv = dvBldr.GetDataView(); + if (weightedDvBldr != null) + idv = AppendRowsDataView.Create(env, idv.Schema, idv, weightedDvBldr.GetDataView()); + return idv; + } + + private static void AddVectorColumn(this ArrayDataViewBuilder dvBldr, IHostEnvironment env, ISchema schema, double[] sumMetrics, double[] sumSqMetrics, int numFolds, int iMetric, int i, ColumnType type, string name) + { + var vectorMetrics = new double[type.VectorSize]; + env.Assert(vectorMetrics.Length > 0); + for (int j = 0; j < vectorMetrics.Length; j++) + vectorMetrics[j] = sumMetrics[iMetric + j] / numFolds; + double[] vectorStdevMetrics = null; + if (sumSqMetrics != null) + { + vectorStdevMetrics = new double[type.VectorSize]; + for (int j = 0; j < vectorStdevMetrics.Length; j++) + vectorStdevMetrics[j] = Math.Sqrt(sumSqMetrics[iMetric + j] / numFolds - vectorMetrics[j] * vectorMetrics[j]); + } + var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); + var slotNames = default(VBuffer); + if (slotNamesType != null && slotNamesType.ItemType.IsText && + slotNamesType.VectorSize == type.VectorSize) + { + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); + Contracts.Assert(slotNames.IsDense); + var values = slotNames.Values; + for (int j = 0; j < values.Length; j++) + values[j] = new DvText(name + values[j]); + slotNames = new VBuffer(slotNames.Length, values, slotNames.Indices); + } + else + { + var values = slotNames.Values; + if (Utils.Size(values) < type.VectorSize) + values = new DvText[type.VectorSize]; + for (int j = 0; j < type.VectorSize; j++) + values[j] = new DvText(name + j); + slotNames = new VBuffer(type.VectorSize, values, slotNames.Indices); + } + ValueGetter> getSlotNames = (ref VBuffer dst) => dst = slotNames; + if (vectorStdevMetrics != null) + { + env.AssertValue(vectorStdevMetrics); + dvBldr.AddColumn(name, getSlotNames, NumberType.R8, new[] { vectorMetrics, vectorStdevMetrics }); + } + else + dvBldr.AddColumn(name, getSlotNames, NumberType.R8, new[] { vectorMetrics }); + } + + private static void AddScalarColumn(this ArrayDataViewBuilder dvBldr, ISchema schema, double[] sumMetrics, double[] sumSqMetrics, int numFolds, int iMetric, string name) + { + Contracts.AssertValue(dvBldr); + + var avg = sumMetrics[iMetric] / numFolds; + if (sumSqMetrics != null) + dvBldr.AddColumn(name, NumberType.R8, avg, Math.Sqrt(sumSqMetrics[iMetric] / numFolds - avg * avg)); + else + dvBldr.AddColumn(name, NumberType.R8, avg); + } + + /// + /// Takes a data view containing one or more rows of metrics, and returns a data view containing additional + /// rows with the average and the standard deviation of the metrics in the input data view. + /// + public static IDataView CombineFoldMetricsDataViews(IHostEnvironment env, IDataView data, int numFolds) + { + GetOverallMetricsData(env, data, numFolds, out var _, out var _, out var _, out var _, out var overall, out var _); + return overall; + } } public static class MetricWriter @@ -791,286 +1424,57 @@ private static double[][] GetConfusionTableAsArray(IDataView confusionDataView, /// metrics. Otherwise it is assigned null. public static string GetPerFoldResults(IHostEnvironment env, IDataView fold, out string weightedMetrics) { - IDataView avgMetrics; - int isWeightedCol; - if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol)) - weightedMetrics = GetMetricsAsString(env, fold, true, 1, out avgMetrics); - else - weightedMetrics = null; - return GetMetricsAsString(env, fold, false, 1, out avgMetrics); + return GetFoldMetricsAsString(env, fold, out weightedMetrics); } - // This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named - // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks - // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column. - // If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates - // nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column, - // or a known length vector of doubles). - // If average is false, no averaging is done, and instead we check that there is exactly one relevant row. Otherwise, we - // add the vector columns of variable length of the list of non-averagable columns if nonAveragedCols is not null. - private static string GetMetricsAsString(IHostEnvironment env, IDataView data, bool weighted, - int numFolds, out IDataView avgMetricsDataView, bool average = false, List nonAveragedCols = null) + private static string GetOverallMetricsAsString(double[] sumMetrics, double[] sumSqMetrics, int numFolds, bool weighted, bool average, List metricNames) { - int isWeightedCol; - bool hasWeighted = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol); - // If the IsWeighted column is not present, weighted must be false. - Contracts.Assert(hasWeighted || !weighted); - - int stratCol; - bool hasStrats = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); - int stratVal; - bool hasStratVals = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal); - Contracts.Assert(hasStrats == hasStratVals); - - int foldCol; - bool hasFoldCol = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol); - - // We currently have only double valued or vector of double valued metrics. - var colCount = data.Schema.ColumnCount; - var getters = new ValueGetter[colCount]; - var vBufferGetters = new ValueGetter>[colCount]; - - double[] avgMetrics; - double[] sumSqMetrics; - List metricNames; - int numResults = 0; - using (var cursor = data.GetRowCursor(col => true)) - { - DvBool isWeighted = DvBool.False; - ValueGetter isWeightedGetter; - if (hasWeighted) - isWeightedGetter = cursor.GetGetter(isWeightedCol); - else - isWeightedGetter = (ref DvBool dst) => dst = DvBool.False; - - ValueGetter stratColGetter; - if (hasStrats) - { - var type = cursor.Schema.GetColumnType(stratCol); - stratColGetter = RowCursorUtils.GetGetterAs(type, cursor, stratCol); - } - else - stratColGetter = (ref uint dst) => dst = 0; - - // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns - // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. - using (var ch = env.Register("GetMetricsAsString").Start("Get Metric Names")) - { - metricNames = GetMetricNames(ch, data.Schema, cursor, - i => hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal) || - hasFoldCol && i == foldCol, getters, vBufferGetters); - ch.Done(); - } - - Double metricVal = 0; - VBuffer metricVals = default(VBuffer); - avgMetrics = new double[metricNames.Count]; - sumSqMetrics = new double[metricNames.Count]; - uint strat = 0; - while (cursor.MoveNext()) - { - isWeightedGetter(ref isWeighted); - if (isWeighted.IsTrue != weighted) - continue; - - stratColGetter(ref strat); - // REVIEW: how to print stratified results? - if (strat > 0) - continue; - - // If !average, we should have only one relevant row. - if (!average && numResults > 0) - throw Contracts.Except("Multiple {0} rows found in metrics data view.", weighted ? "weighted" : "unweighted"); - - numResults++; - int iMetric = 0; - for (int i = 0; i < colCount; i++) - { - if (hasWeighted && i == isWeightedCol || hasStrats && (i == stratCol || i == stratVal)) - continue; - - // REVIEW: What to do with metrics that are not doubles? - if (getters[i] != null) - { - getters[i](ref metricVal); - avgMetrics[iMetric] += metricVal; - if (sumSqMetrics != null) - sumSqMetrics[iMetric] += metricVal * metricVal; - iMetric++; - } - else if (vBufferGetters[i] != null) - { - vBufferGetters[i](ref metricVals); - foreach (var metric in metricVals.Items(all: true)) - { - avgMetrics[iMetric] += metric.Value; - if (sumSqMetrics != null) - sumSqMetrics[iMetric] += metric.Value * metric.Value; - iMetric++; - } - } - } - Contracts.Assert(iMetric == metricNames.Count); - - if (numResults == numFolds) - break; - } - } - var sb = new StringBuilder(); for (int i = 0; i < metricNames.Count; i++) { - avgMetrics[i] /= numResults; + var avg = sumMetrics[i] / numFolds; sb.Append(string.Format("{0}{1}: ", weighted ? "Weighted " : "", metricNames[i]).PadRight(20)); - sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", avgMetrics[i])); + sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", avg)); if (average) { - Contracts.AssertValue(sumSqMetrics); - sb.AppendLine(string.Format(" ({0:N4})", numResults == 1 ? 0 : - Math.Sqrt(sumSqMetrics[i] / numResults - avgMetrics[i] * avgMetrics[i]))); + Contracts.Assert(sumSqMetrics != null || numFolds == 1); + sb.AppendLine(string.Format(" ({0:N4})", numFolds == 1 ? 0 : + Math.Sqrt(sumSqMetrics[i] / numFolds - avg * avg))); } else sb.AppendLine(); } - - if (average) - { - var dvBldr = new ArrayDataViewBuilder(env); - int iMetric = 0; - for (int i = 0; i < colCount; i++) - { - if (hasStrats && i == stratCol) - { - var type = data.Schema.GetColumnType(i); - var keyValuesType = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.KeyValues, i); - if (keyValuesType == null || !keyValuesType.ItemType.IsText || - keyValuesType.VectorSize != type.KeyCount) - { - throw env.Except("Column '{0}' must have key values metadata", - MetricKinds.ColumnNames.StratCol); - } - - ValueGetter> getKeyValues = - (ref VBuffer dst) => - { - data.Schema.GetMetadata(MetadataUtils.Kinds.KeyValues, stratCol, ref dst); - Contracts.Assert(dst.IsDense); - }; - - dvBldr.AddColumn(MetricKinds.ColumnNames.StratCol, getKeyValues, 0, type.KeyCount, (uint)0); - } - else if (hasStratVals && i == stratVal) - dvBldr.AddColumn(MetricKinds.ColumnNames.StratVal, TextType.Instance, DvText.NA); - else if (hasWeighted && i == isWeightedCol) - dvBldr.AddColumn(MetricKinds.ColumnNames.IsWeighted, BoolType.Instance, weighted ? DvBool.True : DvBool.False); - else if (hasFoldCol && i == foldCol) - { - var avg = new DvText("Average"); - dvBldr.AddColumn(MetricKinds.ColumnNames.FoldIndex, TextType.Instance, avg); - } - else if (getters[i] != null) - { - dvBldr.AddColumn(data.Schema.GetColumnName(i), NumberType.R8, avgMetrics[iMetric]); - iMetric++; - } - else if (vBufferGetters[i] != null) - { - var type = data.Schema.GetColumnType(i); - var vectorMetrics = new double[type.VectorSize]; - env.Assert(vectorMetrics.Length > 0); - Array.Copy(avgMetrics, iMetric, vectorMetrics, 0, vectorMetrics.Length); - var slotNamesType = data.Schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); - var name = data.Schema.GetColumnName(i); - var slotNames = default(VBuffer); - if (slotNamesType != null && slotNamesType.ItemType.IsText && - slotNamesType.VectorSize == type.VectorSize) - { - data.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref slotNames); - Contracts.Assert(slotNames.IsDense); - var values = slotNames.Values; - for (int j = 0; j < values.Length; j++) - values[j] = new DvText(name + values[j]); - slotNames = new VBuffer(slotNames.Length, values, slotNames.Indices); - } - else - { - var values = slotNames.Values; - if (Utils.Size(values) < type.VectorSize) - values = new DvText[type.VectorSize]; - for (int j = 0; j < type.VectorSize; j++) - values[j] = new DvText(name + j); - slotNames = new VBuffer(type.VectorSize, values, slotNames.Indices); - } - ValueGetter> getSlotNames = (ref VBuffer dst) => dst = slotNames; - dvBldr.AddColumn(name, getSlotNames, NumberType.R8, new[] { vectorMetrics }); - iMetric += vectorMetrics.Length; - } - else - nonAveragedCols?.Add(data.Schema.GetColumnName(i)); - } - Contracts.Assert(iMetric == metricNames.Count); - avgMetricsDataView = dvBldr.GetDataView(); - } - else - avgMetricsDataView = null; - return sb.ToString(); } - private static List GetMetricNames(IChannel ch, ISchema schema, IRow row, Func ignoreCol, - ValueGetter[] getters, ValueGetter>[] vBufferGetters) + // This method returns a string representation of a set of metrics. If there are stratification columns, it looks for columns named + // StratCol and StratVal, and outputs the metrics in the rows with NA in the StratCol column. If weighted is true, it looks + // for a DvBool column named "IsWeighted" and outputs the metrics in the rows with a value of true in that column. + // If nonAveragedCols is non-null, it computes the average and standard deviation over all the relevant rows and populates + // nonAveragedCols with columns that are either hidden, or are not of a type that we can display (i.e., either a numeric column, + // or a known length vector of doubles). + // If average is false, no averaging is done, and instead we check that there is exactly one relevant row. Otherwise, we + // add the vector columns of variable length of the list of non-averagable columns if nonAveragedCols is not null. + private static string GetFoldMetricsAsString(IHostEnvironment env, IDataView data, out string weightedMetricsString) { - Contracts.AssertValue(schema); - Contracts.AssertValue(row); - Contracts.Assert(Utils.Size(getters) == schema.ColumnCount); - Contracts.Assert(Utils.Size(vBufferGetters) == schema.ColumnCount); + EvaluateUtils.ComputeMetricsSum(env, data, 1, out int isWeightedCol, out int stratCol, + out int stratVal, out int foldCol, out var metrics, out var _, out var weightedMetrics, + out var _, out var metricNames); - // Get the names of the metrics. For R8 valued columns the metric name is the column name. For R8 vector valued columns - // the names of the metrics are the column name, followed by the slot name if it exists, or "Label_i" if it doesn't. - VBuffer names = default(VBuffer); - int metricCount = 0; - var metricNames = new List(); - for (int i = 0; i < schema.ColumnCount; i++) + var sb = new StringBuilder(); + var weightedSb = isWeightedCol >= 0 ? new StringBuilder() : null; + for (int i = 0; i < metricNames.Count; i++) { - if (schema.IsHidden(i) || ignoreCol(i)) - continue; - - var type = schema.GetColumnType(i); - var metricName = row.Schema.GetColumnName(i); - if (type.IsNumber) - { - getters[i] = RowCursorUtils.GetGetterAs(NumberType.R8, row, i); - metricNames.Add(metricName); - metricCount++; - } - else if (type.IsVector && type.ItemType == NumberType.R8) - { - if (type.VectorSize == 0) - { - ch.Warning("Vector metric '{0}' has different lengths in different folds and will not be averaged for overall results.", metricName); - continue; - } - - vBufferGetters[i] = row.GetGetter>(i); - metricCount += type.VectorSize; - var slotNamesType = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, i); - if (slotNamesType != null && slotNamesType.VectorSize == type.VectorSize && slotNamesType.ItemType.IsText) - schema.GetMetadata(MetadataUtils.Kinds.SlotNames, i, ref names); - else - { - var namesArray = names.Values; - if (Utils.Size(namesArray) < type.VectorSize) - namesArray = new DvText[type.VectorSize]; - for (int j = 0; j < type.VectorSize; j++) - namesArray[j] = new DvText(string.Format("Label_{0}", j)); - names = new VBuffer(type.VectorSize, namesArray); - } - foreach (var name in names.Items(all: true)) - metricNames.Add(string.Format("{0} {1}", metricName, name.Value)); - } + sb.Append($"{metricNames[i]}: ".PadRight(20)); + sb.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", metrics[i])); + weightedSb?.Append($"Weighted {metricNames[i]}: ".PadRight(20)); + weightedSb?.Append(string.Format(CultureInfo.InvariantCulture, "{0,7:N6}", weightedMetrics[i])); + sb.AppendLine(); + weightedSb?.AppendLine(); } - Contracts.Assert(metricNames.Count == metricCount); - return metricNames; + + weightedMetricsString = weightedSb?.ToString(); + return sb.ToString(); } // Get a string representation of a confusion table. @@ -1181,58 +1585,27 @@ private static string GetConfusionTableAsString(double[][] confusionTable, doubl /// public static void PrintOverallMetrics(IHostEnvironment env, IChannel ch, string filename, IDataView overall, int numFolds) { + EvaluateUtils.GetOverallMetricsData(env, overall, numFolds, out var sumMetrics, out var sumSqMetrics, + out var weightedSumMetrics, out var weightedSumSqMetrics, out var overallWithAvg, out var metricNames); + var sb = new StringBuilder(); sb.AppendLine(); sb.AppendLine("OVERALL RESULTS"); sb.AppendLine("---------------------------------------"); - int isWeighted; - IDataView weightedAvgMetrics = null; var nonAveragedCols = new List(); - if (overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeighted)) - sb.Append(GetMetricsAsString(env, overall, true, numFolds, out weightedAvgMetrics, true)); - IDataView avgMetrics; - sb.AppendLine(GetMetricsAsString(env, overall, false, numFolds, out avgMetrics, true, nonAveragedCols)); - env.AssertValue(avgMetrics); - sb.AppendLine("---------------------------------------"); + if (weightedSumMetrics != null) + sb.Append(GetOverallMetricsAsString(weightedSumMetrics, weightedSumSqMetrics, numFolds, true, true, metricNames)); + sb.Append(GetOverallMetricsAsString(sumMetrics, sumSqMetrics, numFolds, false, true, metricNames)); + sb.AppendLine("\n---------------------------------------"); ch.Info(sb.ToString()); if (!string.IsNullOrEmpty(filename)) { using (var file = env.CreateOutputFile(filename)) { - // idvList will contain all the dataviews that should be appended with AppendRowsDataView. - // If numResults=1, then we just save the average metrics. Otherwise, we remove all the non-metric columns - // (except for the IsWeighted column and FoldIndex column if present), and append to the average results. - var idvList = new List() { avgMetrics }; - if (weightedAvgMetrics != null) - idvList.Add(weightedAvgMetrics); - - int stratCol; - var hasStrat = overall.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); - if (numFolds > 1 || hasStrat) - { - if (Utils.Size(nonAveragedCols) > 0) - { - var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() }; - overall = new DropColumnsTransform(env, dropArgs, overall); - } - idvList.Add(overall); - } - - var summary = AppendRowsDataView.Create(env, avgMetrics.Schema, idvList.ToArray()); - - // If there are stratified results, apply a KeyToValue transform to get the stratification column - // names from the key column. - if (hasStrat) - { - var args = new KeyToValueTransform.Arguments(); - args.Column = new[] { new KeyToValueTransform.Column() { Source = MetricKinds.ColumnNames.StratCol }, }; - summary = new KeyToValueTransform(env, args, summary); - } - var saverArgs = new TextSaver.Arguments() { Dense = true, Silent = true }; - DataSaverUtils.SaveDataView(ch, new TextSaver(env, saverArgs), summary, file); + DataSaverUtils.SaveDataView(ch, new TextSaver(env, saverArgs), overallWithAvg, file); } } } diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index ada69d2d47..0547cfbef3 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -6,7 +6,6 @@ using System.Linq; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; namespace Microsoft.ML.Runtime.Data { @@ -25,13 +24,18 @@ public interface IMamlEvaluator : IEvaluator void PrintFoldResults(IChannel ch, Dictionary metrics); /// - /// Combine the aggregate metrics from multiple folds and print them to the console. If filename is not null then - /// also save the results to the specified file. If results are from multiple folds, the file will contain - /// the average results first, and then each fold result. - /// Also handle any custom kinds of custom metrics, such as p/r curves for binary classification, or group summary results - /// for ranking. + /// Combine the overall metrics from multiple folds into a single data view. /// - void PrintOverallResults(IChannel ch, string filename, params Dictionary[] metrics); + /// + /// + IDataView GetOverallResults(params IDataView[] metrics); + + /// + /// Handles custom metrics (such as p/r curves for binary classification, or group summary results for ranking) from one + /// or more folds. Implementations of this method typically creates a single data view for the custom metric and saves it + /// to a user specified file. + /// + void PrintAdditionalMetrics(IChannel ch, params Dictionary[] metrics); /// /// Create a data view containing only the columns that are saved as per-instance results by Maml commands. @@ -162,57 +166,36 @@ protected virtual void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + public IDataView GetOverallResults(params IDataView[] metrics) { - Host.CheckValue(ch, nameof(ch)); Host.CheckNonEmpty(metrics, nameof(metrics)); - PrintOverallResultsCore(ch, filename, metrics); + var overall = CombineOverallMetricsCore(metrics); + return GetOverallResultsCore(overall); } - /// - /// This method simply prints the overall metrics using EvaluateUtils.PrintOverallMetrics. - /// Override if something else is needed. - /// - protected virtual void PrintOverallResultsCore(IChannel ch, string filename, Dictionary[] metrics) + protected virtual IDataView CombineOverallMetricsCore(IDataView[] metrics) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return EvaluateUtils.CombineOverallMetrics(Host, metrics); } - protected bool TryGetOverallMetrics(Dictionary[] metrics, out IDataView overall) + protected virtual IDataView GetOverallResultsCore(IDataView overall) { - Host.AssertNonEmpty(metrics); + return overall; + } - if (metrics.Length == 1) - return metrics[0].TryGetValue(MetricKinds.OverallMetrics, out overall); + public void PrintAdditionalMetrics(IChannel ch, params Dictionary[] metrics) + { + Host.CheckValue(ch, nameof(ch)); + Host.CheckNonEmpty(metrics, nameof(metrics)); + PrintAdditionalMetricsCore(ch, metrics); + } - overall = null; - var overallList = new List(); - for (int i = 0; i < metrics.Length; i++) - { - var dict = metrics[i]; - IDataView idv; - if (!dict.TryGetValue(MetricKinds.OverallMetrics, out idv)) - return false; - - // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - // We use DvText.NA as the value of this column since for any stratified row the value will be non empty, so we can uniquely identify - // the overall row using this column. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddTextColumn, inputColType.RawType, Host, - idv, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, string.Format("Fold {0}", i), "FoldName"); - - overallList.Add(idv); - } - overall = AppendRowsDataView.Create(Host, overallList[0].Schema, overallList.ToArray()); - return true; + /// + /// This method simply prints the overall metrics using EvaluateUtils.PrintOverallMetrics. + /// Override if something else is needed. + /// + protected virtual void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) + { } public IDataTransform GetPerInstanceMetrics(RoleMappedData scoredData) diff --git a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs index 0251f0a0e7..5507176fb4 100644 --- a/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs @@ -853,7 +853,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView CombineOverallMetricsCore(IDataView[] metrics) { - ch.AssertNonEmpty(metrics); - var overallList = new List(); for (int i = 0; i < metrics.Length; i++) { - var dict = metrics[i]; - if (!dict.TryGetValue(MetricKinds.OverallMetrics, out IDataView idv)) - throw ch.Except("No overall metrics found"); - - // Add a fold-name column. We add it as a text column, since it is only used for saving the result summary file. - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - // We use DvText.NA as the value of this column since for any stratified row the value will be non empty, so we can uniquely identify - // the overall row using this column. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddTextColumn, inputColType.RawType, Host, - idv, inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, string.Format("Fold {0}", i), "FoldName"); - + var idv = metrics[i]; if (!_outputPerClass) idv = DropPerClassColumn(idv); @@ -925,14 +911,15 @@ protected override void PrintOverallResultsCore(IChannel ch, string filename, Di views[i] = idv; } } + return base.CombineOverallMetricsCore(views); + } - var overall = AppendRowsDataView.Create(Host, views[0].Schema, views.ToArray()); - + protected override IDataView GetOverallResultsCore(IDataView overall) + { // Change the name of the Top-k-accuracy column. if (_outputTopKAcc != null) overall = ChangeTopKAccColumnName(overall); - - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return overall; } private IDataView ChangeTopKAccColumnName(IDataView input) diff --git a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs index fa450286f1..6d61f6b965 100644 --- a/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/QuantileRegressionEvaluator.cs @@ -491,17 +491,9 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary[] metrics) + protected override IDataView GetOverallResultsCore(IDataView overall) { - ch.AssertNonEmpty(metrics); - - IDataView overall; - if (!TryGetOverallMetrics(metrics, out overall)) - throw ch.Except("No overall metrics found"); - - // Show only the metrics for the requested index. - overall = ExtractRelevantIndex(overall); - MetricWriter.PrintOverallMetrics(Host, ch, filename, overall, metrics.Length); + return ExtractRelevantIndex(overall); } private IDataView ExtractRelevantIndex(IDataView data) @@ -516,6 +508,8 @@ private IDataView ExtractRelevantIndex(IDataView data) var index = _index ?? type.VectorSize / 2; output = LambdaColumnMapper.Create(Host, "Quantile Regression", output, name, name, type, NumberType.R8, (ref VBuffer src, ref Double dst) => dst = src.GetItemOrDefault(index)); + output = new ChooseColumnsByIndexTransform(Host, + new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { i } }, output); } } return output; diff --git a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs index 420f875cc6..a383f835fd 100644 --- a/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs @@ -854,9 +854,10 @@ public RankerMamlEvaluator(IHostEnvironment env, Arguments args) return cols.Prepend(RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Group, groupIdCol)); } - protected override void PrintOverallResultsCore(IChannel ch, string filename, Dictionary[] metrics) + protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary[] metrics) { - base.PrintOverallResultsCore(ch, filename, metrics); + ch.AssertNonEmpty(metrics); + if (!string.IsNullOrEmpty(_groupSummaryFilename)) { IDataView gs; @@ -887,12 +888,7 @@ private bool TryGetGroupSummaryMetrics(Dictionary[] metrics, if (!metrics[i].TryGetValue(RankerEvaluator.GroupSummary, out idv)) return false; - // We use the first column in the data view as an input column to the LambdaColumnMapper, because it must have an input. - var inputColName = idv.Schema.GetColumnName(0); - var inputColType = idv.Schema.GetColumnType(0); - idv = Utils.MarshalInvoke(EvaluateUtils.AddKeyColumn, inputColType.RawType, Host, idv, - inputColName, MetricKinds.ColumnNames.FoldIndex, inputColType, metrics.Length, i + 1, "FoldIndex", - default(ValueGetter>)); + idv = EvaluateUtils.AddFoldIndex(Host, idv, i, metrics.Length); gsList.Add(idv); } gs = AppendRowsDataView.Create(Host, gsList[0].Schema, gsList.ToArray()); diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 2ca1af7159..994b01c772 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -22,6 +22,18 @@ namespace Runtime { public sealed partial class Experiment { + public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) + { + var output = new Microsoft.ML.Data.DataViewReference.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) + { + _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); + } + public Microsoft.ML.Data.IDataViewArrayConverter.Output Add(Microsoft.ML.Data.IDataViewArrayConverter input) { var output = new Microsoft.ML.Data.IDataViewArrayConverter.Output(); @@ -53,22 +65,11 @@ public Microsoft.ML.Data.TextLoader.Output Add(Microsoft.ML.Data.TextLoader inpu return output; } - public Microsoft.ML.Data.DataViewReference.Output Add(Microsoft.ML.Data.DataViewReference input) - { - var output = new Microsoft.ML.Data.DataViewReference.Output(); - Add(input, output); - return output; - } - public void Add(Microsoft.ML.Data.TextLoader input, Microsoft.ML.Data.TextLoader.Output output) { _jsonNodes.Add(Serialize("Data.TextLoader", input, output)); } - public void Add(Microsoft.ML.Data.DataViewReference input, Microsoft.ML.Data.DataViewReference.Output output) - { - _jsonNodes.Add(Serialize("Data.DataViewReference", input, output)); - } public Microsoft.ML.Models.AnomalyDetectionEvaluator.Output Add(Microsoft.ML.Models.AnomalyDetectionEvaluator input) { var output = new Microsoft.ML.Models.AnomalyDetectionEvaluator.Output(); @@ -129,6 +130,18 @@ public void Add(Microsoft.ML.Models.ClusterEvaluator input, Microsoft.ML.Models. _jsonNodes.Add(Serialize("Models.ClusterEvaluator", input, output)); } + public Microsoft.ML.Models.CrossValidationResultsCombiner.Output Add(Microsoft.ML.Models.CrossValidationResultsCombiner input) + { + var output = new Microsoft.ML.Models.CrossValidationResultsCombiner.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Models.CrossValidationResultsCombiner input, Microsoft.ML.Models.CrossValidationResultsCombiner.Output output) + { + _jsonNodes.Add(Serialize("Models.CrossValidationResultsCombiner", input, output)); + } + public Microsoft.ML.Models.CrossValidator.Output Add(Microsoft.ML.Models.CrossValidator input) { var output = new Microsoft.ML.Models.CrossValidator.Output(); @@ -1271,6 +1284,33 @@ public void Add(Microsoft.ML.Transforms.WordTokenizer input, Microsoft.ML.Transf } } + namespace Data + { + + /// + /// Pass dataview from memory to experiment + /// + public sealed partial class DataViewReference + { + + + /// + /// Pointer to IDataView in memory + /// + public Var Data { get; set; } = new Var(); + + + public sealed class Output + { + /// + /// The resulting data view + /// + public Var Data { get; set; } = new Var(); + + } + } + } + namespace Data { @@ -1346,23 +1386,6 @@ public sealed partial class TextLoader public string CustomSchema { get; set; } - public sealed class Output - { - /// - /// The resulting data view - /// - public Var Data { get; set; } = new Var(); - - } - } - - public sealed partial class DataViewReference - { - /// - /// Location of the input file - /// - public Var Data { get; set; } = new Var(); - public sealed class Output { /// @@ -1826,6 +1849,73 @@ public enum MacroUtilsTrainerKinds } + /// + /// Combine the metric data views returned from cross validation. + /// + public sealed partial class CrossValidationResultsCombiner + { + + + /// + /// Overall metrics datasets + /// + public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + + /// + /// Per instance metrics datasets + /// + public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); + + /// + /// Confusion matrix datasets + /// + public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + + /// + /// Warning datasets + /// + public ArrayVar Warnings { get; set; } = new ArrayVar(); + + /// + /// The label column name + /// + public string LabelColumn { get; set; } = "Label"; + + /// + /// Specifies the trainer kind, which determines the evaluator to be used. + /// + public Models.MacroUtilsTrainerKinds Kind { get; set; } = Models.MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer; + + + public sealed class Output + { + /// + /// Warning dataset + /// + public Var Warnings { get; set; } = new Var(); + + /// + /// Overall metrics dataset + /// + public Var OverallMetrics { get; set; } = new Var(); + + /// + /// Per instance metrics dataset + /// + public Var PerInstanceMetrics { get; set; } = new Var(); + + /// + /// Confusion matrix dataset + /// + public Var ConfusionMatrix { get; set; } = new Var(); + + } + } + } + + namespace Models + { + public sealed class CrossValidationMacroSubGraphInput { /// @@ -1902,22 +1992,22 @@ public sealed class Output /// /// Warning dataset /// - public ArrayVar Warnings { get; set; } = new ArrayVar(); + public Var Warnings { get; set; } = new Var(); /// /// Overall metrics dataset /// - public ArrayVar OverallMetrics { get; set; } = new ArrayVar(); + public Var OverallMetrics { get; set; } = new Var(); /// /// Per instance metrics dataset /// - public ArrayVar PerInstanceMetrics { get; set; } = new ArrayVar(); + public Var PerInstanceMetrics { get; set; } = new Var(); /// /// Confusion matrix dataset /// - public ArrayVar ConfusionMatrix { get; set; } = new ArrayVar(); + public Var ConfusionMatrix { get; set; } = new Var(); } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 1cb950f939..f639ebdc58 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Generic; using System.Linq; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; using Newtonsoft.Json.Linq; [assembly: LoadableClass(typeof(void), typeof(CrossValidationMacro), null, typeof(SignatureEntryPointModule), "CrossValidationMacro")] @@ -83,16 +83,52 @@ public sealed class Output public IPredictorModel[] PredictorModel; [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] - public IDataView[] Warnings; + public IDataView Warnings; [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] - public IDataView[] OverallMetrics; + public IDataView OverallMetrics; [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] - public IDataView[] PerInstanceMetrics; + public IDataView PerInstanceMetrics; [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + public IDataView ConfusionMatrix; + } + + public sealed class CombineMetricsInput + { + [Argument(ArgumentType.Multiple, HelpText = "Overall metrics datasets", SortOrder = 1)] + public IDataView[] OverallMetrics; + + [Argument(ArgumentType.Multiple, HelpText = "Per instance metrics datasets", SortOrder = 2)] + public IDataView[] PerInstanceMetrics; + + [Argument(ArgumentType.Multiple, HelpText = "Confusion matrix datasets", SortOrder = 3)] public IDataView[] ConfusionMatrix; + + [Argument(ArgumentType.Multiple, HelpText = "Warning datasets", SortOrder = 4)] + public IDataView[] Warnings; + + [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 4)] + public string LabelColumn = DefaultColumnNames.Label; + + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 0)] + public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; + } + + public sealed class CombinedOutput + { + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + public IDataView Warnings; + + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + public IDataView OverallMetrics; + + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + public IDataView PerInstanceMetrics; + + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + public IDataView ConfusionMatrix; } [TlcModule.EntryPoint(Desc = "Cross validation for general learning", Name = "Models.CrossValidator")] @@ -206,6 +242,7 @@ public static CommonOutputs.MacroOutput CrossValidate( exp.Reset(); + // Convert predictors from all folds into an array of predictors. var outModels = new ML.Data.PredictorModelArrayConverter { Model = new ArrayVar(predModelVars) @@ -214,45 +251,125 @@ public static CommonOutputs.MacroOutput CrossValidate( outModelsOutput.OutputModel.VarName = node.GetOutputVariableName(nameof(Output.PredictorModel)); exp.Add(outModels, outModelsOutput); + // Convert warnings data views from all folds into an array of data views. var warnings = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(warningsVars) }; var warningsOutput = new ML.Data.IDataViewArrayConverter.Output(); - warningsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.Warnings)); exp.Add(warnings, warningsOutput); + // Convert overall metrics data views from all folds into an array of data views. var overallMetrics = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(overallMetricsVars) }; var overallMetricsOutput = new ML.Data.IDataViewArrayConverter.Output(); - overallMetricsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics)); exp.Add(overallMetrics, overallMetricsOutput); + // Convert per instance data views from all folds into an array of data views. var instanceMetrics = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(instanceMetricsVars) }; var instanceMetricsOutput = new ML.Data.IDataViewArrayConverter.Output(); - instanceMetricsOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); exp.Add(instanceMetrics, instanceMetricsOutput); + ML.Data.IDataViewArrayConverter.Output confusionMatricesOutput = null; if (input.Kind == MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer || input.Kind == MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer) { + // Convert confusion matrix data views from all folds into an array of data views. var confusionMatrices = new ML.Data.IDataViewArrayConverter { Data = new ArrayVar(confusionMatrixVars) }; - var confusionMatricesOutput = new ML.Data.IDataViewArrayConverter.Output(); - confusionMatricesOutput.OutputData.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + confusionMatricesOutput = new ML.Data.IDataViewArrayConverter.Output(); exp.Add(confusionMatrices, confusionMatricesOutput); } - subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + var combineArgs = new CombineMetricsInput(); + combineArgs.Kind = input.Kind; + + // Set the input bindings for the CombineMetrics entry point. + var combineInputBindingMap = new Dictionary>(); + var combineInputMap = new Dictionary(); + var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics)); + combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List { overallArray }); + combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName)); + var combinePerInstArray = new SimpleParameterBinding(nameof(combineArgs.PerInstanceMetrics)); + combineInputBindingMap.Add(nameof(combineArgs.PerInstanceMetrics), new List { combinePerInstArray }); + combineInputMap.Add(combinePerInstArray, new SimpleVariableBinding(instanceMetricsOutput.OutputData.VarName)); + if (confusionMatricesOutput != null) + { + var combineConfArray = new SimpleParameterBinding(nameof(combineArgs.ConfusionMatrix)); + combineInputBindingMap.Add(nameof(combineArgs.ConfusionMatrix), new List { combineConfArray }); + combineInputMap.Add(combineConfArray, new SimpleVariableBinding(confusionMatricesOutput.OutputData.VarName)); + } + var combineOutputMap = new Dictionary(); + var combineWarningVar = new Var(); + combineWarningVar.VarName = node.GetOutputVariableName(nameof(Output.Warnings)); + combineOutputMap.Add(nameof(Output.Warnings), combineWarningVar.VarName); + var combineOverallMetric = new Var(); + combineOverallMetric.VarName = node.GetOutputVariableName(nameof(Output.OverallMetrics)); + combineOutputMap.Add(nameof(Output.OverallMetrics), combineOverallMetric.VarName); + var combineInstanceMetric = new Var(); + combineInstanceMetric.VarName = node.GetOutputVariableName(nameof(Output.PerInstanceMetrics)); + combineOutputMap.Add(nameof(Output.PerInstanceMetrics), combineInstanceMetric.VarName); + var combineConfusionMatrix = new Var(); + combineConfusionMatrix.VarName = node.GetOutputVariableName(nameof(Output.ConfusionMatrix)); + combineOutputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), combineConfusionMatrix.VarName); + + subGraphNodes.AddRange(EntryPointNode.ValidateNodes(env, node.Context, exp.GetNodes(), node.Catalog)); + subGraphNodes.Add(EntryPointNode.Create(env, "Models.CrossValidationResultsCombiner", combineArgs, node.Catalog, node.Context, combineInputBindingMap, combineInputMap, combineOutputMap)); return new CommonOutputs.MacroOutput() { Nodes = subGraphNodes }; } + + [TlcModule.EntryPoint(Desc = "Combine the metric data views returned from cross validation.", Name = "Models.CrossValidationResultsCombiner")] + public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetricsInput input) + { + var eval = GetEvaluator(env, input.Kind); + var perInst = EvaluateUtils.CombinePerInstanceDataViews(env, eval, true, true, input.PerInstanceMetrics.Select( + idv => RoleMappedData.Create(idv, RoleMappedSchema.CreatePair(RoleMappedSchema.ColumnRole.Label, input.LabelColumn))).ToArray(), + out var variableSizeVectorColumnNames); + + env.Assert(Utils.Size(perInst) == 1); + + var overall = eval.GetOverallResults(input.OverallMetrics); + overall = EvaluateUtils.CombineFoldMetricsDataViews(env, overall, input.OverallMetrics.Length); + + IDataView conf = null; + if (Utils.Size(input.ConfusionMatrix) > 0) + { + EvaluateUtils.ReconcileSlotNames(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8); + conf = AppendRowsDataView.Create(env, input.ConfusionMatrix[0].Schema, input.ConfusionMatrix); + } + + return new CombinedOutput() { PerInstanceMetrics = perInst[0], OverallMetrics = overall, ConfusionMatrix = conf }; + } + + private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.TrainerKinds kind) + { + switch (kind) + { + case MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer: + return new BinaryClassifierMamlEvaluator(env, new BinaryClassifierMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiClassClassifierTrainer: + return new MultiClassMamlEvaluator(env, new MultiClassMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRegressorTrainer: + return new RegressionMamlEvaluator(env, new RegressionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureRankerTrainer: + return new RankerMamlEvaluator(env, new RankerMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureAnomalyDetectorTrainer: + return new AnomalyDetectionMamlEvaluator(env, new AnomalyDetectionMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureClusteringTrainer: + return new ClusteringMamlEvaluator(env, new ClusteringMamlEvaluator.Arguments()); + case MacroUtils.TrainerKinds.SignatureMultiOutputRegressorTrainer: + return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); + default: + throw env.Except($"Trainer kind {kind} does not have an evaluator"); + } + } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index ee4f56c260..09b63a1a31 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -269,11 +269,11 @@ public void TestCrossValidationBinaryMacro() } } - [Fact(Skip = "Missing data set. See https://github.com/dotnet/machinelearning/issues/3")] + [Fact] public void TestCrossValidationMacro() { - var dataPath = GetDataPath(@"housing.txt"); - using (var env = new TlcEnvironment()) + var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); + using (var env = new TlcEnvironment(42)) { var subGraph = env.CreateExperiment(); @@ -312,19 +312,48 @@ public void TestCrossValidationMacro() experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); experiment.Run(); - var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); var schema = data.Schema; var b = schema.TryGetColumnIndex("L1(avg)", out int metricCol); Assert.True(b); - using (var cursor = data.GetRowCursor(col => col == metricCol)) + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) { var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. b = cursor.MoveNext(); Assert.True(b); + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(0.096, stdev, 3); + + double sum = 0; double val = 0; - getter(ref val); - Assert.Equal(3.32, val, 1); + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + } + Assert.Equal(avg, sum / 2); b = cursor.MoveNext(); Assert.False(b); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index e8be6c0370..8afb8c97a6 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -2182,9 +2182,9 @@ public void EntryPointChainedCrossValMacros() model = runner.GetOutput("model2"); Assert.NotNull(model[0]); - var metrics = runner.GetOutput("OverallMetrics"); - Assert.NotNull(metrics[0]); - using (var cursor = metrics[0].GetRowCursor(col => true)) + var metrics = runner.GetOutput("OverallMetrics"); + Assert.NotNull(metrics); + using (var cursor = metrics.GetRowCursor(col => true)) { Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); var aucGetter = cursor.GetGetter(aucCol); @@ -2194,9 +2194,9 @@ public void EntryPointChainedCrossValMacros() Assert.True(auc > 0.99); } - metrics = runner.GetOutput("OverallMetrics2"); - Assert.NotNull(metrics[0]); - using (var cursor = metrics[0].GetRowCursor(col => true)) + metrics = runner.GetOutput("OverallMetrics2"); + Assert.NotNull(metrics); + using (var cursor = metrics.GetRowCursor(col => true)) { Assert.True(cursor.Schema.TryGetColumnIndex("AUC", out int aucCol)); var aucGetter = cursor.GetGetter(aucCol); From 3543ddb154219ff112061c23a1aeea5e537ba25a Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 22 May 2018 14:59:40 -0700 Subject: [PATCH 05/18] Add unit test for the multi-class case --- .../EntryPoints/CrossValidationMacro.cs | 18 ++- .../UnitTests/TestCSharpApi.cs | 126 ++++++++++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index f639ebdc58..700137f2b5 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -343,7 +343,23 @@ public static CombinedOutput CombineMetrics(IHostEnvironment env, CombineMetrics if (Utils.Size(input.ConfusionMatrix) > 0) { EvaluateUtils.ReconcileSlotNames(env, input.ConfusionMatrix, MetricKinds.ColumnNames.Count, NumberType.R8); - conf = AppendRowsDataView.Create(env, input.ConfusionMatrix[0].Schema, input.ConfusionMatrix); + + for (int i = 0; i < input.ConfusionMatrix.Length; i++) + { + var idv = input.ConfusionMatrix[i]; + // Find the old Count column and drop it. + for (int col = 0; col < idv.Schema.ColumnCount; col++) + { + if (idv.Schema.IsHidden(col) && + idv.Schema.GetColumnName(col).Equals(MetricKinds.ColumnNames.Count)) + { + input.ConfusionMatrix[i] = new ChooseColumnsByIndexTransform(env, + new ChooseColumnsByIndexTransform.Arguments() { Drop = true, Index = new[] { col } }, idv); + break; + } + } + } + conf = EvaluateUtils.CombineOverallMetrics(env, input.ConfusionMatrix); } return new CombinedOutput() { PerInstanceMetrics = perInst[0], OverallMetrics = overall, ConfusionMatrix = conf }; diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 09b63a1a31..e2a6b17cfc 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -9,6 +9,7 @@ using Microsoft.ML.TestFramework; using Xunit; using Xunit.Abstractions; +using System.Linq; /*using Categorical = Microsoft.ML.Transforms; using Commands = Microsoft.ML.Transforms; using Evaluate = Microsoft.ML; @@ -359,5 +360,130 @@ public void TestCrossValidationMacro() } } } + + [Fact] + public void TestCrossValidationMacroWithMultiClass() + { + var dataPath = GetDataPath(@"Train-Tiny-28x28.txt"); + using (var env = new TlcEnvironment(42)) + { + var subGraph = env.CreateExperiment(); + + var nop = new ML.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); + + var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(); + var importOutput = experiment.Add(importInput); + + var crossValidate = new ML.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + Kind = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer, + TransformModel = null + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("Accuracy(micro-avg)", out int metricCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out int foldCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol || col == foldCol)) + { + var getter = cursor.GetGetter(metricCol); + var foldGetter = cursor.GetGetter(foldCol); + DvText fold = default; + + // Get the verage. + b = cursor.MoveNext(); + Assert.True(b); + double avg = 0; + getter(ref avg); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Average")); + + // Get the standard deviation. + b = cursor.MoveNext(); + Assert.True(b); + double stdev = 0; + getter(ref stdev); + foldGetter(ref fold); + Assert.True(fold.EqualsStr("Standard Deviation")); + Assert.Equal(0.025, stdev, 3); + + double sum = 0; + double val = 0; + for (int f = 0; f < 2; f++) + { + b = cursor.MoveNext(); + Assert.True(b); + getter(ref val); + foldGetter(ref fold); + sum += val; + Assert.True(fold.EqualsStr("Fold " + f)); + } + Assert.Equal(avg, sum / 2); + b = cursor.MoveNext(); + Assert.False(b); + } + + var confusion = experiment.GetOutput(crossValidateOutput.ConfusionMatrix); + schema = confusion.Schema; + b = schema.TryGetColumnIndex("Count", out int countCol); + Assert.True(b); + b = schema.TryGetColumnIndex("Fold Index", out foldCol); + Assert.True(b); + var type = schema.GetMetadataTypeOrNull(MetadataUtils.Kinds.SlotNames, countCol); + Assert.True(type != null && type.ItemType.IsText && type.VectorSize == 10); + var slotNames = default(VBuffer); + schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countCol, ref slotNames); + Assert.True(slotNames.Values.Select((s, i) => s.EqualsStr(i.ToString())).All(x => x)); + using (var curs = confusion.GetRowCursor(col => true)) + { + var countGetter = curs.GetGetter>(countCol); + var foldGetter = curs.GetGetter(foldCol); + var confCount = default(VBuffer); + var foldIndex = default(DvText); + int rowCount = 0; + var foldCur = "Fold 0"; + while (curs.MoveNext()) + { + countGetter(ref confCount); + foldGetter(ref foldIndex); + rowCount++; + Assert.True(foldIndex.EqualsStr(foldCur)); + if (rowCount == 10) + { + rowCount = 0; + foldCur = "Fold 1"; + } + } + Assert.Equal(0, rowCount); + } + } + } } } From 5676a27f797a3a5892b3f318b4104b6878536ca7 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 22 May 2018 15:32:52 -0700 Subject: [PATCH 06/18] Update CSharpApi.cs after merge --- src/Microsoft.ML/CSharpApi.cs | 2 +- test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index b40dfa8afe..e09988782b 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2135,7 +2135,7 @@ public sealed class Output namespace Models { - public sealed class CrossValidationMacroSubGraphInput + public sealed partial class CrossValidationMacroSubGraphInput { /// /// The data to be used for training diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 59c6d8f6c6..c6a94dd679 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -1,5 +1,6 @@  + From e9309040c9872d4daec284bc6f68a10db13a7493 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Tue, 22 May 2018 15:52:36 -0700 Subject: [PATCH 07/18] train-test big fix. --- src/Microsoft.ML/CSharpApi.cs | 10 --- src/Microsoft.ML/Models/CrossValidator.cs | 64 +++++++++++++++---- .../EntryPoints/CrossValidationMacro.cs | 24 +++---- .../Runtime/EntryPoints/TrainTestMacro.cs | 37 ++++++++--- .../Scenarios/SentimentPredictionTests.cs | 5 +- 5 files changed, 90 insertions(+), 50 deletions(-) diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 6865861de7..318d711194 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2118,11 +2118,6 @@ public sealed partial class CrossValidationMacroSubGraphOutput /// public Var TransformModel { get; set; } = new Var(); - /// - /// Transform data - /// - public Var TransformData { get; set; } = new Var(); - /// /// Indicates to use transform model instead of predictor model. /// @@ -3309,11 +3304,6 @@ public sealed partial class TrainTestMacroSubGraphOutput /// public Var TransformModel { get; set; } = new Var(); - /// - /// Transform data - /// - public Var TransformData { get; set; } = new Var(); - /// /// Indicates to use transform model instead of predictor model. /// diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 6841c04808..c6aa52fb32 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Models { public sealed partial class CrossValidator { - public PredictionModel CrossValidate(LearningPipeline pipeline) + public CrossValidationOutput[] CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() { @@ -24,7 +24,6 @@ public PredictionModel CrossValidate(LearningP Var lastTransformModel = null; Var firstPipelineDataStep = null; Var firstModel = null; - Var lastData = null; ILearningPipelineItem firstTransform = null; foreach (ILearningPipelineItem currentItem in pipeline) { @@ -73,7 +72,6 @@ public PredictionModel CrossValidate(LearningP var scorerOutput = subGraph.Add(scorer); lastTransformModel = scorerOutput.ScoringTransform; - lastData = scorerOutput.ScoredData; step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); transformModels.Clear(); } @@ -91,7 +89,6 @@ public PredictionModel CrossValidate(LearningP var modelOutput = subGraph.Add(modelInput); lastTransformModel = modelOutput.OutputModel; - lastData = modelOutput.Data; } var experiment = environment.CreateExperiment(); @@ -101,9 +98,7 @@ public PredictionModel CrossValidate(LearningP Nodes = subGraph; TransformModel = null; Inputs.Data = firstTransform.GetInputData(); - Outputs.Model = null; Outputs.TransformModel = lastTransformModel; - Outputs.TransformData = lastData; Outputs.UseTransformModel = true; var crossValidateOutput = experiment.Add(this); experiment.Compile(); @@ -113,19 +108,62 @@ public PredictionModel CrossValidate(LearningP } experiment.Run(); - ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[0]); - BatchPredictionEngine predictor; - using (var memoryStream = new MemoryStream()) + + CrossValidationOutput[] cvo = new CrossValidationOutput[NumFolds]; + + for (int Index = 0; Index < NumFolds; Index++) { - model.Save(environment, memoryStream); + cvo[Index] = new CrossValidationOutput(); + + if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) + { + cvo[Index].BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index])); + } + else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) + { + cvo[Index].ClassificationMetrics = ClassificationMetrics.FromMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index])); + } + else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) + { + cvo[Index].RegressionMetrics = RegressionMetrics.FromOverallMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics[Index])); + } - memoryStream.Position = 0; + ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[Index]); + BatchPredictionEngine predictor; + using (var memoryStream = new MemoryStream()) + { + model.Save(environment, memoryStream); + + memoryStream.Position = 0; - predictor = environment.CreateBatchPredictionEngine(memoryStream); + predictor = environment.CreateBatchPredictionEngine(memoryStream); - return new PredictionModel(predictor, memoryStream); + cvo[Index].PredictorModel = new PredictionModel(predictor, memoryStream); + } } + + return cvo; } } } + + public class CrossValidationOutput + where TInput : class + where TOutput : class, new() + { + public BinaryClassificationMetrics BinaryClassificationMetrics; + public ClassificationMetrics ClassificationMetrics; + public RegressionMetrics RegressionMetrics; + public PredictionModel PredictorModel; + + //REVIEW: Add warnings and per instance results. + } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index db086ba014..af6bd3756b 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -29,16 +29,13 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { - [Argument(ArgumentType.Required, HelpText = "The model", SortOrder = 1)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The model", SortOrder = 1)] public Var Model; - - [Argument(ArgumentType.Required, HelpText = "The transform model", SortOrder = 2)] + + [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)] public Var TransformModel; - [Argument(ArgumentType.Required, HelpText = "Transform data", SortOrder = 3)] - public Var TransformData; - - [Argument(ArgumentType.Required, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 4)] + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 3)] public bool UseTransformModel = false; } @@ -169,26 +166,25 @@ public static CommonOutputs.MacroOutput CrossValidate( VarName = mapping[input.Inputs.Data.VarName] }; - if (input.Outputs.Model != null) + if (input.Outputs.Model != null && mapping.ContainsKey(input.Outputs.Model.VarName)) { args.Outputs.Model = new Var { VarName = mapping[input.Outputs.Model.VarName] }; } + else + args.Outputs.Model = null; - if (input.Outputs.TransformModel != null) + if (input.Outputs.TransformModel != null && mapping.ContainsKey(input.Outputs.TransformModel.VarName)) { args.Outputs.TransformModel = new Var { VarName = mapping[input.Outputs.TransformModel.VarName] }; } - - args.Outputs.TransformData = new Var - { - VarName = mapping[input.Outputs.TransformData.VarName] - }; + else + args.Outputs.TransformModel = null; args.Outputs.UseTransformModel = input.Outputs.UseTransformModel; diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index 1eb7d08908..61b376e7f9 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -24,16 +24,13 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { - [Argument(ArgumentType.Required, HelpText = "The model", SortOrder = 1)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The model", SortOrder = 1)] public Var Model; - - [Argument(ArgumentType.Required, HelpText = "Transform model", SortOrder = 2)] + + [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] public Var TransformModel; - [Argument(ArgumentType.Required, HelpText = "Transform data", SortOrder = 3)] - public Var TransformData; - - [Argument(ArgumentType.Required, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 4)] + [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 3)] public bool UseTransformModel = false; } @@ -155,6 +152,7 @@ public static CommonOutputs.MacroOutput TrainTest( var exp = new Experiment(env); DatasetScorer.Output scoreNodeOutput = null; + ML.Models.DatasetTransformer.Output datasetTransformNodeOutput = null; if (input.Outputs.UseTransformModel) { //combine the predictor model with any potential transfrom model passed from the outer graph @@ -172,6 +170,14 @@ public static CommonOutputs.MacroOutput TrainTest( var modelCombineOutput = exp.Add(modelCombine); outputVarName = modelCombineOutput.OutputModel.VarName; } + + var datasetTransformerNode = new Models.DatasetTransformer + { + Data = { VarName = testingVar.ToJson() }, + TransformModel = { VarName = outputVarName } + }; + + datasetTransformNodeOutput = exp.Add(datasetTransformerNode); } else { @@ -215,7 +221,18 @@ public static CommonOutputs.MacroOutput TrainTest( if (input.IncludeTrainingMetrics) { DatasetScorer.Output scoreNodeTrainingOutput = null; - if (!input.Outputs.UseTransformModel) + ML.Models.DatasetTransformer.Output datasetTransformNodeTrainingOutput = null; + if (input.Outputs.UseTransformModel) + { + var datasetTransformerNode = new Models.DatasetTransformer + { + Data = { VarName = testingVar.ToJson() }, + TransformModel = { VarName = outputVarName } + }; + + datasetTransformNodeTrainingOutput = exp.Add(datasetTransformerNode); + } + else { // Add the scoring node for training. var scoreNodeTraining = new DatasetScorer @@ -235,7 +252,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNodeTraining = evalInputOutputTraining.Item1; var evalOutputTraining = evalInputOutputTraining.Item2; - evalNodeTraining.Data.VarName = input.Outputs.UseTransformModel ? input.Outputs.TransformData.VarName : + evalNodeTraining.Data.VarName = input.Outputs.UseTransformModel ? datasetTransformNodeTrainingOutput.OutputData.VarName : scoreNodeTrainingOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName)) @@ -259,7 +276,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutput = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNode = evalInputOutput.Item1; var evalOutput = evalInputOutput.Item2; - evalNode.Data.VarName = input.Outputs.UseTransformModel ? input.Outputs.TransformData.VarName : scoreNodeOutput.ScoredData.VarName; + evalNode.Data.VarName = input.Outputs.UseTransformModel ? datasetTransformNodeOutput.OutputData.VarName : scoreNodeOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out outVariableName)) evalOutput.Warnings.VarName = outVariableName; diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 80947644e9..a27a85488a 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -50,7 +50,7 @@ public void TrainAndPredictSentimentModelTest() } } }); - + pipeline.Add(new TextFeaturizer("Features", "SentimentText") { KeepDiacritics = false, @@ -67,9 +67,8 @@ public void TrainAndPredictSentimentModelTest() pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); PredictionModel model = pipeline.Train(); - IEnumerable sentiments = new[] - { + { new SentimentData { SentimentText = "Please refrain from adding nonsense to Wikipedia." From ae376a038e60a76d5c96d58633bff1f4f5bc9a53 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Tue, 22 May 2018 15:58:16 -0700 Subject: [PATCH 08/18] Fix unit test --- test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 42998b8cc0..1e7807404e 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -249,7 +249,7 @@ public void TestCrossValidationBinaryMacro() var crossValidateOutput = experiment.Add(crossValidateBinary); experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + importInput.SetInput(env, experiment); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); @@ -311,7 +311,7 @@ public void TestCrossValidationMacro() var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + importInput.SetInput(env, experiment); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); @@ -387,7 +387,7 @@ public void TestCrossValidationMacroWithMultiClass() var modelCombineOutput = subGraph.Add(modelCombine); var experiment = env.CreateExperiment(); - var importInput = new ML.Data.TextLoader(); + var importInput = new ML.Data.TextLoader(dataPath); var importOutput = experiment.Add(importInput); var crossValidate = new ML.Models.CrossValidator @@ -402,7 +402,7 @@ public void TestCrossValidationMacroWithMultiClass() var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); - experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + importInput.SetInput(env, experiment); experiment.Run(); var data = experiment.GetOutput(crossValidateOutput.OverallMetrics); From 0ef9ce16f0e250eac6ffc6470e5722cf1b6fd202 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 23 May 2018 02:06:24 -0700 Subject: [PATCH 09/18] Merge PR# 207. --- .../Common/EntryPoints/core_ep-list.tsv | 1 + .../Common/EntryPoints/core_manifest.json | 88 ++++++++++++++++++- src/Microsoft.ML.Core/Data/ITransformModel.cs | 2 - .../EntryPoints/TransformModel.cs | 5 -- src/Microsoft.ML/CSharpApi.cs | 5 -- .../Models/BinaryClassificationEvaluator.cs | 3 +- .../Models/BinaryClassificationMetrics.cs | 54 +++++++----- .../Models/ClassificationEvaluator.cs | 3 +- .../Models/ClassificationMetrics.cs | 45 ++++++---- src/Microsoft.ML/Models/ConfusionMatrix.cs | 36 +++++--- src/Microsoft.ML/Models/CrossValidator.cs | 41 +++++---- .../Models/RegressionEvaluator.cs | 3 +- src/Microsoft.ML/Models/RegressionMetrics.cs | 32 +++---- .../Runtime/EntryPoints/ModelOperations.cs | 5 +- ...sticDualCoordinateAscentClassifierBench.cs | 2 +- test/Microsoft.ML.Tests/CSharpCodeGen.cs | 3 +- .../HousePriceTrainAndPredictionTests.cs | 3 +- .../Scenarios/IrisPlantClassificationTests.cs | 3 +- ...PlantClassificationWithStringLabelTests.cs | 3 +- .../Scenarios/SentimentPredictionTests.cs | 4 +- 20 files changed, 223 insertions(+), 118 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index 7fc82434b4..da227b16ff 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -3,6 +3,7 @@ Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runt Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output +Data.TransformModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output] diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index d8b3fdccec..ab2771065c 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -469,6 +469,35 @@ "ILearningPipelineLoader" ] }, + { + "Name": "Data.TransformModelArrayConverter", + "Desc": "Create and array variable", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "TransformModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The models", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The model array" + } + ] + }, { "Name": "Models.AnomalyDetectionEvaluator", "Desc": "Evaluates an anomaly detection scored dataset.", @@ -1411,9 +1440,28 @@ "Name": "Model", "Type": "PredictorModel", "Desc": "The model", - "Required": true, + "Required": false, "SortOrder": 1.0, - "IsNullable": false + "IsNullable": false, + "Default": null + }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "The transform model", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "UseTransformModel", + "Type": "Bool", + "Desc": "Indicates to use transform model instead of predictor model.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false } ] }, @@ -1476,6 +1524,14 @@ }, "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." }, + { + "Name": "TransformModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." + }, { "Name": "Warnings", "Type": "DataView", @@ -3002,9 +3058,28 @@ "Name": "Model", "Type": "PredictorModel", "Desc": "The model", - "Required": true, + "Required": false, "SortOrder": 1.0, - "IsNullable": false + "IsNullable": false, + "Default": null + }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "Transform model", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "UseTransformModel", + "Type": "Bool", + "Desc": "Indicates to use transform model instead of predictor model.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false } ] }, @@ -3058,6 +3133,11 @@ "Type": "PredictorModel", "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." + }, { "Name": "Warnings", "Type": "DataView", diff --git a/src/Microsoft.ML.Core/Data/ITransformModel.cs b/src/Microsoft.ML.Core/Data/ITransformModel.cs index ec249ce768..ccb65d43ab 100644 --- a/src/Microsoft.ML.Core/Data/ITransformModel.cs +++ b/src/Microsoft.ML.Core/Data/ITransformModel.cs @@ -25,8 +25,6 @@ public interface ITransformModel /// ISchema InputSchema { get; } - IDataView Data { get; } - /// /// Apply the transform(s) in the model to the given input data. /// diff --git a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs index acbce34b24..b840529e77 100644 --- a/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs +++ b/src/Microsoft.ML.Data/EntryPoints/TransformModel.cs @@ -44,11 +44,6 @@ public ISchema InputSchema get { return _schemaRoot; } } - public IDataView Data - { - get { return _chain; } - } - /// /// Create a TransformModel containing the transforms from "result" back to "input". /// diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 8ef1ae4206..10c078a891 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -11582,11 +11582,6 @@ public sealed class Output /// public Var OutputModel { get; set; } = new Var(); - /// - /// Data - /// - public Var Data { get; set; } = new Var(); - } } } diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs index e0a4eae826..b916a9ab57 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -23,7 +24,7 @@ public sealed partial class BinaryClassificationEvaluator /// /// A BinaryClassificationMetrics instance that describes how well the model performed against the test data. /// - public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs index aa3a94f3a9..fe95eac312 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using System; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -18,7 +19,7 @@ private BinaryClassificationMetrics() { } - internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -31,28 +32,37 @@ internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, ID throw env.Except("The overall RegressionMetrics didn't have any rows."); } - SerializationClass metrics = enumerator.Current; - - if (enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics contained more than 1 row."); - } - - return new BinaryClassificationMetrics() + List metrics = new List(); + var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); + do { - Auc = metrics.Auc, - Accuracy = metrics.Accuracy, - PositivePrecision = metrics.PositivePrecision, - PositiveRecall = metrics.PositiveRecall, - NegativePrecision = metrics.NegativePrecision, - NegativeRecall = metrics.NegativeRecall, - LogLoss = metrics.LogLoss, - LogLossReduction = metrics.LogLossReduction, - Entropy = metrics.Entropy, - F1Score = metrics.F1Score, - Auprc = metrics.Auprc, - ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix), - }; + SerializationClass metric = enumerator.Current; + + if (!confusionMatrices.MoveNext()) + { + throw env.Except("Confusion matrices didn't have enough matrices."); + } + + metrics.Add( + new BinaryClassificationMetrics() + { + Auc = metric.Auc, + Accuracy = metric.Accuracy, + PositivePrecision = metric.PositivePrecision, + PositiveRecall = metric.PositiveRecall, + NegativePrecision = metric.NegativePrecision, + NegativeRecall = metric.NegativeRecall, + LogLoss = metric.LogLoss, + LogLossReduction = metric.LogLossReduction, + Entropy = metric.Entropy, + F1Score = metric.F1Score, + Auprc = metric.Auprc, + ConfusionMatrix = confusionMatrices.Current, + }); + + } while (enumerator.MoveNext()); + + return metrics; } /// diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index c8bec8642f..33799c0c78 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -23,7 +24,7 @@ public sealed partial class ClassificationEvaluator /// /// A ClassificationMetrics instance that describes how well the model performed against the test data. /// - public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs index 81c0f91d7b..0fbbba602e 100644 --- a/src/Microsoft.ML/Models/ClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -17,7 +18,7 @@ private ClassificationMetrics() { } - internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -29,24 +30,32 @@ internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataVie { throw env.Except("The overall RegressionMetrics didn't have any rows."); } - - SerializationClass metrics = enumerator.Current; - - if (enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics contained more than 1 row."); - } - - return new ClassificationMetrics() + + List metrics = new List(); + var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); + do { - AccuracyMicro = metrics.AccuracyMicro, - AccuracyMacro = metrics.AccuracyMacro, - LogLoss = metrics.LogLoss, - LogLossReduction = metrics.LogLossReduction, - TopKAccuracy = metrics.TopKAccuracy, - PerClassLogLoss = metrics.PerClassLogLoss, - ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix) - }; + if (!confusionMatrices.MoveNext()) + { + throw env.Except("Confusion matrices didn't have enough matrices."); + } + + SerializationClass metric = enumerator.Current; + metrics.Add( + new ClassificationMetrics() + { + AccuracyMicro = metric.AccuracyMicro, + AccuracyMacro = metric.AccuracyMacro, + LogLoss = metric.LogLoss, + LogLossReduction = metric.LogLossReduction, + TopKAccuracy = metric.TopKAccuracy, + PerClassLogLoss = metric.PerClassLogLoss, + ConfusionMatrix = confusionMatrices.Current + }); + + } while (enumerator.MoveNext()); + + return metrics; } /// diff --git a/src/Microsoft.ML/Models/ConfusionMatrix.cs b/src/Microsoft.ML/Models/ConfusionMatrix.cs index 2040fc8331..72aa5061dc 100644 --- a/src/Microsoft.ML/Models/ConfusionMatrix.cs +++ b/src/Microsoft.ML/Models/ConfusionMatrix.cs @@ -41,7 +41,7 @@ private ConfusionMatrix(double[,] elements, string[] classNames) }); } - internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusionMatrix) + internal static List Create(IHostEnvironment env, IDataView confusionMatrix) { Contracts.AssertValue(env); env.AssertValue(confusionMatrix); @@ -51,18 +51,28 @@ internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusion env.Except($"ConfusionMatrix data view did not contain a {nameof(MetricKinds.ColumnNames.Count)} column."); } + IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn); + var slots = default(VBuffer); + confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots); + string[] classNames = new string[slots.Count]; + for (int i = 0; i < slots.Count; i++) + { + classNames[i] = slots.Values[i].ToString(); + } + ColumnType type = confusionMatrix.Schema.GetColumnType(countColumn); env.Assert(type.IsVector); - - double[,] elements = new double[type.VectorSize, type.VectorSize]; - - IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn); ValueGetter> countGetter = cursor.GetGetter>(countColumn); VBuffer countValues = default; - + List confusionMatrices = new List(); + int valuesRowIndex = 0; + double[,] elements = null; while (cursor.MoveNext()) { + if(valuesRowIndex == 0) + elements = new double[type.VectorSize, type.VectorSize]; + countGetter(ref countValues); for (int i = 0; i < countValues.Length; i++) { @@ -70,17 +80,15 @@ internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusion } valuesRowIndex++; - } - var slots = default(VBuffer); - confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots); - string[] classNames = new string[slots.Count]; - for (int i = 0; i < slots.Count; i++) - { - classNames[i] = slots.Values[i].ToString(); + if(valuesRowIndex == type.VectorSize) + { + valuesRowIndex = 0; + confusionMatrices.Add(new ConfusionMatrix(elements, classNames)); + } } - return new ConfusionMatrix(elements, classNames); + return confusionMatrices; } /// diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index c6aa52fb32..2ef99993cb 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -11,7 +11,7 @@ namespace Microsoft.ML.Models { public sealed partial class CrossValidator { - public CrossValidationOutput[] CrossValidate(LearningPipeline pipeline) + public CrossValidationOutput CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() { @@ -109,31 +109,35 @@ public CrossValidationOutput[] CrossValidate(L experiment.Run(); - CrossValidationOutput[] cvo = new CrossValidationOutput[NumFolds]; + CrossValidationOutput cvo = new CrossValidationOutput(); + cvo.PredictorModels = new PredictionModel[NumFolds]; for (int Index = 0; Index < NumFolds; Index++) { - cvo[Index] = new CrossValidationOutput(); if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) { - cvo[Index].BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( + cvo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( environment, - experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index])); + experiment.GetOutput(crossValidateOutput.OverallMetrics), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix)); } else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { - cvo[Index].ClassificationMetrics = ClassificationMetrics.FromMetrics( + cvo.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, - experiment.GetOutput(crossValidateOutput.OverallMetrics[Index]), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix[Index])); + experiment.GetOutput(crossValidateOutput.OverallMetrics), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix)); } else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) { - cvo[Index].RegressionMetrics = RegressionMetrics.FromOverallMetrics( + cvo.RegressionMetrics = RegressionMetrics.FromOverallMetrics( environment, - experiment.GetOutput(crossValidateOutput.OverallMetrics[Index])); + experiment.GetOutput(crossValidateOutput.OverallMetrics)); + } + else + { + //Implement metrics for ranking, clustering and anomaly detection. } ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[Index]); @@ -146,7 +150,7 @@ public CrossValidationOutput[] CrossValidate(L predictor = environment.CreateBatchPredictionEngine(memoryStream); - cvo[Index].PredictorModel = new PredictionModel(predictor, memoryStream); + cvo.PredictorModels[Index] = new PredictionModel(predictor, memoryStream); } } @@ -159,11 +163,12 @@ public class CrossValidationOutput where TInput : class where TOutput : class, new() { - public BinaryClassificationMetrics BinaryClassificationMetrics; - public ClassificationMetrics ClassificationMetrics; - public RegressionMetrics RegressionMetrics; - public PredictionModel PredictorModel; - - //REVIEW: Add warnings and per instance results. + public List BinaryClassificationMetrics; + public List ClassificationMetrics; + public List RegressionMetrics; + public PredictionModel[] PredictorModels; + + //REVIEW: Add warnings and per instance results and implement + //metrics for ranking, clustering and anomaly detection. } } diff --git a/src/Microsoft.ML/Models/RegressionEvaluator.cs b/src/Microsoft.ML/Models/RegressionEvaluator.cs index 8c2daa53f0..c55f4f3335 100644 --- a/src/Microsoft.ML/Models/RegressionEvaluator.cs +++ b/src/Microsoft.ML/Models/RegressionEvaluator.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -23,7 +24,7 @@ public sealed partial class RegressionEvaluator /// /// A RegressionMetrics instance that describes how well the model performed against the test data. /// - public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs index f5a5122242..43eabc3bec 100644 --- a/src/Microsoft.ML/Models/RegressionMetrics.cs +++ b/src/Microsoft.ML/Models/RegressionMetrics.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using System; +using System.Collections.Generic; namespace Microsoft.ML.Models { @@ -18,7 +19,7 @@ private RegressionMetrics() { } - internal static RegressionMetrics FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics) + internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -30,21 +31,22 @@ internal static RegressionMetrics FromOverallMetrics(IHostEnvironment env, IData throw env.Except("The overall RegressionMetrics didn't have any rows."); } - SerializationClass metrics = enumerator.Current; - - if (enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics contained more than 1 row."); - } - - return new RegressionMetrics() + List metrics = new List(); + do { - L1 = metrics.L1, - L2 = metrics.L2, - Rms = metrics.Rms, - LossFn = metrics.LossFn, - RSquared = metrics.RSquared, - }; + SerializationClass metric = enumerator.Current; + metrics.Add(new RegressionMetrics() + { + L1 = metric.L1, + L2 = metric.L2, + Rms = metric.Rms, + LossFn = metric.LossFn, + RSquared = metric.RSquared, + }); + + } while (enumerator.MoveNext()); + + return metrics; } /// diff --git a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs index 9f7cbb727b..fa34cfd7ac 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/ModelOperations.cs @@ -23,9 +23,6 @@ public sealed class CombineTransformModelsOutput { [TlcModule.Output(Desc = "Combined model", SortOrder = 1)] public ITransformModel OutputModel; - - [TlcModule.Output(Desc = "Data", SortOrder = 2)] - public IDataView Data; } public sealed class PredictorModelInput @@ -92,7 +89,7 @@ public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironme for (int i = input.Models.Length - 2; i >= 0; i--) model = model.Apply(env, input.Models[i]); - return new CombineTransformModelsOutput { OutputModel = model, Data = model.Data }; + return new CombineTransformModelsOutput { OutputModel = model }; } [TlcModule.EntryPoint(Name = "Transforms.ManyHeterogeneousModelCombiner", Desc = "Combines a sequence of TransformModels and a PredictorModel into a single PredictorModel.")] diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index adfa42e50d..e31f6311cd 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -53,7 +53,7 @@ public void Setup() var testData = new TextLoader(s_dataPath).CreateFrom(useHeader: true); var evaluator = new ClassificationEvaluator(); - s_metrics = evaluator.Evaluate(s_trainedModel, testData); + s_metrics = evaluator.Evaluate(s_trainedModel, testData).FirstOrDefault(); s_batches = new IrisData[s_batchSizes.Length][]; for (int i = 0; i < s_batches.Length; i++) diff --git a/test/Microsoft.ML.Tests/CSharpCodeGen.cs b/test/Microsoft.ML.Tests/CSharpCodeGen.cs index 316d7eab55..c647110702 100644 --- a/test/Microsoft.ML.Tests/CSharpCodeGen.cs +++ b/test/Microsoft.ML.Tests/CSharpCodeGen.cs @@ -15,8 +15,7 @@ public CSharpCodeGen(ITestOutputHelper output) : base(output) { } - //[Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] - [Fact] + [Fact(Skip = "Temporary solution(Windows ONLY) to regenerate codegenerated CSharpAPI.cs")] public void GenerateCSharpAPI() { var cSharpAPIPath = Path.Combine(RootDir, @"src\\Microsoft.ML\\CSharpApi.cs"); diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs index 31fc4fdd6d..85955b1c06 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs @@ -8,6 +8,7 @@ using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -65,7 +66,7 @@ public void TrainAndPredictHousePriceModelTest() var testData = new TextLoader(testDataPath).CreateFrom(useHeader: true, separator: ','); var evaluator = new RegressionEvaluator(); - RegressionMetrics metrics = evaluator.Evaluate(model, testData); + RegressionMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); Assert.InRange(metrics.L1, 85_000, 89_000); Assert.InRange(metrics.L2, 17_000_000_000, 19_000_000_000); Assert.InRange(metrics.Rms, 130_500, 135_000); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 5dcbf3a588..cb6cad9548 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System.Linq; using Xunit; namespace Microsoft.ML.Scenarios @@ -71,7 +72,7 @@ public void TrainAndPredictIrisModelTest() var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; - ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); Assert.Equal(.98, metrics.AccuracyMacro); Assert.Equal(.98, metrics.AccuracyMicro, 2); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index ebddc33b03..10d23b62b7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; +using System.Linq; using Xunit; namespace Microsoft.ML.Scenarios @@ -74,7 +75,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; - ClassificationMetrics metrics = evaluator.Evaluate(model, testData); + ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); ; Assert.Equal(.98, metrics.AccuracyMacro); Assert.Equal(.98, metrics.AccuracyMicro, 2); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index a27a85488a..165d89f372 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -65,7 +65,7 @@ public void TrainAndPredictSentimentModelTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - + //var cv = new CrossValidator().CrossValidate(pipeline); PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] { @@ -111,7 +111,7 @@ public void TrainAndPredictSentimentModelTest() } }; var evaluator = new BinaryClassificationEvaluator(); - BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData); + BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); Assert.Equal(.5556, metrics.Accuracy, 4); Assert.Equal(.8, metrics.Auc, 1); From 001fa536c4d13dee92007b1d75f4d208906d9083 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Wed, 23 May 2018 23:18:44 -0700 Subject: [PATCH 10/18] Unit-test. --- .../Models/BinaryClassificationMetrics.cs | 10 ++- .../Models/ClassificationMetrics.cs | 9 ++- src/Microsoft.ML/Models/CrossValidator.cs | 10 ++- src/Microsoft.ML/Models/RegressionMetrics.cs | 9 ++- .../EntryPoints/CrossValidationMacro.cs | 8 +- .../Scenarios/SentimentPredictionTests.cs | 77 ++++++++++++++++++- 6 files changed, 106 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs index fe95eac312..a33475bd2a 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs @@ -19,7 +19,7 @@ private BinaryClassificationMetrics() { } - internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -27,9 +27,13 @@ internal static List FromMetrics(IHostEnvironment e var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - if (!enumerator.MoveNext()) + + while (skipRows-- >= 0) { - throw env.Except("The overall RegressionMetrics didn't have any rows."); + if (!enumerator.MoveNext()) + { + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + } } List metrics = new List(); diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs index 0fbbba602e..2d036f6b6a 100644 --- a/src/Microsoft.ML/Models/ClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs @@ -18,7 +18,7 @@ private ClassificationMetrics() { } - internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -26,9 +26,12 @@ internal static List FromMetrics(IHostEnvironment env, ID var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - if (!enumerator.MoveNext()) + while (skipRows-- >= 0) { - throw env.Except("The overall RegressionMetrics didn't have any rows."); + if (!enumerator.MoveNext()) + { + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + } } List metrics = new List(); diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 2ef99993cb..16f4df665b 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -120,24 +120,28 @@ public CrossValidationOutput CrossValidate(Lea cvo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix)); + experiment.GetOutput(crossValidateOutput.ConfusionMatrix), + 2); } else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { cvo.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix)); + experiment.GetOutput(crossValidateOutput.ConfusionMatrix), + 2); } else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) { cvo.RegressionMetrics = RegressionMetrics.FromOverallMetrics( environment, - experiment.GetOutput(crossValidateOutput.OverallMetrics)); + experiment.GetOutput(crossValidateOutput.OverallMetrics), + 2); } else { //Implement metrics for ranking, clustering and anomaly detection. + throw Contracts.Except($"{Kind.ToString()} is not supported at the moment."); } ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel[Index]); diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs index 43eabc3bec..00bb74a9ef 100644 --- a/src/Microsoft.ML/Models/RegressionMetrics.cs +++ b/src/Microsoft.ML/Models/RegressionMetrics.cs @@ -19,16 +19,19 @@ private RegressionMetrics() { } - internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics) + internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics, int skipRows = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - if (!enumerator.MoveNext()) + while (skipRows-- >= 0) { - throw env.Except("The overall RegressionMetrics didn't have any rows."); + if (!enumerator.MoveNext()) + { + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + } } List metrics = new List(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 11bea3970c..d7d4729344 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -94,16 +94,16 @@ public sealed class Output "provided as the Input.TransformModel.", SortOrder = 2)] public ITransformModel[] TransformModel; - [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView Warnings; - [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 4)] public IDataView OverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 5)] public IDataView PerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 6)] public IDataView ConfusionMatrix; } diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 165d89f372..6725178326 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -65,7 +65,6 @@ public void TrainAndPredictSentimentModelTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - //var cv = new CrossValidator().CrossValidate(pipeline); PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] { @@ -140,6 +139,82 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(8, matrix["negative", "positive"]); Assert.Equal(1, matrix[1, 1]); Assert.Equal(1, matrix["negative", "negative"]); + + var cv = new CrossValidator().CrossValidate(pipeline); + + Assert.Equal(2, cv.PredictorModels.Count()); + Assert.Null(cv.ClassificationMetrics); + Assert.Null(cv.RegressionMetrics); + Assert.NotNull(cv.BinaryClassificationMetrics); + Assert.Equal(2, cv.BinaryClassificationMetrics.Count()); + + metrics = cv.BinaryClassificationMetrics[0]; + Assert.Equal(0.53030303030303028, metrics.Accuracy, 4); + Assert.Equal(0.52854072128015284, metrics.Auc, 1); + Assert.Equal(0.62464073827546951, metrics.Auprc, 2); + Assert.Equal(0, metrics.Entropy, 3); + Assert.Equal(0.65934065934065933, metrics.F1Score, 4); + Assert.Equal(1.0098658732948276, metrics.LogLoss, 3); + Assert.Equal(-3.9138397565662424, metrics.LogLossReduction, 3); + Assert.Equal(0.34482758620689657, metrics.NegativePrecision, 3); + Assert.Equal(0.18867924528301888, metrics.NegativeRecall, 3); + Assert.Equal(0.58252427184466016, metrics.PositivePrecision, 3); + Assert.Equal(0.759493670886076, metrics.PositiveRecall); + + matrix = metrics.ConfusionMatrix; + Assert.Equal(2, matrix.Order); + Assert.Equal(2, matrix.ClassNames.Count); + Assert.Equal("positive", matrix.ClassNames[0]); + Assert.Equal("negative", matrix.ClassNames[1]); + + Assert.Equal(60, matrix[0, 0]); + Assert.Equal(60, matrix["positive", "positive"]); + Assert.Equal(19, matrix[0, 1]); + Assert.Equal(19, matrix["positive", "negative"]); + + Assert.Equal(43, matrix[1, 0]); + Assert.Equal(43, matrix["negative", "positive"]); + Assert.Equal(10, matrix[1, 1]); + Assert.Equal(10, matrix["negative", "negative"]); + + metrics = cv.BinaryClassificationMetrics[1]; + Assert.Equal(0.61016949152542377, metrics.Accuracy, 4); + Assert.Equal(0.57067307692307689, metrics.Auc, 1); + Assert.Equal(0.71632480611861549, metrics.Auprc, 2); + Assert.Equal(0, metrics.Entropy, 3); + Assert.Equal(0.71951219512195119, metrics.F1Score, 4); + Assert.Equal(0.94405231894454111, metrics.LogLoss, 3); + Assert.Equal(-2.1876127616628396, metrics.LogLossReduction, 3); + Assert.Equal(0.40625, metrics.NegativePrecision, 3); + Assert.Equal(0.325, metrics.NegativeRecall, 3); + Assert.Equal(0.686046511627907, metrics.PositivePrecision, 3); + Assert.Equal(0.75641025641025639, metrics.PositiveRecall); + + matrix = metrics.ConfusionMatrix; + Assert.Equal(2, matrix.Order); + Assert.Equal(2, matrix.ClassNames.Count); + Assert.Equal("positive", matrix.ClassNames[0]); + Assert.Equal("negative", matrix.ClassNames[1]); + + Assert.Equal(59, matrix[0, 0]); + Assert.Equal(59, matrix["positive", "positive"]); + Assert.Equal(19, matrix[0, 1]); + Assert.Equal(19, matrix["positive", "negative"]); + + Assert.Equal(27, matrix[1, 0]); + Assert.Equal(27, matrix["negative", "positive"]); + Assert.Equal(13, matrix[1, 1]); + Assert.Equal(13, matrix["negative", "negative"]); + + predictions = cv.PredictorModels[0].Predict(sentiments); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + + predictions = cv.PredictorModels[1].Predict(sentiments); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); } public class SentimentData From 372e764413f99e48a85ad7373e529bbf1b057fca Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Thu, 24 May 2018 10:34:11 -0700 Subject: [PATCH 11/18] resolve merge conflicts. --- test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 4607bf629c..8e9b07d255 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -392,7 +392,7 @@ public void TestCrossValidationMacro() [Fact] public void TestCrossValidationMacroWithStratification() { - var dataPath = GetDataPath(@"breast-cancer.txt"); + /*var dataPath = GetDataPath(@"breast-cancer.txt"); using (var env = new TlcEnvironment()) { var subGraph = env.CreateExperiment(); @@ -454,7 +454,7 @@ public void TestCrossValidationMacroWithStratification() b = cursor.MoveNext(); Assert.False(b); } - } + }*/ } [Fact] From 2769d9bb95643ca61aec6e5cae52414aaeda4354 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 12:51:25 -0700 Subject: [PATCH 12/18] Add Train-Test and address PR feedback. --- .../Common/EntryPoints/core_manifest.json | 88 ++++++++- src/Microsoft.ML/CSharpApi.cs | 26 ++- src/Microsoft.ML/ILearningPipelineItem.cs | 5 + .../Models/BinaryClassificationEvaluator.cs | 9 +- .../Models/ClassificationEvaluator.cs | 9 +- src/Microsoft.ML/Models/CrossValidator.cs | 13 +- src/Microsoft.ML/Models/TrainTestEvaluator.cs | 181 ++++++++++++++++++ ...sticDualCoordinateAscentClassifierBench.cs | 2 +- .../Scenarios/IrisPlantClassificationTests.cs | 2 +- ...PlantClassificationWithStringLabelTests.cs | 2 +- .../Scenarios/SentimentPredictionTests.cs | 68 ++++++- 11 files changed, 380 insertions(+), 25 deletions(-) create mode 100644 src/Microsoft.ML/Models/TrainTestEvaluator.cs diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index 05b7965842..9a7306d87f 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -469,6 +469,35 @@ "ILearningPipelineLoader" ] }, + { + "Name": "Data.TransformModelArrayConverter", + "Desc": "Create and array variable", + "FriendlyName": null, + "ShortName": null, + "Inputs": [ + { + "Name": "TransformModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The models", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The model array" + } + ] + }, { "Name": "Models.AnomalyDetectionEvaluator", "Desc": "Evaluates an anomaly detection scored dataset.", @@ -1411,9 +1440,28 @@ "Name": "Model", "Type": "PredictorModel", "Desc": "The model", - "Required": true, + "Required": false, "SortOrder": 1.0, - "IsNullable": false + "IsNullable": false, + "Default": null + }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "The transform model", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "UseTransformModel", + "Type": "Bool", + "Desc": "Indicates to use transform model instead of predictor model.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false } ] }, @@ -1476,6 +1524,14 @@ }, "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." }, + { + "Name": "TransformModel", + "Type": { + "Kind": "Array", + "ItemType": "TransformModel" + }, + "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." + }, { "Name": "Warnings", "Type": "DataView", @@ -3002,9 +3058,28 @@ "Name": "Model", "Type": "PredictorModel", "Desc": "The model", - "Required": true, + "Required": false, "SortOrder": 1.0, - "IsNullable": false + "IsNullable": false, + "Default": null + }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "Transform model", + "Required": false, + "SortOrder": 2.0, + "IsNullable": false, + "Default": null + }, + { + "Name": "UseTransformModel", + "Type": "Bool", + "Desc": "Indicates to use transform model instead of predictor model.", + "Required": false, + "SortOrder": 3.0, + "IsNullable": false, + "Default": false } ] }, @@ -3058,6 +3133,11 @@ "Type": "PredictorModel", "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." }, + { + "Name": "TransformModel", + "Type": "TransformModel", + "Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel." + }, { "Name": "Warnings", "Type": "DataView", diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 6de2ec729c..283655d059 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -7095,14 +7095,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.IAno public Var PredictorModel { get; set; } = new Var(); } + public Var GetInputData() => TrainingData; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PcaAnomalyDetector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PcaAnomalyDetector)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - TrainingData = dataStep.Data; + TrainingData = dataStep.Data; + } Output output = experiment.Add(this); return new PcaAnomalyDetectorPipelineStep(output); } @@ -12140,14 +12145,19 @@ public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITra public Var Model { get; set; } = new Var(); } + public Var GetInputData() => Data; + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) { - if (!(previousStep is ILearningPipelineDataStep dataStep)) + if (previousStep != null) { - throw new InvalidOperationException($"{ nameof(PcaCalculator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); - } + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(PcaCalculator)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } - Data = dataStep.Data; + Data = dataStep.Data; + } Output output = experiment.Add(this); return new PcaCalculatorPipelineStep(output); } diff --git a/src/Microsoft.ML/ILearningPipelineItem.cs b/src/Microsoft.ML/ILearningPipelineItem.cs index 3d29b65a2e..c36f890c57 100644 --- a/src/Microsoft.ML/ILearningPipelineItem.cs +++ b/src/Microsoft.ML/ILearningPipelineItem.cs @@ -14,6 +14,11 @@ namespace Microsoft.ML public interface ILearningPipelineItem { ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment); + + /// + /// Returns the place holder for input IDataView object for the node in the execution graph. + /// + /// Var GetInputData(); } diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs index b916a9ab57..9733946dd0 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; using System.Collections.Generic; +using System.Linq; namespace Microsoft.ML.Models { @@ -24,7 +25,7 @@ public sealed partial class BinaryClassificationEvaluator /// /// A BinaryClassificationMetrics instance that describes how well the model performed against the test data. /// - public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { @@ -67,7 +68,11 @@ public List Evaluate(PredictionModel model, ILearni throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); } - return BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + + Contracts.Assert(metric.Count == 1); + + return metric.First(); } } } diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index 33799c0c78..e68ba68d25 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -6,6 +6,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; using System.Collections.Generic; +using System.Linq; namespace Microsoft.ML.Models { @@ -24,7 +25,7 @@ public sealed partial class ClassificationEvaluator /// /// A ClassificationMetrics instance that describes how well the model performed against the test data. /// - public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { @@ -67,7 +68,11 @@ public List Evaluate(PredictionModel model, ILearningPipe throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate."); } - return ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); + + Contracts.Assert(metric.Count == 1); + + return metric.First(); } } } diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 16f4df665b..8b50cfb5b7 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -2,15 +2,23 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; -using System; using System.Collections.Generic; using System.IO; -using System.Text; namespace Microsoft.ML.Models { + /// + /// Performs cross-validation on a pipeline. + /// public sealed partial class CrossValidator { + /// + /// Performs cross validation on a pipeline. + /// + /// Class type that represents input schema. + /// Class type that represents prediction schema. + /// Machine learning pipeline that contain may contain loader, transforms and at least one trainer. + /// List containning metrics and predictor model for each fold public CrossValidationOutput CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() @@ -44,7 +52,6 @@ public CrossValidationOutput CrossValidate(Lea firstTransform = currentItem; } } - else if (step is ILearningPipelinePredictorStep predictorDataStep) { if (lastTransformModel != null) diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs new file mode 100644 index 0000000000..af20ae69fb --- /dev/null +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -0,0 +1,181 @@ +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.Models +{ + /// + /// Performs Train-Test on a pipeline. + /// + public sealed partial class TrainTestEvaluator + { + /// + /// Performs train-test on a pipeline. + /// + /// Class type that represents input schema. + /// Class type that represents prediction schema. + /// Machine learning pipeline that contains , + /// transforms and at least one trainer. + /// that represents the test dataset. + /// Metrics and predictor model. + public TrainTestEvaluatorOutput TrainTestEvaluate(LearningPipeline pipeline, ILearningPipelineLoader testData) + where TInput : class + where TOutput : class, new() + { + using (var environment = new TlcEnvironment()) + { + Experiment subGraph = environment.CreateExperiment(); + ILearningPipelineStep step = null; + List loaders = new List(); + List> transformModels = new List>(); + Var lastTransformModel = null; + Var firstPipelineDataStep = null; + Var firstModel = null; + ILearningPipelineItem firstTransform = null; + foreach (ILearningPipelineItem currentItem in pipeline) + { + if (currentItem is ILearningPipelineLoader loader) + { + loaders.Add(loader); + continue; + } + + step = currentItem.ApplyStep(step, subGraph); + + if (step is ILearningPipelineDataStep dataStep && dataStep.Model != null) + { + transformModels.Add(dataStep.Model); + if (firstPipelineDataStep == null) + { + firstPipelineDataStep = dataStep.Data; + firstTransform = currentItem; + } + } + else if (step is ILearningPipelinePredictorStep predictorDataStep) + { + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + + Var predictorModel; + if (transformModels.Count != 0) + { + var localModelInput = new Transforms.ManyHeterogeneousModelCombiner + { + PredictorModel = predictorDataStep.Model, + TransformModels = new ArrayVar(transformModels.ToArray()) + }; + var localModelOutput = subGraph.Add(localModelInput); + predictorModel = localModelOutput.PredictorModel; + } + else + predictorModel = predictorDataStep.Model; + firstModel = predictorModel; + + var scorer = new Transforms.Scorer + { + PredictorModel = predictorModel + }; + + var scorerOutput = subGraph.Add(scorer); + lastTransformModel = scorerOutput.ScoringTransform; + step = new ScorerPipelineStep(scorerOutput.ScoredData, scorerOutput.ScoringTransform); + transformModels.Clear(); + } + } + + if (transformModels.Count > 0) + { + if (lastTransformModel != null) + transformModels.Insert(0, lastTransformModel); + + var modelInput = new Transforms.ModelCombiner + { + Models = new ArrayVar(transformModels.ToArray()) + }; + + var modelOutput = subGraph.Add(modelInput); + lastTransformModel = modelOutput.OutputModel; + } + + var experiment = environment.CreateExperiment(); + + TrainingData = (loaders[0].ApplyStep(null, experiment) as ILearningPipelineDataStep).Data; + TestingData = (testData.ApplyStep(null, experiment) as ILearningPipelineDataStep).Data; + Nodes = subGraph; + TransformModel = null; + Inputs.Data = firstTransform.GetInputData(); + Outputs.TransformModel = lastTransformModel; + Outputs.UseTransformModel = true; + var crossValidateOutput = experiment.Add(this); + experiment.Compile(); + foreach (ILearningPipelineLoader loader in loaders) + { + loader.SetInput(environment, experiment); + } + testData.SetInput(environment, experiment); + + experiment.Run(); + + TrainTestEvaluatorOutput tteo = new TrainTestEvaluatorOutput(); + + if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) + { + tteo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault(); + } + else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) + { + tteo.ClassificationMetrics = ClassificationMetrics.FromMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics), + experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault(); + } + else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) + { + tteo.RegressionMetrics = RegressionMetrics.FromOverallMetrics( + environment, + experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault(); + } + else + { + //Implement metrics for ranking, clustering and anomaly detection. + throw Contracts.Except($"{Kind.ToString()} is not supported at the moment."); + } + + ITransformModel model = experiment.GetOutput(crossValidateOutput.TransformModel); + BatchPredictionEngine predictor; + using (var memoryStream = new MemoryStream()) + { + model.Save(environment, memoryStream); + + memoryStream.Position = 0; + + predictor = environment.CreateBatchPredictionEngine(memoryStream); + + tteo.PredictorModels = new PredictionModel(predictor, memoryStream); + } + + return tteo; + } + } + } + + public class TrainTestEvaluatorOutput + where TInput : class + where TOutput : class, new() + { + public BinaryClassificationMetrics BinaryClassificationMetrics; + public ClassificationMetrics ClassificationMetrics; + public RegressionMetrics RegressionMetrics; + public PredictionModel PredictorModels; + + //REVIEW: Add warnings and per instance results and implement + //metrics for ranking, clustering and anomaly detection. + } +} diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index e31f6311cd..adfa42e50d 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -53,7 +53,7 @@ public void Setup() var testData = new TextLoader(s_dataPath).CreateFrom(useHeader: true); var evaluator = new ClassificationEvaluator(); - s_metrics = evaluator.Evaluate(s_trainedModel, testData).FirstOrDefault(); + s_metrics = evaluator.Evaluate(s_trainedModel, testData); s_batches = new IrisData[s_batchSizes.Length][]; for (int i = 0; i < s_batches.Length; i++) diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index cb6cad9548..4c4e7114cb 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -72,7 +72,7 @@ public void TrainAndPredictIrisModelTest() var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; - ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); Assert.Equal(.98, metrics.AccuracyMacro); Assert.Equal(.98, metrics.AccuracyMicro, 2); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 10d23b62b7..bd434757d4 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -75,7 +75,7 @@ public void TrainAndPredictIrisModelWithStringLabelTest() var evaluator = new ClassificationEvaluator(); evaluator.OutputTopKAcc = 3; - ClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); ; + ClassificationMetrics metrics = evaluator.Evaluate(model, testData); Assert.Equal(.98, metrics.AccuracyMacro); Assert.Equal(.98, metrics.AccuracyMicro, 2); diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 6725178326..0140eb233e 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -23,6 +23,9 @@ public partial class ScenariosTests [Fact] public void TrainAndPredictSentimentModelTest() { + + //1. Construct the ML pipeline. + string dataPath = GetDataPath(SentimentDataPath); var pipeline = new LearningPipeline(); @@ -50,7 +53,7 @@ public void TrainAndPredictSentimentModelTest() } } }); - + pipeline.Add(new TextFeaturizer("Features", "SentimentText") { KeepDiacritics = false, @@ -65,9 +68,12 @@ public void TrainAndPredictSentimentModelTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + + //2.1 Train. + PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] - { + { new SentimentData { SentimentText = "Please refrain from adding nonsense to Wikipedia." @@ -78,6 +84,8 @@ public void TrainAndPredictSentimentModelTest() } }; + //2.2 Predict. + IEnumerable predictions = model.Predict(sentiments); Assert.Equal(2, predictions.Count()); @@ -109,8 +117,11 @@ public void TrainAndPredictSentimentModelTest() } } }; + + //2.3 Evaluate the predictor model. + var evaluator = new BinaryClassificationEvaluator(); - BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); + BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData); Assert.Equal(.5556, metrics.Accuracy, 4); Assert.Equal(.8, metrics.Auc, 1); @@ -140,6 +151,57 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(1, matrix[1, 1]); Assert.Equal(1, matrix["negative", "negative"]); + //3. Lets do 2.1, 2.2 and 2.3 using Train-Test that does training, evaluation and gives a predictor. + + var tt = new TrainTestEvaluator().TrainTestEvaluate(pipeline, testData); + + Assert.Null(tt.ClassificationMetrics); + Assert.Null(tt.RegressionMetrics); + Assert.NotNull(tt.BinaryClassificationMetrics); + Assert.NotNull(tt.PredictorModels); + + //These results should match that of 2.3 + metrics = tt.BinaryClassificationMetrics; + Assert.Equal(.5556, metrics.Accuracy, 4); + Assert.Equal(.8, metrics.Auc, 1); + Assert.Equal(.87, metrics.Auprc, 2); + Assert.Equal(1, metrics.Entropy, 3); + Assert.Equal(.6923, metrics.F1Score, 4); + Assert.Equal(.969, metrics.LogLoss, 3); + Assert.Equal(3.083, metrics.LogLossReduction, 3); + Assert.Equal(1, metrics.NegativePrecision, 3); + Assert.Equal(.111, metrics.NegativeRecall, 3); + Assert.Equal(.529, metrics.PositivePrecision, 3); + Assert.Equal(1, metrics.PositiveRecall); + + matrix = metrics.ConfusionMatrix; + Assert.Equal(2, matrix.Order); + Assert.Equal(2, matrix.ClassNames.Count); + Assert.Equal("positive", matrix.ClassNames[0]); + Assert.Equal("negative", matrix.ClassNames[1]); + + Assert.Equal(9, matrix[0, 0]); + Assert.Equal(9, matrix["positive", "positive"]); + Assert.Equal(0, matrix[0, 1]); + Assert.Equal(0, matrix["positive", "negative"]); + + Assert.Equal(8, matrix[1, 0]); + Assert.Equal(8, matrix["negative", "positive"]); + Assert.Equal(1, matrix[1, 1]); + Assert.Equal(1, matrix["negative", "negative"]); + + predictions = tt.PredictorModels.Predict(sentiments); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + + predictions = tt.PredictorModels.Predict(sentiments); + Assert.Equal(2, predictions.Count()); + Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); + Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + + //4. Cross Validation on training pipeline. + var cv = new CrossValidator().CrossValidate(pipeline); Assert.Equal(2, cv.PredictorModels.Count()); From 1f520a2f47de41b175cf04bd218846030311af44 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 13:08:01 -0700 Subject: [PATCH 13/18] clean up. --- .../Common/EntryPoints/core_manifest.json | 10 +++---- src/Microsoft.ML/Data/TextLoader.cs | 1 + .../Models/RegressionEvaluator.cs | 9 +++++-- .../EntryPoints/CrossValidationMacro.cs | 26 +++++++++---------- .../HousePriceTrainAndPredictionTests.cs | 2 +- 5 files changed, 27 insertions(+), 21 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index 9a7306d87f..07d7ec79a1 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1329,7 +1329,7 @@ "Label" ], "Required": false, - "SortOrder": 6.0, + "SortOrder": 5.0, "IsNullable": false, "Default": "Label" }, @@ -1349,7 +1349,7 @@ }, "Desc": "Specifies the trainer kind, which determines the evaluator to be used.", "Required": true, - "SortOrder": 7.0, + "SortOrder": 6.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" } @@ -1478,7 +1478,7 @@ "strat" ], "Required": false, - "SortOrder": 7.0, + "SortOrder": 6.0, "IsNullable": false, "Default": null }, @@ -1490,7 +1490,7 @@ "k" ], "Required": false, - "SortOrder": 8.0, + "SortOrder": 7.0, "IsNullable": false, "Default": 2 }, @@ -1510,7 +1510,7 @@ }, "Desc": "Specifies the trainer kind, which determines the evaluator to be used.", "Required": true, - "SortOrder": 9.0, + "SortOrder": 8.0, "IsNullable": false, "Default": "SignatureBinaryClassifierTrainer" } diff --git a/src/Microsoft.ML/Data/TextLoader.cs b/src/Microsoft.ML/Data/TextLoader.cs index 1829d6225e..3c8550ef09 100644 --- a/src/Microsoft.ML/Data/TextLoader.cs +++ b/src/Microsoft.ML/Data/TextLoader.cs @@ -21,6 +21,7 @@ public TextLoaderRange() /// /// Convenience constructor for the scalar case, when a given column /// in the schema spans only a single column in the dataset. + /// and are set to the single value . /// /// Column index in the dataset. public TextLoaderRange(int ordinal) diff --git a/src/Microsoft.ML/Models/RegressionEvaluator.cs b/src/Microsoft.ML/Models/RegressionEvaluator.cs index c55f4f3335..3b5e1bda13 100644 --- a/src/Microsoft.ML/Models/RegressionEvaluator.cs +++ b/src/Microsoft.ML/Models/RegressionEvaluator.cs @@ -7,6 +7,7 @@ using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; using System.Collections.Generic; +using System.Linq; namespace Microsoft.ML.Models { @@ -24,7 +25,7 @@ public sealed partial class RegressionEvaluator /// /// A RegressionMetrics instance that describes how well the model performed against the test data. /// - public List Evaluate(PredictionModel model, ILearningPipelineLoader testData) + public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData) { using (var environment = new TlcEnvironment()) { @@ -61,8 +62,12 @@ public List Evaluate(PredictionModel model, ILearningPipeline { throw environment.Except($"Could not find OverallMetrics in the results returned in {nameof(RegressionEvaluator)} Evaluate."); } + + var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics); - return RegressionMetrics.FromOverallMetrics(environment, overallMetrics); + Contracts.Assert(metric.Count == 1); + + return metric.First(); } } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index f832637632..9133aef258 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -69,16 +69,16 @@ public sealed class Arguments // For splitting the data into folds, this column is used for grouping rows and makes sure // that a group of rows is not split among folds. - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 7)] + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Column to use for stratification", ShortName = "strat", SortOrder = 6)] public string StratificationColumn; // The number of folds to generate. - [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 8)] + [Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of folds in k-fold cross-validation", ShortName = "k", SortOrder = 7)] public int NumFolds = 2; // REVIEW: suggest moving to subcomponents for evaluators, to allow for different parameters on the evaluators // (and the same for the TrainTest macro). I currently do not know how to do this, so this should be revisited in the future. - [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 9)] + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 8)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; } @@ -94,16 +94,16 @@ public sealed class Output "provided as the Input.TransformModel.", SortOrder = 2)] public ITransformModel[] TransformModel; - [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 3)] public IDataView Warnings; - [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 4)] public IDataView OverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 5)] public IDataView PerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 6)] public IDataView ConfusionMatrix; } @@ -122,25 +122,25 @@ public sealed class CombineMetricsInput [Argument(ArgumentType.Multiple, HelpText = "Warning datasets", SortOrder = 4)] public IDataView[] Warnings; - [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 6)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The label column name", ShortName = "Label", SortOrder = 5)] public string LabelColumn = DefaultColumnNames.Label; - [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 7)] + [Argument(ArgumentType.Required, HelpText = "Specifies the trainer kind, which determines the evaluator to be used.", SortOrder = 6)] public MacroUtils.TrainerKinds Kind = MacroUtils.TrainerKinds.SignatureBinaryClassifierTrainer; } public sealed class CombinedOutput { - [TlcModule.Output(Desc = "Warning dataset", SortOrder = 2)] + [TlcModule.Output(Desc = "Warning dataset", SortOrder = 1)] public IDataView Warnings; - [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 3)] + [TlcModule.Output(Desc = "Overall metrics dataset", SortOrder = 2)] public IDataView OverallMetrics; - [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 4)] + [TlcModule.Output(Desc = "Per instance metrics dataset", SortOrder = 3)] public IDataView PerInstanceMetrics; - [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 5)] + [TlcModule.Output(Desc = "Confusion matrix dataset", SortOrder = 4)] public IDataView ConfusionMatrix; } diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs index 85955b1c06..da7288645c 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs @@ -66,7 +66,7 @@ public void TrainAndPredictHousePriceModelTest() var testData = new TextLoader(testDataPath).CreateFrom(useHeader: true, separator: ','); var evaluator = new RegressionEvaluator(); - RegressionMetrics metrics = evaluator.Evaluate(model, testData).FirstOrDefault(); + RegressionMetrics metrics = evaluator.Evaluate(model, testData); Assert.InRange(metrics.L1, 85_000, 89_000); Assert.InRange(metrics.L2, 17_000_000_000, 19_000_000_000); Assert.InRange(metrics.Rms, 130_500, 135_000); From 3991971130a0c88099becca33cb7ba1b0cd29e84 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 13:12:58 -0700 Subject: [PATCH 14/18] more cleanup. --- .../Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs | 3 --- .../Scenarios/HousePriceTrainAndPredictionTests.cs | 4 ---- .../Scenarios/IrisPlantClassificationTests.cs | 1 - .../Scenarios/IrisPlantClassificationWithStringLabelTests.cs | 1 - 4 files changed, 9 deletions(-) diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs index 392462a0eb..81a2d950b5 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePricePredictionTests.cs @@ -2,11 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.Models; using Microsoft.ML.Runtime.Api; using Microsoft.ML.TestFramework; -using Microsoft.ML.Trainers; -using Microsoft.ML.Transforms; using Xunit; using Xunit.Abstractions; diff --git a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs index da7288645c..df529f04a7 100644 --- a/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/HousePriceTrainAndPredictionTests.cs @@ -4,13 +4,9 @@ using Microsoft.ML.Data; using Microsoft.ML.Models; -using Microsoft.ML.Runtime.Api; -using Microsoft.ML.TestFramework; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; -using System.Linq; using Xunit; -using Xunit.Abstractions; namespace Microsoft.ML.Scenarios { diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs index 4c4e7114cb..5dcbf3a588 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationTests.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; -using System.Linq; using Xunit; namespace Microsoft.ML.Scenarios diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index bd434757d4..ebddc33b03 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Api; using Microsoft.ML.Trainers; using Microsoft.ML.Transforms; -using System.Linq; using Xunit; namespace Microsoft.ML.Scenarios From 34247751645f9ddaab6dab812ec44a4c54c59754 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 15:43:26 -0700 Subject: [PATCH 15/18] PR feedback. --- .../Common/EntryPoints/core_ep-list.tsv | 6 +- .../Common/EntryPoints/core_manifest.json | 10 +- src/Microsoft.ML/CSharpApi.cs | 10 +- src/Microsoft.ML/Models/TrainTestEvaluator.cs | 15 +- .../EntryPoints/CrossValidationBinaryMacro.cs | 6 +- .../EntryPoints/CrossValidationMacro.cs | 3 +- .../Runtime/EntryPoints/TrainTestMacro.cs | 2 +- .../Scenarios/SentimentPredictionTests.cs | 167 ++++++++++++++++-- 8 files changed, 174 insertions(+), 45 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_ep-list.tsv b/ZBaselines/Common/EntryPoints/core_ep-list.tsv index da227b16ff..61f2604a8d 100644 --- a/ZBaselines/Common/EntryPoints/core_ep-list.tsv +++ b/ZBaselines/Common/EntryPoints/core_ep-list.tsv @@ -1,9 +1,9 @@ Data.CustomTextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData ImportText Microsoft.ML.Runtime.EntryPoints.ImportTextData+Input Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runtime.EntryPoints.DataViewReference ImportData Microsoft.ML.Runtime.EntryPoints.DataViewReference+Input Microsoft.ML.Runtime.EntryPoints.DataViewReference+Output -Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput -Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput +Data.IDataViewArrayConverter Create an array variable of IDataView Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput +Data.PredictorModelArrayConverter Create an array variable of IPredictorModel Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output -Data.TransformModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput +Data.TransformModelArrayConverter Create an array variable of ITransformModel Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output] diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index 07d7ec79a1..85be4d0e0e 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -63,7 +63,7 @@ }, { "Name": "Data.IDataViewArrayConverter", - "Desc": "Create and array variable", + "Desc": "Create an array variable of IDataView", "FriendlyName": null, "ShortName": null, "Inputs": [ @@ -92,7 +92,7 @@ }, { "Name": "Data.PredictorModelArrayConverter", - "Desc": "Create and array variable", + "Desc": "Create an array variable of IPredictorModel", "FriendlyName": null, "ShortName": null, "Inputs": [ @@ -471,7 +471,7 @@ }, { "Name": "Data.TransformModelArrayConverter", - "Desc": "Create and array variable", + "Desc": "Create an array variable of ITransformModel", "FriendlyName": null, "ShortName": null, "Inputs": [ @@ -1439,7 +1439,7 @@ { "Name": "Model", "Type": "PredictorModel", - "Desc": "The model", + "Desc": "The predictor model", "Required": false, "SortOrder": 1.0, "IsNullable": false, @@ -3057,7 +3057,7 @@ { "Name": "Model", "Type": "PredictorModel", - "Desc": "The model", + "Desc": "The predictor model", "Required": false, "SortOrder": 1.0, "IsNullable": false, diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 283655d059..cc50050dd3 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -1408,7 +1408,7 @@ namespace Data { /// - /// Create and array variable + /// Create an array variable of IDataView /// public sealed partial class IDataViewArrayConverter { @@ -1435,7 +1435,7 @@ namespace Data { /// - /// Create and array variable + /// Create an array variable of IPredictorModel /// public sealed partial class PredictorModelArrayConverter { @@ -1658,7 +1658,7 @@ namespace Data { /// - /// Create and array variable + /// Create an array variable of ITransformModel /// public sealed partial class TransformModelArrayConverter { @@ -2212,7 +2212,7 @@ public sealed partial class CrossValidationMacroSubGraphInput public sealed partial class CrossValidationMacroSubGraphOutput { /// - /// The model + /// The predictor model /// public Var Model { get; set; } = new Var(); @@ -3398,7 +3398,7 @@ public sealed partial class TrainTestMacroSubGraphInput public sealed partial class TrainTestMacroSubGraphOutput { /// - /// The model + /// The predictor model /// public Var Model { get; set; } = new Var(); diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index af20ae69fb..c18c8f5c23 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -113,32 +113,31 @@ public TrainTestEvaluatorOutput TrainTestEvaluate tteo = new TrainTestEvaluatorOutput(); + TrainTestEvaluatorOutput trainTestOutput = new TrainTestEvaluatorOutput(); if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) { - tteo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( + trainTestOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault(); } else if (Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { - tteo.ClassificationMetrics = ClassificationMetrics.FromMetrics( + trainTestOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), experiment.GetOutput(crossValidateOutput.ConfusionMatrix)).FirstOrDefault(); } else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) { - tteo.RegressionMetrics = RegressionMetrics.FromOverallMetrics( + trainTestOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics)).FirstOrDefault(); } @@ -158,10 +157,10 @@ public TrainTestEvaluatorOutput TrainTestEvaluate(memoryStream); - tteo.PredictorModels = new PredictionModel(predictor, memoryStream); + trainTestOutput.PredictorModels = new PredictionModel(predictor, memoryStream); } - return tteo; + return trainTestOutput; } } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs index c6fed29fd3..fca8d3ac5b 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationBinaryMacro.cs @@ -224,7 +224,7 @@ public sealed class ArrayIPredictorModelOutput public IPredictorModel[] OutputModel; } - [TlcModule.EntryPoint(Desc = "Create and array variable", Name = "Data.PredictorModelArrayConverter")] + [TlcModule.EntryPoint(Desc = "Create an array variable of IPredictorModel", Name = "Data.PredictorModelArrayConverter")] public static ArrayIPredictorModelOutput MakeArray(IHostEnvironment env, ArrayIPredictorModelInput input) { var result = new ArrayIPredictorModelOutput @@ -246,7 +246,7 @@ public sealed class ArrayITransformModelOutput public ITransformModel[] OutputModel; } - [TlcModule.EntryPoint(Desc = "Create and array variable", Name = "Data.TransformModelArrayConverter")] + [TlcModule.EntryPoint(Desc = "Create an array variable of ITransformModel", Name = "Data.TransformModelArrayConverter")] public static ArrayITransformModelOutput MakeArray(IHostEnvironment env, ArrayITransformModelInput input) { var result = new ArrayITransformModelOutput @@ -269,7 +269,7 @@ public sealed class ArrayIDataViewOutput public IDataView[] OutputData; } - [TlcModule.EntryPoint(Desc = "Create and array variable", Name = "Data.IDataViewArrayConverter")] + [TlcModule.EntryPoint(Desc = "Create an array variable of IDataView", Name = "Data.IDataViewArrayConverter")] public static ArrayIDataViewOutput MakeArray(IHostEnvironment env, ArrayIDataViewInput input) { var result = new ArrayIDataViewOutput diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 9133aef258..3d3dab9ff8 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -29,7 +29,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { - [Argument(ArgumentType.AtMostOnce, HelpText = "The model", SortOrder = 1)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] public Var Model; [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)] @@ -476,7 +476,6 @@ private static IMamlEvaluator GetEvaluator(IHostEnvironment env, MacroUtils.Trai return new MultiOutputRegressionMamlEvaluator(env, new MultiOutputRegressionMamlEvaluator.Arguments()); default: throw env.ExceptParam(nameof(kind), $"Trainer kind {kind} does not have an evaluator"); - } } } diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index 61b376e7f9..869fcba3e2 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -24,7 +24,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { - [Argument(ArgumentType.AtMostOnce, HelpText = "The model", SortOrder = 1)] + [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] public Var Model; [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 0140eb233e..ed9e04a404 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -23,9 +23,6 @@ public partial class ScenariosTests [Fact] public void TrainAndPredictSentimentModelTest() { - - //1. Construct the ML pipeline. - string dataPath = GetDataPath(SentimentDataPath); var pipeline = new LearningPipeline(); @@ -69,8 +66,6 @@ public void TrainAndPredictSentimentModelTest() pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); - //2.1 Train. - PredictionModel model = pipeline.Train(); IEnumerable sentiments = new[] { @@ -84,8 +79,6 @@ public void TrainAndPredictSentimentModelTest() } }; - //2.2 Predict. - IEnumerable predictions = model.Predict(sentiments); Assert.Equal(2, predictions.Count()); @@ -118,8 +111,6 @@ public void TrainAndPredictSentimentModelTest() } }; - //2.3 Evaluate the predictor model. - var evaluator = new BinaryClassificationEvaluator(); BinaryClassificationMetrics metrics = evaluator.Evaluate(model, testData); @@ -150,8 +141,92 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(8, matrix["negative", "positive"]); Assert.Equal(1, matrix[1, 1]); Assert.Equal(1, matrix["negative", "negative"]); + } + + [Fact] + public void TrainTestPredictSentimentModelTest() + { + string dataPath = GetDataPath(SentimentDataPath); + var pipeline = new LearningPipeline(); + + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }); + + pipeline.Add(new TextFeaturizer("Features", "SentimentText") + { + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + StopWordsRemover = new PredefinedStopWordsRemover(), + VectorNormalizer = TextTransformTextNormKind.L2, + CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, + WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } + }); - //3. Lets do 2.1, 2.2 and 2.3 using Train-Test that does training, evaluation and gives a predictor. + pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); + pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + + PredictionModel model = pipeline.Train(); + IEnumerable sentiments = new[] + { + new SentimentData + { + SentimentText = "Please refrain from adding nonsense to Wikipedia." + }, + new SentimentData + { + SentimentText = "He is a CHEATER, and the article should say that." + } + }; + + string testDataPath = GetDataPath(SentimentTestPath); + var testData = new Data.TextLoader(testDataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }; var tt = new TrainTestEvaluator().TrainTestEvaluate(pipeline, testData); @@ -160,8 +235,7 @@ public void TrainAndPredictSentimentModelTest() Assert.NotNull(tt.BinaryClassificationMetrics); Assert.NotNull(tt.PredictorModels); - //These results should match that of 2.3 - metrics = tt.BinaryClassificationMetrics; + BinaryClassificationMetrics metrics = tt.BinaryClassificationMetrics; Assert.Equal(.5556, metrics.Accuracy, 4); Assert.Equal(.8, metrics.Auc, 1); Assert.Equal(.87, metrics.Auprc, 2); @@ -174,7 +248,7 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(.529, metrics.PositivePrecision, 3); Assert.Equal(1, metrics.PositiveRecall); - matrix = metrics.ConfusionMatrix; + ConfusionMatrix matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); Assert.Equal(2, matrix.ClassNames.Count); Assert.Equal("positive", matrix.ClassNames[0]); @@ -190,7 +264,7 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(1, matrix[1, 1]); Assert.Equal(1, matrix["negative", "negative"]); - predictions = tt.PredictorModels.Predict(sentiments); + IEnumerable predictions = tt.PredictorModels.Predict(sentiments); Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); @@ -199,8 +273,65 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment.IsFalse); Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); + } + + [Fact] + public void CrossValidateSentimentModelTest() + { + string dataPath = GetDataPath(SentimentDataPath); + var pipeline = new LearningPipeline(); - //4. Cross Validation on training pipeline. + pipeline.Add(new Data.TextLoader(dataPath) + { + Arguments = new TextLoaderArguments + { + Separator = new[] { '\t' }, + HasHeader = true, + Column = new[] + { + new TextLoaderColumn() + { + Name = "Label", + Source = new [] { new TextLoaderRange(0) }, + Type = Runtime.Data.DataKind.Num + }, + + new TextLoaderColumn() + { + Name = "SentimentText", + Source = new [] { new TextLoaderRange(1) }, + Type = Runtime.Data.DataKind.Text + } + } + } + }); + + pipeline.Add(new TextFeaturizer("Features", "SentimentText") + { + KeepDiacritics = false, + KeepPunctuations = false, + TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, + OutputTokens = true, + StopWordsRemover = new PredefinedStopWordsRemover(), + VectorNormalizer = TextTransformTextNormKind.L2, + CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, + WordFeatureExtractor = new NGramNgramExtractor() { NgramLength = 2, AllLengths = true } + }); + + pipeline.Add(new FastTreeBinaryClassifier() { NumLeaves = 5, NumTrees = 5, MinDocumentsInLeafs = 2 }); + pipeline.Add(new PredictedLabelColumnOriginalValueConverter() { PredictedLabelColumn = "PredictedLabel" }); + + IEnumerable sentiments = new[] + { + new SentimentData + { + SentimentText = "Please refrain from adding nonsense to Wikipedia." + }, + new SentimentData + { + SentimentText = "He is a CHEATER, and the article should say that." + } + }; var cv = new CrossValidator().CrossValidate(pipeline); @@ -210,7 +341,7 @@ public void TrainAndPredictSentimentModelTest() Assert.NotNull(cv.BinaryClassificationMetrics); Assert.Equal(2, cv.BinaryClassificationMetrics.Count()); - metrics = cv.BinaryClassificationMetrics[0]; + BinaryClassificationMetrics metrics = cv.BinaryClassificationMetrics[0]; Assert.Equal(0.53030303030303028, metrics.Accuracy, 4); Assert.Equal(0.52854072128015284, metrics.Auc, 1); Assert.Equal(0.62464073827546951, metrics.Auprc, 2); @@ -223,7 +354,7 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(0.58252427184466016, metrics.PositivePrecision, 3); Assert.Equal(0.759493670886076, metrics.PositiveRecall); - matrix = metrics.ConfusionMatrix; + ConfusionMatrix matrix = metrics.ConfusionMatrix; Assert.Equal(2, matrix.Order); Assert.Equal(2, matrix.ClassNames.Count); Assert.Equal("positive", matrix.ClassNames[0]); @@ -268,7 +399,7 @@ public void TrainAndPredictSentimentModelTest() Assert.Equal(13, matrix[1, 1]); Assert.Equal(13, matrix["negative", "negative"]); - predictions = cv.PredictorModels[0].Predict(sentiments); + IEnumerable predictions = cv.PredictorModels[0].Predict(sentiments); Assert.Equal(2, predictions.Count()); Assert.True(predictions.ElementAt(0).Sentiment.IsTrue); Assert.True(predictions.ElementAt(1).Sentiment.IsTrue); From 542b845e08586f0091fa487164e329bf9822d029 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 16:03:25 -0700 Subject: [PATCH 16/18] PR feedback. --- .../Models/BinaryClassificationEvaluator.cs | 2 +- src/Microsoft.ML/Models/ClassificationEvaluator.cs | 2 +- src/Microsoft.ML/Models/CrossValidator.cs | 14 +++++++------- src/Microsoft.ML/Models/RegressionEvaluator.cs | 2 +- src/Microsoft.ML/Models/TrainTestEvaluator.cs | 3 +-- .../Runtime/EntryPoints/CrossValidationMacro.cs | 3 ++- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs index 9733946dd0..daed482a2a 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs @@ -72,7 +72,7 @@ public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipe Contracts.Assert(metric.Count == 1); - return metric.First(); + return metric[0]; } } } diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index e68ba68d25..dcfe92608b 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -72,7 +72,7 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo Contracts.Assert(metric.Count == 1); - return metric.First(); + return metric[0]; } } } diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 8b50cfb5b7..d97669620f 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -116,15 +116,15 @@ public CrossValidationOutput CrossValidate(Lea experiment.Run(); - CrossValidationOutput cvo = new CrossValidationOutput(); - cvo.PredictorModels = new PredictionModel[NumFolds]; + var cvOutput = new CrossValidationOutput(); + cvOutput.PredictorModels = new PredictionModel[NumFolds]; for (int Index = 0; Index < NumFolds; Index++) { if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) { - cvo.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( + cvOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), experiment.GetOutput(crossValidateOutput.ConfusionMatrix), @@ -132,7 +132,7 @@ public CrossValidationOutput CrossValidate(Lea } else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { - cvo.ClassificationMetrics = ClassificationMetrics.FromMetrics( + cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), experiment.GetOutput(crossValidateOutput.ConfusionMatrix), @@ -140,7 +140,7 @@ public CrossValidationOutput CrossValidate(Lea } else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) { - cvo.RegressionMetrics = RegressionMetrics.FromOverallMetrics( + cvOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), 2); @@ -161,11 +161,11 @@ public CrossValidationOutput CrossValidate(Lea predictor = environment.CreateBatchPredictionEngine(memoryStream); - cvo.PredictorModels[Index] = new PredictionModel(predictor, memoryStream); + cvOutput.PredictorModels[Index] = new PredictionModel(predictor, memoryStream); } } - return cvo; + return cvOutput; } } } diff --git a/src/Microsoft.ML/Models/RegressionEvaluator.cs b/src/Microsoft.ML/Models/RegressionEvaluator.cs index 3b5e1bda13..b707e87c6b 100644 --- a/src/Microsoft.ML/Models/RegressionEvaluator.cs +++ b/src/Microsoft.ML/Models/RegressionEvaluator.cs @@ -67,7 +67,7 @@ public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader Contracts.Assert(metric.Count == 1); - return metric.First(); + return metric[0]; } } } diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index c18c8f5c23..657b27e0f3 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -119,8 +119,7 @@ public TrainTestEvaluatorOutput TrainTestEvaluate trainTestOutput = new TrainTestEvaluatorOutput(); - + var trainTestOutput = new TrainTestEvaluatorOutput(); if (Kind == MacroUtilsTrainerKinds.SignatureBinaryClassifierTrainer) { trainTestOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 3d3dab9ff8..19f41e2483 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -294,7 +294,8 @@ public static CommonOutputs.MacroOutput CrossValidate( var confusionMatrix = new Var(); outputMap.Add(nameof(TrainTestMacro.Output.ConfusionMatrix), confusionMatrix.VarName); confusionMatrixVars[k] = confusionMatrix; - subGraphNodes.Add(EntryPointNode.Create(env, "Models.TrainTestEvaluator", args, node.Catalog, node.Context, inputBindingMap, inputMap, outputMap)); + const string trainTestEvaluatorMacroEntryPoint = "Models.TrainTestEvaluator"; + subGraphNodes.Add(EntryPointNode.Create(env, trainTestEvaluatorMacroEntryPoint, args, node.Catalog, node.Context, inputBindingMap, inputMap, outputMap)); } exp.Reset(); From 072e61d01fe84d7081da8cadd0ae0175a5e70dcb Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Fri, 25 May 2018 16:57:32 -0700 Subject: [PATCH 17/18] PR feedback. --- .../Common/EntryPoints/core_manifest.json | 4 +- .../PipelinePattern.cs | 4 +- src/Microsoft.ML/CSharpApi.cs | 4 +- .../Models/BinaryClassificationEvaluator.cs | 4 +- .../Models/BinaryClassificationMetrics.cs | 43 +++++++++---------- .../Models/ClassificationEvaluator.cs | 3 +- .../Models/ClassificationMetrics.cs | 16 +++---- src/Microsoft.ML/Models/CrossValidator.cs | 13 +++--- src/Microsoft.ML/Models/RegressionMetrics.cs | 9 ++-- .../EntryPoints/CrossValidationMacro.cs | 10 ++--- .../Runtime/EntryPoints/TrainTestMacro.cs | 4 +- .../UnitTests/TestCSharpApi.cs | 6 +-- .../UnitTests/TestEntryPoints.cs | 12 +++--- .../Scenarios/SentimentPredictionTests.cs | 37 +++++++++++++++- 14 files changed, 96 insertions(+), 73 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index 85be4d0e0e..7e7139ef3f 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1437,7 +1437,7 @@ "Kind": "Struct", "Fields": [ { - "Name": "Model", + "Name": "PredictorModel", "Type": "PredictorModel", "Desc": "The predictor model", "Required": false, @@ -3055,7 +3055,7 @@ "Kind": "Struct", "Fields": [ { - "Name": "Model", + "Name": "PredictorModel", "Type": "PredictorModel", "Desc": "The predictor model", "Required": false, diff --git a/src/Microsoft.ML.PipelineInference/PipelinePattern.cs b/src/Microsoft.ML.PipelineInference/PipelinePattern.cs index 662a16798f..c6b4de44fa 100644 --- a/src/Microsoft.ML.PipelineInference/PipelinePattern.cs +++ b/src/Microsoft.ML.PipelineInference/PipelinePattern.cs @@ -152,7 +152,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD }, Outputs = { - Model = finalOutput + PredictorModel = finalOutput }, PipelineId = UniqueId.ToString("N"), Kind = MacroUtils.TrainerKindApiValue(trainerKind), @@ -189,7 +189,7 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var trainData, }, Outputs = { - Model = finalOutput + PredictorModel = finalOutput }, TrainingData = trainData, TestingData = testData, diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index cc50050dd3..fa23bda6fa 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2214,7 +2214,7 @@ public sealed partial class CrossValidationMacroSubGraphOutput /// /// The predictor model /// - public Var Model { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); /// /// The transform model @@ -3400,7 +3400,7 @@ public sealed partial class TrainTestMacroSubGraphOutput /// /// The predictor model /// - public Var Model { get; set; } = new Var(); + public Var PredictorModel { get; set; } = new Var(); /// /// Transform model diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs index daed482a2a..87b9f31f4a 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs @@ -4,9 +4,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; -using System.Collections.Generic; using System.Linq; namespace Microsoft.ML.Models @@ -70,7 +68,7 @@ public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipe var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - Contracts.Assert(metric.Count == 1); + Contracts.Check(metric.Count == 1); return metric[0]; } diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs index a33475bd2a..49f1df5c57 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs @@ -19,7 +19,7 @@ private BinaryClassificationMetrics() { } - internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -28,41 +28,40 @@ internal static List FromMetrics(IHostEnvironment e var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - while (skipRows-- >= 0) + if (!enumerator.MoveNext()) { - if (!enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); - } + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); } List metrics = new List(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); + + int Index = 0; do { SerializationClass metric = enumerator.Current; - if (!confusionMatrices.MoveNext()) + if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } - metrics.Add( + metrics.Add( new BinaryClassificationMetrics() - { - Auc = metric.Auc, - Accuracy = metric.Accuracy, - PositivePrecision = metric.PositivePrecision, - PositiveRecall = metric.PositiveRecall, - NegativePrecision = metric.NegativePrecision, - NegativeRecall = metric.NegativeRecall, - LogLoss = metric.LogLoss, - LogLossReduction = metric.LogLossReduction, - Entropy = metric.Entropy, - F1Score = metric.F1Score, - Auprc = metric.Auprc, - ConfusionMatrix = confusionMatrices.Current, - }); + { + Auc = metric.Auc, + Accuracy = metric.Accuracy, + PositivePrecision = metric.PositivePrecision, + PositiveRecall = metric.PositiveRecall, + NegativePrecision = metric.NegativePrecision, + NegativeRecall = metric.NegativeRecall, + LogLoss = metric.LogLoss, + LogLossReduction = metric.LogLossReduction, + Entropy = metric.Entropy, + F1Score = metric.F1Score, + Auprc = metric.Auprc, + ConfusionMatrix = confusionMatrices.Current, + }); } while (enumerator.MoveNext()); diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index dcfe92608b..2933f3a8b6 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -5,7 +5,6 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; -using System.Collections.Generic; using System.Linq; namespace Microsoft.ML.Models @@ -70,7 +69,7 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - Contracts.Assert(metric.Count == 1); + Contracts.Check(metric.Count == 1); return metric[0]; } diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs index 2d036f6b6a..8c6efe139e 100644 --- a/src/Microsoft.ML/Models/ClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs @@ -18,7 +18,8 @@ private ClassificationMetrics() { } - internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int skipRows = 0) + internal static List FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, + int confusionMatriceStartIndex = 0) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); @@ -26,19 +27,18 @@ internal static List FromMetrics(IHostEnvironment env, ID var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - while (skipRows-- >= 0) + if (!enumerator.MoveNext()) { - if (!enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); - } + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); } - + List metrics = new List(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); + + int Index = 0; do { - if (!confusionMatrices.MoveNext()) + if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index d97669620f..5dcce75a00 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -17,8 +17,8 @@ public sealed partial class CrossValidator /// /// Class type that represents input schema. /// Class type that represents prediction schema. - /// Machine learning pipeline that contain may contain loader, transforms and at least one trainer. - /// List containning metrics and predictor model for each fold + /// Machine learning pipeline may contain loader, transforms and at least one trainer. + /// List containing metrics and predictor model for each fold public CrossValidationOutput CrossValidate(LearningPipeline pipeline) where TInput : class where TOutput : class, new() @@ -127,23 +127,20 @@ public CrossValidationOutput CrossValidate(Lea cvOutput.BinaryClassificationMetrics = BinaryClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix), - 2); + experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2); } else if(Kind == MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer) { cvOutput.ClassificationMetrics = ClassificationMetrics.FromMetrics( environment, experiment.GetOutput(crossValidateOutput.OverallMetrics), - experiment.GetOutput(crossValidateOutput.ConfusionMatrix), - 2); + experiment.GetOutput(crossValidateOutput.ConfusionMatrix), 2); } else if (Kind == MacroUtilsTrainerKinds.SignatureRegressorTrainer) { cvOutput.RegressionMetrics = RegressionMetrics.FromOverallMetrics( environment, - experiment.GetOutput(crossValidateOutput.OverallMetrics), - 2); + experiment.GetOutput(crossValidateOutput.OverallMetrics)); } else { diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs index 00bb74a9ef..19af8f7c3b 100644 --- a/src/Microsoft.ML/Models/RegressionMetrics.cs +++ b/src/Microsoft.ML/Models/RegressionMetrics.cs @@ -19,19 +19,16 @@ private RegressionMetrics() { } - internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics, int skipRows = 0) + internal static List FromOverallMetrics(IHostEnvironment env, IDataView overallMetrics) { Contracts.AssertValue(env); env.AssertValue(overallMetrics); var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); var enumerator = metricsEnumerable.GetEnumerator(); - while (skipRows-- >= 0) + if (!enumerator.MoveNext()) { - if (!enumerator.MoveNext()) - { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); - } + throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); } List metrics = new List(); diff --git a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs index 19f41e2483..2cd148bec1 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs @@ -30,7 +30,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] - public Var Model; + public Var PredictorModel; [Argument(ArgumentType.AtMostOnce, HelpText = "The transform model", SortOrder = 2)] public Var TransformModel; @@ -203,15 +203,15 @@ public static CommonOutputs.MacroOutput CrossValidate( VarName = mapping[input.Inputs.Data.VarName] }; - if (input.Outputs.Model != null && mapping.ContainsKey(input.Outputs.Model.VarName)) + if (input.Outputs.PredictorModel != null && mapping.ContainsKey(input.Outputs.PredictorModel.VarName)) { - args.Outputs.Model = new Var + args.Outputs.PredictorModel = new Var { - VarName = mapping[input.Outputs.Model.VarName] + VarName = mapping[input.Outputs.PredictorModel.VarName] }; } else - args.Outputs.Model = null; + args.Outputs.PredictorModel = null; if (input.Outputs.TransformModel != null && mapping.ContainsKey(input.Outputs.TransformModel.VarName)) { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index 869fcba3e2..5aefbb49fd 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -25,7 +25,7 @@ public sealed class SubGraphInput public sealed class SubGraphOutput { [Argument(ArgumentType.AtMostOnce, HelpText = "The predictor model", SortOrder = 1)] - public Var Model; + public Var PredictorModel; [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] public Var TransformModel; @@ -129,7 +129,7 @@ public static CommonOutputs.MacroOutput TrainTest( subGraphRunContext.RemoveVariable(dataVariable); // Change the subgraph to use the model variable as output. - varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.Model.VarName; + varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName; if (!subGraphRunContext.TryGetVariable(varName, out dataVariable)) throw env.Except($"Invalid variable name '{varName}'."); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index c7c199f2d1..a385917367 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -319,7 +319,7 @@ public void TestCrossValidationMacro() TransformModel = null }; crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); @@ -410,7 +410,7 @@ public void TestCrossValidationMacroWithMultiClass() TransformModel = null }; crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); @@ -541,7 +541,7 @@ public void TestCrossValidationMacroWithStratification() StratificationColumn = "Strat" }; crossValidate.Inputs.Data = nop.Data; - crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + crossValidate.Outputs.PredictorModel = modelCombineOutput.PredictorModel; var crossValidateOutput = experiment.Add(crossValidate); experiment.Compile(); experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 34b9317176..2343c7949a 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -1875,7 +1875,7 @@ public void EntryPointTrainTestMacroNoTransformInput() 'Data': '$data1' }, 'Outputs': { - 'Model': '$model' + 'PredictorModel': '$model' } }, 'Outputs': { @@ -1980,7 +1980,7 @@ public void EntryPointTrainTestMacro() 'Data': '$data1' }, 'Outputs': { - 'Model': '$model' + 'PredictorModel': '$model' } }, 'Outputs': { @@ -2108,7 +2108,7 @@ public void EntryPointChainedTrainTestMacros() 'Data': '$data1' }, 'Outputs': { - 'Model': '$model' + 'PredictorModel': '$model' } }, 'Outputs': { @@ -2141,7 +2141,7 @@ public void EntryPointChainedTrainTestMacros() 'Data': '$data4' }, 'Outputs': { - 'Model': '$model2' + 'PredictorModel': '$model2' } }, 'Outputs': { @@ -2274,7 +2274,7 @@ public void EntryPointChainedCrossValMacros() 'Data': '$data6' }, 'Outputs': { - 'Model': '$model' + 'PredictorModel': '$model' } }, 'Outputs': { @@ -2336,7 +2336,7 @@ public void EntryPointChainedCrossValMacros() 'Data': '$data4' }, 'Outputs': { - 'Model': '$model2' + 'PredictorModel': '$model2' } }, 'Outputs': { diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index ed9e04a404..f99fefe378 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -335,13 +335,45 @@ public void CrossValidateSentimentModelTest() var cv = new CrossValidator().CrossValidate(pipeline); + //First two items are average and std. deviation of metrics from the folds. Assert.Equal(2, cv.PredictorModels.Count()); Assert.Null(cv.ClassificationMetrics); Assert.Null(cv.RegressionMetrics); Assert.NotNull(cv.BinaryClassificationMetrics); - Assert.Equal(2, cv.BinaryClassificationMetrics.Count()); + Assert.Equal(4, cv.BinaryClassificationMetrics.Count()); + //Avergae of all folds. BinaryClassificationMetrics metrics = cv.BinaryClassificationMetrics[0]; + Assert.Equal(0.57023626091422708, metrics.Accuracy, 4); + Assert.Equal(0.54960689910161487, metrics.Auc, 1); + Assert.Equal(0.67048277219704255, metrics.Auprc, 2); + Assert.Equal(0, metrics.Entropy, 3); + Assert.Equal(0.68942642723130532, metrics.F1Score, 4); + Assert.Equal(0.97695909611968434, metrics.LogLoss, 3); + Assert.Equal(-3.050726259114541, metrics.LogLossReduction, 3); + Assert.Equal(0.37553879310344829, metrics.NegativePrecision, 3); + Assert.Equal(0.25683962264150945, metrics.NegativeRecall, 3); + Assert.Equal(0.63428539173628362, metrics.PositivePrecision, 3); + Assert.Equal(0.75795196364816619, metrics.PositiveRecall); + Assert.Null(metrics.ConfusionMatrix); + + //Std. Deviation. + metrics = cv.BinaryClassificationMetrics[1]; + Assert.Equal(0.039933230611196011, metrics.Accuracy, 4); + Assert.Equal(0.021066177821462407, metrics.Auc, 1); + Assert.Equal(0.045842033921572725, metrics.Auprc, 2); + Assert.Equal(0, metrics.Entropy, 3); + Assert.Equal(0.030085767890644915, metrics.F1Score, 4); + Assert.Equal(0.032906777175141941, metrics.LogLoss, 3); + Assert.Equal(0.86311349745170118, metrics.LogLossReduction, 3); + Assert.Equal(0.030711206896551647, metrics.NegativePrecision, 3); + Assert.Equal(0.068160377358490579, metrics.NegativeRecall, 3); + Assert.Equal(0.051761119891622735, metrics.PositivePrecision, 3); + Assert.Equal(0.0015417072379052127, metrics.PositiveRecall); + Assert.Null(metrics.ConfusionMatrix); + + //Fold 1. + metrics = cv.BinaryClassificationMetrics[2]; Assert.Equal(0.53030303030303028, metrics.Accuracy, 4); Assert.Equal(0.52854072128015284, metrics.Auc, 1); Assert.Equal(0.62464073827546951, metrics.Auprc, 2); @@ -370,7 +402,8 @@ public void CrossValidateSentimentModelTest() Assert.Equal(10, matrix[1, 1]); Assert.Equal(10, matrix["negative", "negative"]); - metrics = cv.BinaryClassificationMetrics[1]; + //Fold 2. + metrics = cv.BinaryClassificationMetrics[3]; Assert.Equal(0.61016949152542377, metrics.Accuracy, 4); Assert.Equal(0.57067307692307689, metrics.Auc, 1); Assert.Equal(0.71632480611861549, metrics.Auprc, 2); From 9d6fb01aed8dcc6e19a319b779850be837446943 Mon Sep 17 00:00:00 2001 From: Zeeshan Siddiqui Date: Mon, 28 May 2018 00:49:06 -0700 Subject: [PATCH 18/18] PR feedback. --- .../Common/EntryPoints/core_manifest.json | 18 ------------------ src/Microsoft.ML/CSharpApi.cs | 10 ---------- .../Models/BinaryClassificationEvaluator.cs | 3 +-- .../Models/BinaryClassificationMetrics.cs | 13 +++++-------- .../Models/ClassificationEvaluator.cs | 3 +-- .../Models/ClassificationMetrics.cs | 14 ++++++-------- src/Microsoft.ML/Models/CrossValidator.cs | 2 +- src/Microsoft.ML/Models/RegressionEvaluator.cs | 5 +---- src/Microsoft.ML/Models/RegressionMetrics.cs | 13 +++++-------- src/Microsoft.ML/Models/TrainTestEvaluator.cs | 2 +- .../EntryPoints/CrossValidationMacro.cs | 11 +++-------- .../Runtime/EntryPoints/TrainTestMacro.cs | 15 ++++++--------- 12 files changed, 30 insertions(+), 79 deletions(-) diff --git a/ZBaselines/Common/EntryPoints/core_manifest.json b/ZBaselines/Common/EntryPoints/core_manifest.json index 7e7139ef3f..010d4a0afa 100644 --- a/ZBaselines/Common/EntryPoints/core_manifest.json +++ b/ZBaselines/Common/EntryPoints/core_manifest.json @@ -1453,15 +1453,6 @@ "SortOrder": 2.0, "IsNullable": false, "Default": null - }, - { - "Name": "UseTransformModel", - "Type": "Bool", - "Desc": "Indicates to use transform model instead of predictor model.", - "Required": false, - "SortOrder": 3.0, - "IsNullable": false, - "Default": false } ] }, @@ -3071,15 +3062,6 @@ "SortOrder": 2.0, "IsNullable": false, "Default": null - }, - { - "Name": "UseTransformModel", - "Type": "Bool", - "Desc": "Indicates to use transform model instead of predictor model.", - "Required": false, - "SortOrder": 3.0, - "IsNullable": false, - "Default": false } ] }, diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index fa23bda6fa..058e8bafe3 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -2221,11 +2221,6 @@ public sealed partial class CrossValidationMacroSubGraphOutput /// public Var TransformModel { get; set; } = new Var(); - /// - /// Indicates to use transform model instead of predictor model. - /// - public bool UseTransformModel { get; set; } = false; - } /// @@ -3407,11 +3402,6 @@ public sealed partial class TrainTestMacroSubGraphOutput /// public Var TransformModel { get; set; } = new Var(); - /// - /// Indicates to use transform model instead of predictor model. - /// - public bool UseTransformModel { get; set; } = false; - } /// diff --git a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs index 87b9f31f4a..1a670fc854 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs @@ -5,7 +5,6 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; -using System.Linq; namespace Microsoft.ML.Models { @@ -68,7 +67,7 @@ public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipe var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - Contracts.Check(metric.Count == 1); + Contracts.Check(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); return metric[0]; } diff --git a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs index 49f1df5c57..f536f30ed0 100644 --- a/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/BinaryClassificationMetrics.cs @@ -26,20 +26,17 @@ internal static List FromMetrics(IHostEnvironment e env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); - var enumerator = metricsEnumerable.GetEnumerator(); - - if (!enumerator.MoveNext()) + if (!metricsEnumerable.GetEnumerator().MoveNext()) { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + throw env.Except("The overall RegressionMetrics didn't have any rows."); } List metrics = new List(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int Index = 0; - do + foreach(var metric in metricsEnumerable) { - SerializationClass metric = enumerator.Current; if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { @@ -63,7 +60,7 @@ internal static List FromMetrics(IHostEnvironment e ConfusionMatrix = confusionMatrices.Current, }); - } while (enumerator.MoveNext()); + } return metrics; } @@ -168,7 +165,7 @@ internal static List FromMetrics(IHostEnvironment e /// /// This class contains the public fields necessary to deserialize from IDataView. /// - private class SerializationClass + private sealed class SerializationClass { #pragma warning disable 649 // never assigned [ColumnName(BinaryClassifierEvaluator.Auc)] diff --git a/src/Microsoft.ML/Models/ClassificationEvaluator.cs b/src/Microsoft.ML/Models/ClassificationEvaluator.cs index 2933f3a8b6..8fedc3fb4f 100644 --- a/src/Microsoft.ML/Models/ClassificationEvaluator.cs +++ b/src/Microsoft.ML/Models/ClassificationEvaluator.cs @@ -5,7 +5,6 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Transforms; -using System.Linq; namespace Microsoft.ML.Models { @@ -69,7 +68,7 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix); - Contracts.Check(metric.Count == 1); + Contracts.Check(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); return metric[0]; } diff --git a/src/Microsoft.ML/Models/ClassificationMetrics.cs b/src/Microsoft.ML/Models/ClassificationMetrics.cs index 8c6efe139e..f3a2416bca 100644 --- a/src/Microsoft.ML/Models/ClassificationMetrics.cs +++ b/src/Microsoft.ML/Models/ClassificationMetrics.cs @@ -26,24 +26,22 @@ internal static List FromMetrics(IHostEnvironment env, ID env.AssertValue(confusionMatrix); var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); - var enumerator = metricsEnumerable.GetEnumerator(); - if (!enumerator.MoveNext()) + if (!metricsEnumerable.GetEnumerator().MoveNext()) { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + throw env.Except("The overall RegressionMetrics didn't have any rows."); } List metrics = new List(); var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator(); int Index = 0; - do + foreach (var metric in metricsEnumerable) { if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext()) { throw env.Except("Confusion matrices didn't have enough matrices."); } - - SerializationClass metric = enumerator.Current; + metrics.Add( new ClassificationMetrics() { @@ -56,7 +54,7 @@ internal static List FromMetrics(IHostEnvironment env, ID ConfusionMatrix = confusionMatrices.Current }); - } while (enumerator.MoveNext()); + } return metrics; } @@ -137,7 +135,7 @@ internal static List FromMetrics(IHostEnvironment env, ID /// /// This class contains the public fields necessary to deserialize from IDataView. /// - private class SerializationClass + private sealed class SerializationClass { #pragma warning disable 649 // never assigned [ColumnName(MultiClassClassifierEvaluator.AccuracyMicro)] diff --git a/src/Microsoft.ML/Models/CrossValidator.cs b/src/Microsoft.ML/Models/CrossValidator.cs index 5dcce75a00..173e03916c 100644 --- a/src/Microsoft.ML/Models/CrossValidator.cs +++ b/src/Microsoft.ML/Models/CrossValidator.cs @@ -105,8 +105,8 @@ public CrossValidationOutput CrossValidate(Lea Nodes = subGraph; TransformModel = null; Inputs.Data = firstTransform.GetInputData(); + Outputs.PredictorModel = null; Outputs.TransformModel = lastTransformModel; - Outputs.UseTransformModel = true; var crossValidateOutput = experiment.Add(this); experiment.Compile(); foreach (ILearningPipelineLoader loader in loaders) diff --git a/src/Microsoft.ML/Models/RegressionEvaluator.cs b/src/Microsoft.ML/Models/RegressionEvaluator.cs index b707e87c6b..2cb05ee092 100644 --- a/src/Microsoft.ML/Models/RegressionEvaluator.cs +++ b/src/Microsoft.ML/Models/RegressionEvaluator.cs @@ -4,10 +4,7 @@ using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Transforms; -using System.Collections.Generic; -using System.Linq; namespace Microsoft.ML.Models { @@ -65,7 +62,7 @@ public RegressionMetrics Evaluate(PredictionModel model, ILearningPipelineLoader var metric = RegressionMetrics.FromOverallMetrics(environment, overallMetrics); - Contracts.Assert(metric.Count == 1); + Contracts.Assert(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics"); return metric[0]; } diff --git a/src/Microsoft.ML/Models/RegressionMetrics.cs b/src/Microsoft.ML/Models/RegressionMetrics.cs index 19af8f7c3b..68f9af2feb 100644 --- a/src/Microsoft.ML/Models/RegressionMetrics.cs +++ b/src/Microsoft.ML/Models/RegressionMetrics.cs @@ -25,16 +25,14 @@ internal static List FromOverallMetrics(IHostEnvironment env, env.AssertValue(overallMetrics); var metricsEnumerable = overallMetrics.AsEnumerable(env, true, ignoreMissingColumns: true); - var enumerator = metricsEnumerable.GetEnumerator(); - if (!enumerator.MoveNext()) + if (!metricsEnumerable.GetEnumerator().MoveNext()) { - throw env.Except("The overall RegressionMetrics didn't have sufficient rows."); + throw env.Except("The overall RegressionMetrics didn't have any rows."); } List metrics = new List(); - do + foreach (var metric in metricsEnumerable) { - SerializationClass metric = enumerator.Current; metrics.Add(new RegressionMetrics() { L1 = metric.L1, @@ -43,8 +41,7 @@ internal static List FromOverallMetrics(IHostEnvironment env, LossFn = metric.LossFn, RSquared = metric.RSquared, }); - - } while (enumerator.MoveNext()); + } return metrics; } @@ -96,7 +93,7 @@ internal static List FromOverallMetrics(IHostEnvironment env, /// /// This class contains the public fields necessary to deserialize from IDataView. /// - private class SerializationClass + private sealed class SerializationClass { #pragma warning disable 649 // never assigned [ColumnName(Runtime.Data.RegressionEvaluator.L1)] diff --git a/src/Microsoft.ML/Models/TrainTestEvaluator.cs b/src/Microsoft.ML/Models/TrainTestEvaluator.cs index 657b27e0f3..19261e82de 100644 --- a/src/Microsoft.ML/Models/TrainTestEvaluator.cs +++ b/src/Microsoft.ML/Models/TrainTestEvaluator.cs @@ -108,8 +108,8 @@ public TrainTestEvaluatorOutput TrainTestEvaluate TransformModel; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 3)] - public bool UseTransformModel = false; } public sealed class Arguments @@ -222,9 +219,7 @@ public static CommonOutputs.MacroOutput CrossValidate( } else args.Outputs.TransformModel = null; - - args.Outputs.UseTransformModel = input.Outputs.UseTransformModel; - + // Set train/test trainer kind to match. args.Kind = input.Kind; @@ -240,7 +235,7 @@ public static CommonOutputs.MacroOutput CrossValidate( var outputMap = new Dictionary(); var transformModelVar = new Var(); var predModelVar = new Var(); - if (input.Outputs.UseTransformModel) + if (input.Outputs.PredictorModel == null) { outputMap.Add(nameof(TrainTestMacro.Output.TransformModel), transformModelVar.VarName); transformModelVars[k] = transformModelVar; @@ -302,7 +297,7 @@ public static CommonOutputs.MacroOutput CrossValidate( // Convert predictors from all folds into an array of predictors. - if (input.Outputs.UseTransformModel) + if (input.Outputs.PredictorModel == null) { var outModels = new ML.Data.TransformModelArrayConverter { diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs index 5aefbb49fd..edd4cf6e5b 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestMacro.cs @@ -29,9 +29,6 @@ public sealed class SubGraphOutput [Argument(ArgumentType.AtMostOnce, HelpText = "Transform model", SortOrder = 2)] public Var TransformModel; - - [Argument(ArgumentType.AtMostOnce, HelpText = "Indicates to use transform model instead of predictor model.", SortOrder = 3)] - public bool UseTransformModel = false; } public sealed class Arguments @@ -129,11 +126,11 @@ public static CommonOutputs.MacroOutput TrainTest( subGraphRunContext.RemoveVariable(dataVariable); // Change the subgraph to use the model variable as output. - varName = input.Outputs.UseTransformModel ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName; + varName = input.Outputs.PredictorModel == null ? input.Outputs.TransformModel.VarName : input.Outputs.PredictorModel.VarName; if (!subGraphRunContext.TryGetVariable(varName, out dataVariable)) throw env.Except($"Invalid variable name '{varName}'."); - string outputVarName = input.Outputs.UseTransformModel ? node.GetOutputVariableName(nameof(Output.TransformModel)) : + string outputVarName = input.Outputs.PredictorModel == null ? node.GetOutputVariableName(nameof(Output.TransformModel)) : node.GetOutputVariableName(nameof(Output.PredictorModel)); foreach (var subGraphNode in subGraphNodes) @@ -153,7 +150,7 @@ public static CommonOutputs.MacroOutput TrainTest( DatasetScorer.Output scoreNodeOutput = null; ML.Models.DatasetTransformer.Output datasetTransformNodeOutput = null; - if (input.Outputs.UseTransformModel) + if (input.Outputs.PredictorModel == null) { //combine the predictor model with any potential transfrom model passed from the outer graph if (transformModelVarName != null && transformModelVarName.VariableName != null) @@ -222,7 +219,7 @@ public static CommonOutputs.MacroOutput TrainTest( { DatasetScorer.Output scoreNodeTrainingOutput = null; ML.Models.DatasetTransformer.Output datasetTransformNodeTrainingOutput = null; - if (input.Outputs.UseTransformModel) + if (input.Outputs.PredictorModel == null) { var datasetTransformerNode = new Models.DatasetTransformer { @@ -252,7 +249,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutputTraining = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNodeTraining = evalInputOutputTraining.Item1; var evalOutputTraining = evalInputOutputTraining.Item2; - evalNodeTraining.Data.VarName = input.Outputs.UseTransformModel ? datasetTransformNodeTrainingOutput.OutputData.VarName : + evalNodeTraining.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeTrainingOutput.OutputData.VarName : scoreNodeTrainingOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.TrainingWarnings), out outVariableName)) @@ -276,7 +273,7 @@ public static CommonOutputs.MacroOutput TrainTest( var evalInputOutput = MacroUtils.GetEvaluatorInputOutput(input.Kind, settings); var evalNode = evalInputOutput.Item1; var evalOutput = evalInputOutput.Item2; - evalNode.Data.VarName = input.Outputs.UseTransformModel ? datasetTransformNodeOutput.OutputData.VarName : scoreNodeOutput.ScoredData.VarName; + evalNode.Data.VarName = input.Outputs.PredictorModel == null ? datasetTransformNodeOutput.OutputData.VarName : scoreNodeOutput.ScoredData.VarName; if (node.OutputMap.TryGetValue(nameof(Output.Warnings), out outVariableName)) evalOutput.Warnings.VarName = outVariableName;