Skip to content

Commit 56607ba

Browse files
authored
Get rid of value tuples in TrainTest and CrossValidation (#2507)
1 parent f269adc commit 56607ba

File tree

11 files changed

+154
-68
lines changed

11 files changed

+154
-68
lines changed

docs/code/MlNetCookBook.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,20 +825,20 @@ var pipeline =
825825
.Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent());
826826

827827
// Split the data 90:10 into train and test sets, train and evaluate.
828-
var (trainData, testData) = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);
828+
var split = mlContext.MulticlassClassification.TrainTestSplit(data, testFraction: 0.1);
829829

830830
// Train the model.
831-
var model = pipeline.Fit(trainData);
831+
var model = pipeline.Fit(split.TrainSet);
832832
// Compute quality metrics on the test set.
833-
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(testData));
833+
var metrics = mlContext.MulticlassClassification.Evaluate(model.Transform(split.TestSet));
834834
Console.WriteLine(metrics.AccuracyMicro);
835835

836836
// Now run the 5-fold cross-validation experiment, using the same pipeline.
837837
var cvResults = mlContext.MulticlassClassification.CrossValidate(data, pipeline, numFolds: 5);
838838

839839
// The results object is an array of 5 elements. For each of the 5 folds, we have metrics, model and scored test data.
840840
// Let's compute the average micro-accuracy.
841-
var microAccuracies = cvResults.Select(r => r.metrics.AccuracyMicro);
841+
var microAccuracies = cvResults.Select(r => r.Metrics.AccuracyMicro);
842842
Console.WriteLine(microAccuracies.Average());
843843

