diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs index db4f207fa3..f3d84812ce 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs @@ -87,6 +87,12 @@ public void SaveTo(IHostEnvironment env, Stream outputStream) /// public static class CompositeDataReader { + /// + /// Save the contents to a stream, as a "model file". + /// + public static void SaveTo(this IDataReader reader, IHostEnvironment env, Stream outputStream) + => new CompositeDataReader(reader).SaveTo(env, outputStream); + /// /// Load the pipeline from stream. /// diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs index b15bd1b317..c2593654cf 100644 --- a/src/Microsoft.ML.Data/Training/TrainContext.cs +++ b/src/Microsoft.ML.Data/Training/TrainContext.cs @@ -2,9 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Transforms; +using System; +using System.Collections.Generic; +using System.Linq; -namespace Microsoft.ML.Runtime.Training +namespace Microsoft.ML { /// /// A training context is an object instantiable by a user to do various tasks relating to a particular @@ -16,6 +22,90 @@ public abstract class TrainContextBase protected readonly IHost Host; internal IHostEnvironment Environment => Host; + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null) + { + Host.CheckValue(data, nameof(data)); + Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); + Host.CheckValueOrNull(stratificationColumn); + + EnsureStratificationColumn(ref data, ref stratificationColumn); + + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments() + { + Column = stratificationColumn, + Min = 0, + Max = testFraction, + Complement = false + }, data); + + return (trainFilter, testFilter); + } + + /// + /// Train the on folds of the data sequentially. + /// Return each model and each scored test dataset. + /// + protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator, + int numFolds, string stratificationColumn) + { + Host.CheckValue(data, nameof(data)); + Host.CheckValue(estimator, nameof(estimator)); + Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + Host.CheckValueOrNull(stratificationColumn); + + EnsureStratificationColumn(ref data, ref stratificationColumn); + + Func foldFunction = + fold => + { + var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments + { + Column = stratificationColumn, + Min = (double)fold / numFolds, + Max = (double)(fold + 1) / numFolds, + Complement = true + }, data); + var testFilter = new RangeFilter(Host, new RangeFilter.Arguments + { + Column = stratificationColumn, + Min = (double)fold / numFolds, + Max = (double)(fold + 1) / numFolds, + Complement = false + }, data); + + var model = estimator.Fit(trainFilter); + var scoredTest = model.Transform(testFilter); + return (scoredTest, model); + }; + + // Sequential per-fold training. + // REVIEW: we could have a parallel implementation here. We would need to + // spawn off a separate host per fold in that case. + var result = new List<(IDataView scores, ITransformer model)>(); + for (int fold = 0; fold < numFolds; fold++) + result.Add(foldFunction(fold)); + + return result.ToArray(); + } + protected TrainContextBase(IHostEnvironment env, string registrationName) { Contracts.CheckValue(env, nameof(env)); @@ -23,6 +113,45 @@ protected TrainContextBase(IHostEnvironment env, string registrationName) Host = env.Register(registrationName); } + /// + /// Make sure the provided is valid + /// for , hash it if needed, or introduce a new one + /// if needed. + /// + private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn) + { + // We need to handle two cases: if the stratification column is provided, we use hashJoin to + // build a single hash of it. If it is not, we generate a random number. + + if (stratificationColumn == null) + { + stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn"); + data = new GenerateNumberTransform(Host, data, stratificationColumn); + } + else + { + if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol)) + throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn); + + var type = data.Schema.GetColumnType(stratCol); + if (!RangeFilter.IsValidRangeFilterColumnType(Host, type)) + { + // Hash the stratification column. + // REVIEW: this could currently crash, since Hash only accepts a limited set + // of column types. It used to be HashJoin, but we should probably extend Hash + // instead of having two hash transformations. + var origStratCol = stratificationColumn; + int tmp; + int inc = 0; + + // Generate a new column with the hashed stratification column. + while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp)) + stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); + data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data); + } + } + } + /// /// Subclasses of will provide little "extension method" hookable objects /// (e.g., something like ). User code will only @@ -140,6 +269,50 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st var eval = new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments() { }); return eval.Evaluate(data, label, score, predictedLabel); } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } /// @@ -191,6 +364,28 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe var eval = new MultiClassClassifierEvaluator(Host, args); return eval.Evaluate(data, label, score, predictedLabel); } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } /// @@ -233,5 +428,27 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { }); return eval.Evaluate(data, label, score); } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( + IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) + { + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); + } } } diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs new file mode 100644 index 0000000000..8cfda0485e --- /dev/null +++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs @@ -0,0 +1,282 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.StaticPipe; +using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Microsoft.ML +{ + /// + /// Defines static extension methods that allow operations like train-test split, cross-validate, + /// sampling etc. with the . + /// + public static class TrainingStaticExtensions + { + /// + /// Split the dataset into the train set and test set according to the given fraction. + /// Respects the if provided. + /// + /// The tuple describing the data schema. + /// The training context. + /// The dataset to split. + /// The fraction of data to go into the test set. + /// Optional selector for the stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// A pair of datasets, for the train and test set. + public static (DataView trainSet, DataView testSet) TrainTestSplit(this TrainContextBase context, + DataView data, double testFraction = 0.1, Func stratificationColumn = null) + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive"); + env.CheckValueOrNull(stratificationColumn); + + string stratName = null; + + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName); + return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape)); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (RegressionEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this RegressionContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (MultiClassClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this MulticlassClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (BinaryClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidateNonCalibrated( + this BinaryClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidateNonCalibrated(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + + /// + /// Run cross-validation over folds of , by fitting , + /// and respecting if provided. + /// Then evaluate each sub-model against and return metrics. + /// + /// The input schema shape. + /// The output schema shape. + /// The type of the trained model. + /// The training context. + /// The data to run cross-validation on. + /// The estimator to fit. + /// Number of cross-validation folds. + /// The label column (for evaluation). + /// Optional stratification column. + /// If two examples share the same value of the (if provided), + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from + /// train to the test set. + /// Per-fold results: metrics, models, scored datasets. + public static (BinaryClassifierEvaluator.CalibratedResult metrics, Transformer model, DataView scoredTestData)[] CrossValidate( + this BinaryClassificationContext context, + DataView data, + Estimator estimator, + Func> label, + int numFolds = 5, + Func stratificationColumn = null) + where TTransformer : class, ITransformer + { + var env = StaticPipeUtils.GetEnvironment(data); + Contracts.AssertValue(env); + env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1"); + env.CheckValue(label, nameof(label)); + env.CheckValueOrNull(stratificationColumn); + + var outputIndexer = StaticPipeUtils.GetIndexer(estimator); + var labelColumn = label(outputIndexer.Indices); + env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found"); + var labelName = outputIndexer.Get(labelColumn); + + string stratName = null; + if (stratificationColumn != null) + { + var indexer = StaticPipeUtils.GetIndexer(data); + var column = stratificationColumn(indexer.Indices); + env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found"); + stratName = indexer.Get(column); + } + + var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName); + + return results.Select(x => ( + x.metrics, + new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape), + new DataView(env, x.scoredTestData, estimator.Shape))) + .ToArray(); + } + } +} diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs index c11e846bfb..113092ed26 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs @@ -236,7 +236,7 @@ public static (Vector score, Key predictedLabel) { Contracts.CheckValue(label, nameof(label)); Contracts.CheckValue(features, nameof(features)); - Contracts.CheckValue(loss, nameof(loss)); + Contracts.CheckValueOrNull(loss); Contracts.CheckValueOrNull(weights); Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified."); Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified."); diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index ada8d755db..b873e8262d 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -5,6 +5,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.StaticPipe; using Microsoft.ML.TestFramework; using Microsoft.ML.Transforms; @@ -13,6 +14,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.IO; +using System.Linq; using System.Text; using Xunit; using Xunit.Abstractions; @@ -613,5 +615,35 @@ public void FeatureSelection() type = schema.GetColumnType(bagofwordMiCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } + + [Fact] + public void TrainTestSplit() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new BinaryClassificationContext(env); + + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadFloat(0), features: c.LoadFloat(1, 4))); + var data = reader.Read(dataSource); + + var (train, test) = ctx.TrainTestSplit(data, 0.5); + + // Just make sure that the train is about the same size as the test set. + var trainCount = train.GetColumn(r => r.label).Count(); + var testCount = test.GetColumn(r => r.label).Count(); + + Assert.InRange(trainCount * 1.0 / testCount, 0.8, 1.2); + + // Now stratify by label. Silly thing to do. + (train, test) = ctx.TrainTestSplit(data, 0.5, stratificationColumn: r => r.label); + var trainLabels = train.GetColumn(r => r.label).Distinct(); + var testLabels = test.GetColumn(r => r.label).Distinct(); + Assert.True(trainLabels.Count() > 0); + Assert.True(testLabels.Count() > 0); + Assert.False(trainLabels.Intersect(testLabels).Any()); + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs index e3fc685c6d..55a5b6e76b 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs @@ -13,6 +13,7 @@ using Microsoft.ML.Runtime.Training; using Microsoft.ML.Trainers; using System; +using System.Linq; using Xunit; using Xunit.Abstractions; @@ -73,7 +74,7 @@ public void SdcaRegressionNameCollision() var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); var dataSource = new MultiFileSource(dataPath); var ctx = new RegressionContext(env); - + // Here we introduce another column called "Score" to collide with the name of the default output. Heh heh heh... var reader = TextLoader.CreateReader(env, c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10), Score: c.LoadText(2)), @@ -263,6 +264,30 @@ public void SdcaMulticlass() Assert.True(metrics.TopKAccuracy > 0); } + [Fact] + public void CrossValidate() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath(TestDatasets.iris.trainFilename); + var dataSource = new MultiFileSource(dataPath); + + var ctx = new MulticlassClassificationContext(env); + var reader = TextLoader.CreateReader(env, + c => (label: c.LoadText(0), features: c.LoadFloat(1, 4))); + + var est = reader.MakeNewEstimator() + .Append(r => (label: r.label.ToKey(), r.features)) + .Append(r => (r.label, preds: ctx.Trainers.Sdca( + r.label, + r.features, + maxIterations: 2))); + + var results = ctx.CrossValidate(reader.Read(dataSource), est, r => r.label) + .Select(x => x.metrics).ToArray(); + Assert.Equal(5, results.Length); + Assert.True(results.All(x => x.LogLoss > 0)); + } + [Fact] public void FastTreeBinaryClassification() {