diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
index db4f207fa3..f3d84812ce 100644
--- a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
@@ -87,6 +87,12 @@ public void SaveTo(IHostEnvironment env, Stream outputStream)
///
public static class CompositeDataReader
{
+ ///
+ /// Save the contents to a stream, as a "model file".
+ ///
+ public static void SaveTo(this IDataReader reader, IHostEnvironment env, Stream outputStream)
+ => new CompositeDataReader(reader).SaveTo(env, outputStream);
+
///
/// Load the pipeline from stream.
///
diff --git a/src/Microsoft.ML.Data/Training/TrainContext.cs b/src/Microsoft.ML.Data/Training/TrainContext.cs
index b15bd1b317..c2593654cf 100644
--- a/src/Microsoft.ML.Data/Training/TrainContext.cs
+++ b/src/Microsoft.ML.Data/Training/TrainContext.cs
@@ -2,9 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Transforms;
+using System;
+using System.Collections.Generic;
+using System.Linq;
-namespace Microsoft.ML.Runtime.Training
+namespace Microsoft.ML
{
///
/// A training context is an object instantiable by a user to do various tasks relating to a particular
@@ -16,6 +22,90 @@ public abstract class TrainContextBase
protected readonly IHost Host;
internal IHostEnvironment Environment => Host;
+ ///
+ /// Split the dataset into the train set and test set according to the given fraction.
+ /// Respects the if provided.
+ ///
+ /// The dataset to split.
+ /// The fraction of data to go into the test set.
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// A pair of datasets, for the train and test set.
+ public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null)
+ {
+ Host.CheckValue(data, nameof(data));
+ Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
+ Host.CheckValueOrNull(stratificationColumn);
+
+ EnsureStratificationColumn(ref data, ref stratificationColumn);
+
+ var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments()
+ {
+ Column = stratificationColumn,
+ Min = 0,
+ Max = testFraction,
+ Complement = true
+ }, data);
+ var testFilter = new RangeFilter(Host, new RangeFilter.Arguments()
+ {
+ Column = stratificationColumn,
+ Min = 0,
+ Max = testFraction,
+ Complement = false
+ }, data);
+
+ return (trainFilter, testFilter);
+ }
+
+ ///
+ /// Train the on folds of the data sequentially.
+ /// Return each model and each scored test dataset.
+ ///
+ protected (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator estimator,
+ int numFolds, string stratificationColumn)
+ {
+ Host.CheckValue(data, nameof(data));
+ Host.CheckValue(estimator, nameof(estimator));
+ Host.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
+ Host.CheckValueOrNull(stratificationColumn);
+
+ EnsureStratificationColumn(ref data, ref stratificationColumn);
+
+ Func foldFunction =
+ fold =>
+ {
+ var trainFilter = new RangeFilter(Host, new RangeFilter.Arguments
+ {
+ Column = stratificationColumn,
+ Min = (double)fold / numFolds,
+ Max = (double)(fold + 1) / numFolds,
+ Complement = true
+ }, data);
+ var testFilter = new RangeFilter(Host, new RangeFilter.Arguments
+ {
+ Column = stratificationColumn,
+ Min = (double)fold / numFolds,
+ Max = (double)(fold + 1) / numFolds,
+ Complement = false
+ }, data);
+
+ var model = estimator.Fit(trainFilter);
+ var scoredTest = model.Transform(testFilter);
+ return (scoredTest, model);
+ };
+
+ // Sequential per-fold training.
+ // REVIEW: we could have a parallel implementation here. We would need to
+ // spawn off a separate host per fold in that case.
+ var result = new List<(IDataView scores, ITransformer model)>();
+ for (int fold = 0; fold < numFolds; fold++)
+ result.Add(foldFunction(fold));
+
+ return result.ToArray();
+ }
+
protected TrainContextBase(IHostEnvironment env, string registrationName)
{
Contracts.CheckValue(env, nameof(env));
@@ -23,6 +113,45 @@ protected TrainContextBase(IHostEnvironment env, string registrationName)
Host = env.Register(registrationName);
}
+ ///
+ /// Make sure the provided is valid
+ /// for , hash it if needed, or introduce a new one
+ /// if needed.
+ ///
+ private void EnsureStratificationColumn(ref IDataView data, ref string stratificationColumn)
+ {
+ // We need to handle two cases: if the stratification column is provided, we use hashJoin to
+ // build a single hash of it. If it is not, we generate a random number.
+
+ if (stratificationColumn == null)
+ {
+ stratificationColumn = data.Schema.GetTempColumnName("StratificationColumn");
+ data = new GenerateNumberTransform(Host, data, stratificationColumn);
+ }
+ else
+ {
+ if (!data.Schema.TryGetColumnIndex(stratificationColumn, out int stratCol))
+ throw Host.ExceptSchemaMismatch(nameof(stratificationColumn), "stratification", stratificationColumn);
+
+ var type = data.Schema.GetColumnType(stratCol);
+ if (!RangeFilter.IsValidRangeFilterColumnType(Host, type))
+ {
+ // Hash the stratification column.
+ // REVIEW: this could currently crash, since Hash only accepts a limited set
+ // of column types. It used to be HashJoin, but we should probably extend Hash
+ // instead of having two hash transformations.
+ var origStratCol = stratificationColumn;
+ int tmp;
+ int inc = 0;
+
+ // Generate a new column with the hashed stratification column.
+ while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
+ stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
+ data = new HashEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data);
+ }
+ }
+ }
+
///
/// Subclasses of will provide little "extension method" hookable objects
/// (e.g., something like ). User code will only
@@ -140,6 +269,50 @@ public BinaryClassifierEvaluator.Result EvaluateNonCalibrated(IDataView data, st
var eval = new BinaryClassifierEvaluator(Host, new BinaryClassifierEvaluator.Arguments() { });
return eval.Evaluate(data, label, score, predictedLabel);
}
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public (BinaryClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
+ IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
+ {
+ Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
+ var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
+ return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ }
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public (BinaryClassifierEvaluator.CalibratedResult metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
+ {
+ Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
+ var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
+ return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ }
}
///
@@ -191,6 +364,28 @@ public MultiClassClassifierEvaluator.Result Evaluate(IDataView data, string labe
var eval = new MultiClassClassifierEvaluator(Host, args);
return eval.Evaluate(data, label, score, predictedLabel);
}
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public (MultiClassClassifierEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
+ {
+ Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
+ var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
+ return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ }
}
///
@@ -233,5 +428,27 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string
var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { });
return eval.Evaluate(data, label, score);
}
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
+ IDataView data, IEstimator estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
+ {
+ Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
+ var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
+ return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
+ }
}
}
diff --git a/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs
new file mode 100644
index 0000000000..8cfda0485e
--- /dev/null
+++ b/src/Microsoft.ML.Data/Training/TrainingStaticExtensions.cs
@@ -0,0 +1,282 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.StaticPipe;
+using Microsoft.ML.StaticPipe.Runtime;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Microsoft.ML
+{
+ ///
+ /// Defines static extension methods that allow operations like train-test split, cross-validate,
+ /// sampling etc. with the .
+ ///
+ public static class TrainingStaticExtensions
+ {
+ ///
+ /// Split the dataset into the train set and test set according to the given fraction.
+ /// Respects the if provided.
+ ///
+ /// The tuple describing the data schema.
+ /// The training context.
+ /// The dataset to split.
+ /// The fraction of data to go into the test set.
+ /// Optional selector for the stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// A pair of datasets, for the train and test set.
+ public static (DataView trainSet, DataView testSet) TrainTestSplit(this TrainContextBase context,
+ DataView data, double testFraction = 0.1, Func stratificationColumn = null)
+ {
+ var env = StaticPipeUtils.GetEnvironment(data);
+ Contracts.AssertValue(env);
+ env.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
+ env.CheckValueOrNull(stratificationColumn);
+
+ string stratName = null;
+
+ if (stratificationColumn != null)
+ {
+ var indexer = StaticPipeUtils.GetIndexer(data);
+ var column = stratificationColumn(indexer.Indices);
+ env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found");
+ stratName = indexer.Get(column);
+ }
+
+ var (trainData, testData) = context.TrainTestSplit(data.AsDynamic, testFraction, stratName);
+ return (new DataView(env, trainData, data.Shape), new DataView(env, testData, data.Shape));
+ }
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The input schema shape.
+ /// The output schema shape.
+ /// The type of the trained model.
+ /// The training context.
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public static (RegressionEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate(
+ this RegressionContext context,
+ DataView data,
+ Estimator estimator,
+ Func> label,
+ int numFolds = 5,
+ Func stratificationColumn = null)
+ where TTransformer : class, ITransformer
+ {
+ var env = StaticPipeUtils.GetEnvironment(data);
+ Contracts.AssertValue(env);
+ env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
+ env.CheckValue(label, nameof(label));
+ env.CheckValueOrNull(stratificationColumn);
+
+ var outIndexer = StaticPipeUtils.GetIndexer(estimator);
+ var labelColumn = label(outIndexer.Indices);
+ env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found");
+ var labelName = outIndexer.Get(labelColumn);
+
+ string stratName = null;
+ if (stratificationColumn != null)
+ {
+ var indexer = StaticPipeUtils.GetIndexer(data);
+ var column = stratificationColumn(indexer.Indices);
+ env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found");
+ stratName = indexer.Get(column);
+ }
+
+ var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName);
+
+ return results.Select(x => (
+ x.metrics,
+ new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
+ new DataView(env, x.scoredTestData, estimator.Shape)))
+ .ToArray();
+ }
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The input schema shape.
+ /// The output schema shape.
+ /// The type of the trained model.
+ /// The training context.
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public static (MultiClassClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidate(
+ this MulticlassClassificationContext context,
+ DataView data,
+ Estimator estimator,
+ Func> label,
+ int numFolds = 5,
+ Func stratificationColumn = null)
+ where TTransformer : class, ITransformer
+ {
+ var env = StaticPipeUtils.GetEnvironment(data);
+ Contracts.AssertValue(env);
+ env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
+ env.CheckValue(label, nameof(label));
+ env.CheckValueOrNull(stratificationColumn);
+
+ var outputIndexer = StaticPipeUtils.GetIndexer(estimator);
+ var labelColumn = label(outputIndexer.Indices);
+ env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found");
+ var labelName = outputIndexer.Get(labelColumn);
+
+ string stratName = null;
+ if (stratificationColumn != null)
+ {
+ var indexer = StaticPipeUtils.GetIndexer(data);
+ var column = stratificationColumn(indexer.Indices);
+ env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found");
+ stratName = indexer.Get(column);
+ }
+
+ var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName);
+
+ return results.Select(x => (
+ x.metrics,
+ new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
+ new DataView(env, x.scoredTestData, estimator.Shape)))
+ .ToArray();
+ }
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The input schema shape.
+ /// The output schema shape.
+ /// The type of the trained model.
+ /// The training context.
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public static (BinaryClassifierEvaluator.Result metrics, Transformer model, DataView scoredTestData)[] CrossValidateNonCalibrated(
+ this BinaryClassificationContext context,
+ DataView data,
+ Estimator estimator,
+ Func> label,
+ int numFolds = 5,
+ Func stratificationColumn = null)
+ where TTransformer : class, ITransformer
+ {
+ var env = StaticPipeUtils.GetEnvironment(data);
+ Contracts.AssertValue(env);
+ env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
+ env.CheckValue(label, nameof(label));
+ env.CheckValueOrNull(stratificationColumn);
+
+ var outputIndexer = StaticPipeUtils.GetIndexer(estimator);
+ var labelColumn = label(outputIndexer.Indices);
+ env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found");
+ var labelName = outputIndexer.Get(labelColumn);
+
+ string stratName = null;
+ if (stratificationColumn != null)
+ {
+ var indexer = StaticPipeUtils.GetIndexer(data);
+ var column = stratificationColumn(indexer.Indices);
+ env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found");
+ stratName = indexer.Get(column);
+ }
+
+ var results = context.CrossValidateNonCalibrated(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName);
+
+ return results.Select(x => (
+ x.metrics,
+ new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
+ new DataView(env, x.scoredTestData, estimator.Shape)))
+ .ToArray();
+ }
+
+ ///
+ /// Run cross-validation over folds of , by fitting ,
+ /// and respecting if provided.
+ /// Then evaluate each sub-model against and return metrics.
+ ///
+ /// The input schema shape.
+ /// The output schema shape.
+ /// The type of the trained model.
+ /// The training context.
+ /// The data to run cross-validation on.
+ /// The estimator to fit.
+ /// Number of cross-validation folds.
+ /// The label column (for evaluation).
+ /// Optional stratification column.
+ /// If two examples share the same value of the (if provided),
+ /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
+ /// train to the test set.
+ /// Per-fold results: metrics, models, scored datasets.
+ public static (BinaryClassifierEvaluator.CalibratedResult metrics, Transformer model, DataView scoredTestData)[] CrossValidate(
+ this BinaryClassificationContext context,
+ DataView data,
+ Estimator estimator,
+ Func> label,
+ int numFolds = 5,
+ Func stratificationColumn = null)
+ where TTransformer : class, ITransformer
+ {
+ var env = StaticPipeUtils.GetEnvironment(data);
+ Contracts.AssertValue(env);
+ env.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
+ env.CheckValue(label, nameof(label));
+ env.CheckValueOrNull(stratificationColumn);
+
+ var outputIndexer = StaticPipeUtils.GetIndexer(estimator);
+ var labelColumn = label(outputIndexer.Indices);
+ env.CheckParam(labelColumn != null, nameof(stratificationColumn), "Stratification column not found");
+ var labelName = outputIndexer.Get(labelColumn);
+
+ string stratName = null;
+ if (stratificationColumn != null)
+ {
+ var indexer = StaticPipeUtils.GetIndexer(data);
+ var column = stratificationColumn(indexer.Indices);
+ env.CheckParam(column != null, nameof(stratificationColumn), "Stratification column not found");
+ stratName = indexer.Get(column);
+ }
+
+ var results = context.CrossValidate(data.AsDynamic, estimator.AsDynamic, numFolds, labelName, stratName);
+
+ return results.Select(x => (
+ x.metrics,
+ new Transformer(env, (TTransformer)x.model, data.Shape, estimator.Shape),
+ new DataView(env, x.scoredTestData, estimator.Shape)))
+ .ToArray();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
index c11e846bfb..113092ed26 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/SdcaStatic.cs
@@ -236,7 +236,7 @@ public static (Vector score, Key predictedLabel)
{
Contracts.CheckValue(label, nameof(label));
Contracts.CheckValue(features, nameof(features));
- Contracts.CheckValue(loss, nameof(loss));
+ Contracts.CheckValueOrNull(loss);
Contracts.CheckValueOrNull(weights);
Contracts.CheckParam(!(l2Const < 0), nameof(l2Const), "Must not be negative, if specified.");
Contracts.CheckParam(!(l1Threshold < 0), nameof(l1Threshold), "Must not be negative, if specified.");
diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
index ada8d755db..b873e8262d 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs
@@ -5,6 +5,7 @@
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Data.IO;
using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.RunTests;
using Microsoft.ML.StaticPipe;
using Microsoft.ML.TestFramework;
using Microsoft.ML.Transforms;
@@ -13,6 +14,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
+using System.Linq;
using System.Text;
using Xunit;
using Xunit.Abstractions;
@@ -613,5 +615,35 @@ public void FeatureSelection()
type = schema.GetColumnType(bagofwordMiCol);
Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber);
}
+
+ [Fact]
+ public void TrainTestSplit()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new BinaryClassificationContext(env);
+
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadFloat(0), features: c.LoadFloat(1, 4)));
+ var data = reader.Read(dataSource);
+
+ var (train, test) = ctx.TrainTestSplit(data, 0.5);
+
+ // Just make sure that the train is about the same size as the test set.
+ var trainCount = train.GetColumn(r => r.label).Count();
+ var testCount = test.GetColumn(r => r.label).Count();
+
+ Assert.InRange(trainCount * 1.0 / testCount, 0.8, 1.2);
+
+ // Now stratify by label. Silly thing to do.
+ (train, test) = ctx.TrainTestSplit(data, 0.5, stratificationColumn: r => r.label);
+ var trainLabels = train.GetColumn(r => r.label).Distinct();
+ var testLabels = test.GetColumn(r => r.label).Distinct();
+ Assert.True(trainLabels.Count() > 0);
+ Assert.True(testLabels.Count() > 0);
+ Assert.False(trainLabels.Intersect(testLabels).Any());
+ }
}
}
\ No newline at end of file
diff --git a/test/Microsoft.ML.StaticPipelineTesting/Training.cs b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
index e3fc685c6d..55a5b6e76b 100644
--- a/test/Microsoft.ML.StaticPipelineTesting/Training.cs
+++ b/test/Microsoft.ML.StaticPipelineTesting/Training.cs
@@ -13,6 +13,7 @@
using Microsoft.ML.Runtime.Training;
using Microsoft.ML.Trainers;
using System;
+using System.Linq;
using Xunit;
using Xunit.Abstractions;
@@ -73,7 +74,7 @@ public void SdcaRegressionNameCollision()
var dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataSource = new MultiFileSource(dataPath);
var ctx = new RegressionContext(env);
-
+
// Here we introduce another column called "Score" to collide with the name of the default output. Heh heh heh...
var reader = TextLoader.CreateReader(env,
c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10), Score: c.LoadText(2)),
@@ -263,6 +264,30 @@ public void SdcaMulticlass()
Assert.True(metrics.TopKAccuracy > 0);
}
+ [Fact]
+ public void CrossValidate()
+ {
+ var env = new ConsoleEnvironment(seed: 0);
+ var dataPath = GetDataPath(TestDatasets.iris.trainFilename);
+ var dataSource = new MultiFileSource(dataPath);
+
+ var ctx = new MulticlassClassificationContext(env);
+ var reader = TextLoader.CreateReader(env,
+ c => (label: c.LoadText(0), features: c.LoadFloat(1, 4)));
+
+ var est = reader.MakeNewEstimator()
+ .Append(r => (label: r.label.ToKey(), r.features))
+ .Append(r => (r.label, preds: ctx.Trainers.Sdca(
+ r.label,
+ r.features,
+ maxIterations: 2)));
+
+ var results = ctx.CrossValidate(reader.Read(dataSource), est, r => r.label)
+ .Select(x => x.metrics).ToArray();
+ Assert.Equal(5, results.Length);
+ Assert.True(results.All(x => x.LogLoss > 0));
+ }
+
[Fact]
public void FastTreeBinaryClassification()
{