Skip to content

Commit a869547

Browse files
authored
Make sure seed works for stratification column in TrainTest and CrossValidate (#2241)
1 parent aedc4f7 commit a869547

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

src/Microsoft.ML.Data/TrainCatalog.cs

+6-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,12 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific
152152
// Generate a new column with the hashed stratification column.
153153
while (data.Schema.TryGetColumnIndex(stratificationColumn, out tmp))
154154
stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc);
155-
data = new HashingEstimator(Host, origStratCol, stratificationColumn, 30).Fit(data).Transform(data);
155+
HashingTransformer.ColumnInfo columnInfo;
156+
if (seed.HasValue)
157+
columnInfo = new HashingTransformer.ColumnInfo(origStratCol, stratificationColumn, 30, seed.Value);
158+
else
159+
columnInfo = new HashingTransformer.ColumnInfo(origStratCol, stratificationColumn, 30);
160+
data = new HashingEstimator(Host, columnInfo).Fit(data).Transform(data);
156161
}
157162
}
158163
}

test/Microsoft.ML.Tests/Scenarios/Api/TestApi.cs

+63
Original file line numberDiff line numberDiff line change
@@ -289,5 +289,68 @@ private List<BreastCancerExample> ReadBreastCancerExamples()
289289
.ToList();
290290
return data;
291291
}
292+
293+
[Fact]
294+
public void TestTrainTestSplit()
295+
{
296+
var mlContext = new MLContext(0);
297+
298+
var dataPath = GetDataPath("adult.tiny.with-schema.txt");
299+
// Create the reader: define the data columns and where to find them in the text file.
300+
var input = mlContext.Data.ReadFromTextFile(dataPath, new[] {
301+
new TextLoader.Column("Label", DataKind.BL, 0),
302+
new TextLoader.Column("Workclass", DataKind.TX, 1),
303+
new TextLoader.Column("Education", DataKind.TX,2),
304+
new TextLoader.Column("Age", DataKind.R4,9)
305+
}, hasHeader: true);
306+
// this function will accept dataview and return content of "Workclass" column as List of strings.
307+
Func<IDataView, List<string>> getWorkclass = (IDataView view) =>
308+
{
309+
return view.GetColumn<ReadOnlyMemory<char>>(mlContext, "Workclass").Select(x => x.ToString()).ToList();
310+
};
311+
312+
// Let's test what train test properly works with seed.
313+
// In order to do that, let's split same dataset, but in one case we will use default seed value,
314+
// and in other case we set seed to be specific value.
315+
var (simpleTrain, simpleTest) = mlContext.BinaryClassification.TrainTestSplit(input);
316+
var (simpleTrainWithSeed, simpleTestWithSeed) = mlContext.BinaryClassification.TrainTestSplit(input, seed: 10);
317+
318+
// Since test fraction is 0.1, it's much faster to compare test subsets of split.
319+
var simpleTestWorkClass = getWorkclass(simpleTest);
320+
321+
var simpleWithSeedTestWorkClass = getWorkclass(simpleTestWithSeed);
322+
// Validate we get different test sets.
323+
Assert.NotEqual(simpleTestWorkClass, simpleWithSeedTestWorkClass);
324+
325+
// Now let's do same thing but with presence of stratificationColumn.
326+
// Rows with same values in this stratificationColumn should end up in same subset (train or test).
327+
// So let's break dataset by "Workclass" column.
328+
var (stratTrain, stratTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn: "Workclass");
329+
var stratTrainWorkclass = getWorkclass(stratTrain);
330+
var stratTestWorkClass = getWorkclass(stratTest);
331+
// Let's get unique values for "Workclass" column from train subset.
332+
var uniqueTrain = stratTrainWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
333+
// and from test subset.
334+
var uniqueTest = stratTestWorkClass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
335+
// Validate we don't have intersection between workclass values since we use that column as stratification column
336+
Assert.True(Enumerable.Intersect(uniqueTrain, uniqueTest).Count() == 0);
337+
338+
// Let's do same thing, but this time we will choose different seed.
339+
// Stratification column should still break dataset properly without same values in both subsets.
340+
var (stratWithSeedTrain, stratWithSeedTest) = mlContext.BinaryClassification.TrainTestSplit(input, stratificationColumn:"Workclass", seed: 1000000);
341+
var stratTrainWithSeedWorkclass = getWorkclass(stratWithSeedTrain);
342+
var stratTestWithSeedWorkClass = getWorkclass(stratWithSeedTest);
343+
// Let's get unique values for "Workclass" column from train subset.
344+
var uniqueSeedTrain = stratTrainWithSeedWorkclass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
345+
// and from test subset.
346+
var uniqueSeedTest = stratTestWithSeedWorkClass.GroupBy(x => x.ToString()).Select(x => x.First()).ToList();
347+
348+
// Validate we don't have intersection between workclass values since we use that column as stratification column
349+
Assert.True(Enumerable.Intersect(uniqueSeedTrain, uniqueSeedTest).Count() == 0);
350+
// Validate we got different test results on same stratification column with different seeds
351+
Assert.NotEqual(uniqueTest, uniqueSeedTest);
352+
353+
}
354+
292355
}
293356
}

0 commit comments

Comments
 (0)