844844
```

docs/samples/Microsoft.ML.Samples/Dynamic/Calibrator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public static void Calibration()
4343
var data = reader.Read(dataFile);
4444

4545
// Split the dataset into two parts: one used for training, the other to train the calibrator
46-
var (trainData, calibratorTrainingData) = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
46+
var split = mlContext.BinaryClassification.TrainTestSplit(data, testFraction: 0.1);
4747

4848
// Featurize the text column through the FeaturizeText API.
4949
// Then append the StochasticDualCoordinateAscentBinary binary classifier, setting the "Label" column as the label of the dataset, and
@@ -56,12 +56,12 @@ public static void Calibration()
5656
loss: new HingeLoss())); // By specifying loss: new HingeLoss(), StochasticDualCoordinateAscent will train a support vector machine (SVM).
5757

5858
// Fit the pipeline, and get a transformer that knows how to score new data.
59-
var transformer = pipeline.Fit(trainData);
59+
var transformer = pipeline.Fit(split.TrainSet);
6060
IPredictor model = transformer.LastTransformer.Model;
6161

6262
// Let's score the new data. The score will give us a numerical estimation of the chance that the particular sample
6363
// bears positive sentiment. This estimate is relative to the numbers obtained.
64-
var scoredData = transformer.Transform(calibratorTrainingData);
64+
var scoredData = transformer.Transform(split.TestSet);
6565
var scoredDataPreview = scoredData.Preview();
6666

6767
PrintRowViewValues(scoredDataPreview);

docs/samples/Microsoft.ML.Samples/Dynamic/LogisticRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static void LogisticRegression()
5757

5858
IDataView data = reader.Read(dataFilePath);
5959

60-
var (trainData, testData) = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);
60+
var split = ml.BinaryClassification.TrainTestSplit(data, testFraction: 0.2);
6161

6262
var pipeline = ml.Transforms.Concatenate("Text", "workclass", "education", "marital-status",
6363
"relationship", "ethnicity", "sex", "native-country")
@@ -66,9 +66,9 @@ public static void LogisticRegression()
6666
"education-num", "capital-gain", "capital-loss", "hours-per-week"))
6767
.Append(ml.BinaryClassification.Trainers.LogisticRegression());
6868

69-
var model = pipeline.Fit(trainData);
69+
var model = pipeline.Fit(split.TrainSet);
7070

71-
var dataWithPredictions = model.Transform(testData);
71+
var dataWithPredictions = model.Transform(split.TestSet);
7272

7373
var metrics = ml.BinaryClassification.Evaluate(dataWithPredictions);
7474

docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public static void SDCA_BinaryClassification()
5454
// Step 3: Run Cross-Validation on this pipeline.
5555
var cvResults = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment");
5656

57-
var accuracies = cvResults.Select(r => r.metrics.Accuracy);
57+
var accuracies = cvResults.Select(r => r.Metrics.Accuracy);
5858
Console.WriteLine(accuracies.Average());
5959

6060
// If we wanted to specify more advanced parameters for the algorithm,
@@ -70,7 +70,7 @@ public static void SDCA_BinaryClassification()
7070

7171
// Run Cross-Validation on this second pipeline.
7272
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
73-
accuracies = cvResults_advancedPipeline.Select(r => r.metrics.Accuracy);
73+
accuracies = cvResults_advancedPipeline.Select(r => r.Metrics.Accuracy);
7474
Console.WriteLine(accuracies.Average());
7575

7676
}

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 105 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@ public abstract class TrainCatalogBase
2525
[BestFriend]
2626
internal IHostEnvironment Environment => Host;
2727

28+
/// <summary>
29+
/// A pair of datasets, for the train and test set.
30+
/// </summary>
31+
public struct TrainTestData
32+
{
33+
/// <summary>
34+
/// Training set.
35+
/// </summary>
36+
public readonly IDataView TrainSet;
37+
/// <summary>
38+
/// Testing set.
39+
/// </summary>
40+
public readonly IDataView TestSet;
41+
/// <summary>
42+
/// Create pair of datasets.
43+
/// </summary>
44+
/// <param name="trainSet">Training set.</param>
45+
/// <param name="testSet">Testing set.</param>
46+
internal TrainTestData(IDataView trainSet, IDataView testSet)
47+
{
48+
TrainSet = trainSet;
49+
TestSet = testSet;
50+
}
51+
}
52+
2853
/// <summary>
2954
/// Split the dataset into the train set and test set according to the given fraction.
3055
/// Respects the <paramref name="stratificationColumn"/> if provided.
@@ -37,8 +62,7 @@ public abstract class TrainCatalogBase
3762
/// <param name="seed">Optional parameter used in combination with the <paramref name="stratificationColumn"/>.
3863
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
3964
/// And if it is not provided, the default value will be used.</param>
40-
/// <returns>A pair of datasets, for the train and test set.</returns>
41-
public (IDataView trainSet, IDataView testSet) TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
65+
public TrainTestData TrainTestSplit(IDataView data, double testFraction = 0.1, string stratificationColumn = null, uint? seed = null)
4266
{
4367
Host.CheckValue(data, nameof(data));
4468
Host.CheckParam(0 < testFraction && testFraction < 1, nameof(testFraction), "Must be between 0 and 1 exclusive");
@@ -61,14 +85,71 @@ public abstract class TrainCatalogBase
6185
Complement = false
6286
}, data);
6387

64-
return (trainFilter, testFilter);
88+
return new TrainTestData(trainFilter, testFilter);
89+
}
90+
91+
/// <summary>
92+
/// Results for specific cross-validation fold.
93+
/// </summary>
94+
protected internal struct CrossValidationResult
95+
{
96+
/// <summary>
97+
/// Model trained during cross validation fold.
98+
/// </summary>
99+
public readonly ITransformer Model;
100+
/// <summary>
101+
/// Scored test set with <see cref="Model"/> for this fold.
102+
/// </summary>
103+
public readonly IDataView Scores;
104+
/// <summary>
105+
/// Fold number.
106+
/// </summary>
107+
public readonly int Fold;
108+
109+
public CrossValidationResult(ITransformer model, IDataView scores, int fold)
110+
{
111+
Model = model;
112+
Scores = scores;
113+
Fold = fold;
114+
}
115+
}
116+
/// <summary>
117+
/// Results of running cross-validation.
118+
/// </summary>
119+
/// <typeparam name="T">Type of metric class.</typeparam>
120+
public sealed class CrossValidationResult<T> where T : class
121+
{
122+
/// <summary>
123+
/// Metrics for this cross-validation fold.
124+
/// </summary>
125+
public readonly T Metrics;
126+
/// <summary>
127+
/// Model trained during cross-validation fold.
128+
/// </summary>
129+
public readonly ITransformer Model;
130+
/// <summary>
131+
/// The scored hold-out set for this fold.
132+
/// </summary>
133+
public readonly IDataView ScoredHoldOutSet;
134+
/// <summary>
135+
/// Fold number.
136+
/// </summary>
137+
public readonly int Fold;
138+
139+
internal CrossValidationResult(ITransformer model, T metrics, IDataView scores, int fold)
140+
{
141+
Model = model;
142+
Metrics = metrics;
143+
ScoredHoldOutSet = scores;
144+
Fold = fold;
145+
}
65146
}
66147

67148
/// <summary>
68149
/// Train the <paramref name="estimator"/> on <paramref name="numFolds"/> folds of the data sequentially.
69150
/// Return each model and each scored test dataset.
70151
/// </summary>
71-
protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
152+
protected internal CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator<ITransformer> estimator,
72153
int numFolds, string stratificationColumn, uint? seed = null)
73154
{
74155
Host.CheckValue(data, nameof(data));
@@ -78,7 +159,7 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
78159

79160
EnsureStratificationColumn(ref data, ref stratificationColumn, seed);
80161

81-
Func<int, (IDataView scores, ITransformer model)> foldFunction =
162+
Func<int, CrossValidationResult> foldFunction =
82163
fold =>
83164
{
84165
var trainFilter = new RangeFilter(Host, new RangeFilter.Options
@@ -98,17 +179,17 @@ protected internal (IDataView scoredTestSet, ITransformer model)[] CrossValidate
98179

99180
var model = estimator.Fit(trainFilter);
100181
var scoredTest = model.Transform(testFilter);
101-
return (scoredTest, model);
182+
return new CrossValidationResult(model, scoredTest, fold);
102183
};
103184

104185
// Sequential per-fold training.
105186
// REVIEW: we could have a parallel implementation here. We would need to
106187
// spawn off a separate host per fold in that case.
107-
var result = new List<(IDataView scores, ITransformer model)>();
188+
var result = new CrossValidationResult[numFolds];
108189
for (int fold = 0; fold < numFolds; fold++)
109-
result.Add(foldFunction(fold));
190+
result[fold] = foldFunction(fold);
110191

111-
return result.ToArray();
192+
return result;
112193
}
113194

114195
protected internal TrainCatalogBase(IHostEnvironment env, string registrationName)
@@ -263,13 +344,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
263344
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
264345
/// And if it is not provided, the default value will be used.</param>
265346
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
266-
public (BinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidateNonCalibrated(
347+
public CrossValidationResult<BinaryClassificationMetrics>[] CrossValidateNonCalibrated(
267348
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
268349
string stratificationColumn = null, uint? seed = null)
269350
{
270351
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
271352
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
272-
return result.Select(x => (EvaluateNonCalibrated(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
353+
return result.Select(x => new CrossValidationResult<BinaryClassificationMetrics>(x.Model,
354+
EvaluateNonCalibrated(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
273355
}
274356

275357
/// <summary>
@@ -287,13 +369,14 @@ public BinaryClassificationMetrics EvaluateNonCalibrated(IDataView data, string
287369
/// train to the test set.</remarks>
288370
/// <param name="seed">If <paramref name="stratificationColumn"/> not present in dataset we will generate random filled column based on provided <paramref name="seed"/>.</param>
289371
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
290-
public (CalibratedBinaryClassificationMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
372+
public CrossValidationResult<CalibratedBinaryClassificationMetrics>[] CrossValidate(
291373
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
292374
string stratificationColumn = null, uint? seed = null)
293375
{
294376
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
295377
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
296-
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
378+
return result.Select(x => new CrossValidationResult<CalibratedBinaryClassificationMetrics>(x.Model,
379+
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
297380
}
298381
}
299382

@@ -369,12 +452,13 @@ public ClusteringMetrics Evaluate(IDataView data,
369452
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
370453
/// And if it is not provided, the default value will be used.</param>
371454
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
372-
public (ClusteringMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
455+
public CrossValidationResult<ClusteringMetrics>[] CrossValidate(
373456
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = null, string featuresColumn = null,
374457
string stratificationColumn = null, uint? seed = null)
375458
{
376459
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
377-
return result.Select(x => (Evaluate(x.scoredTestSet, label: labelColumn, features: featuresColumn), x.model, x.scoredTestSet)).ToArray();
460+
return result.Select(x => new CrossValidationResult<ClusteringMetrics>(x.Model,
461+
Evaluate(x.Scores, label: labelColumn, features: featuresColumn), x.Scores, x.Fold)).ToArray();
378462
}
379463
}
380464

@@ -444,13 +528,14 @@ public MultiClassClassifierMetrics Evaluate(IDataView data, string label = Defau
444528
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
445529
/// And if it is not provided, the default value will be used.</param>
446530
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
447-
public (MultiClassClassifierMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
531+
public CrossValidationResult<MultiClassClassifierMetrics>[] CrossValidate(
448532
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
449533
string stratificationColumn = null, uint? seed = null)
450534
{
451535
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
452536
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
453-
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
537+
return result.Select(x => new CrossValidationResult<MultiClassClassifierMetrics>(x.Model,
538+
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
454539
}
455540
}
456541

@@ -511,13 +596,14 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
511596
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
512597
/// And if it is not provided, the default value will be used.</param>
513598
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
514-
public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
599+
public CrossValidationResult<RegressionMetrics>[] CrossValidate(
515600
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
516601
string stratificationColumn = null, uint? seed = null)
517602
{
518603
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
519604
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
520-
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
605+
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model,
606+
Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
521607
}
522608
}
523609

src/Microsoft.ML.Recommender/RecommenderCatalog.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ public RegressionMetrics Evaluate(IDataView data, string label = DefaultColumnNa
128128
/// If the <paramref name="stratificationColumn"/> is not provided, the random numbers generated to create it, will use this seed as value.
129129
/// And if it is not provided, the default value will be used.</param>
130130
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
131-
public (RegressionMetrics metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
131+
public CrossValidationResult<RegressionMetrics>[] CrossValidate(
132132
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label,
133133
string stratificationColumn = null, uint? seed = null)
134134
{
135135
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
136136
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn, seed);
137-
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
137+
return result.Select(x => new CrossValidationResult<RegressionMetrics>(x.Model, Evaluate(x.Scores, labelColumn), x.Scores, x.Fold)).ToArray();
138138
}
139139
}
140140
}

0 commit comments

Comments
 (0)