@@ -289,5 +289,68 @@ private List<BreastCancerExample> ReadBreastCancerExamples()
289
289
. ToList ( ) ;
290
290
return data ;
291
291
}
292
+
293
+ [ Fact ]
294
+ public void TestTrainTestSplit ( )
295
+ {
296
+ var mlContext = new MLContext ( 0 ) ;
297
+
298
+ var dataPath = GetDataPath ( "adult.tiny.with-schema.txt" ) ;
299
+ // Create the reader: define the data columns and where to find them in the text file.
300
+ var input = mlContext . Data . ReadFromTextFile ( dataPath , new [ ] {
301
+ new TextLoader . Column ( "Label" , DataKind . BL , 0 ) ,
302
+ new TextLoader . Column ( "Workclass" , DataKind . TX , 1 ) ,
303
+ new TextLoader . Column ( "Education" , DataKind . TX , 2 ) ,
304
+ new TextLoader . Column ( "Age" , DataKind . R4 , 9 )
305
+ } , hasHeader : true ) ;
306
+ // this function will accept dataview and return content of "Workclass" column as List of strings.
307
+ Func < IDataView , List < string > > getWorkclass = ( IDataView view ) =>
308
+ {
309
+ return view . GetColumn < ReadOnlyMemory < char > > ( mlContext , "Workclass" ) . Select ( x => x . ToString ( ) ) . ToList ( ) ;
310
+ } ;
311
+
312
+ // Let's test what train test properly works with seed.
313
+ // In order to do that, let's split same dataset, but in one case we will use default seed value,
314
+ // and in other case we set seed to be specific value.
315
+ var ( simpleTrain , simpleTest ) = mlContext . BinaryClassification . TrainTestSplit ( input ) ;
316
+ var ( simpleTrainWithSeed , simpleTestWithSeed ) = mlContext . BinaryClassification . TrainTestSplit ( input , seed : 10 ) ;
317
+
318
+ // Since test fraction is 0.1, it's much faster to compare test subsets of split.
319
+ var simpleTestWorkClass = getWorkclass ( simpleTest ) ;
320
+
321
+ var simpleWithSeedTestWorkClass = getWorkclass ( simpleTestWithSeed ) ;
322
+ // Validate we get different test sets.
323
+ Assert . NotEqual ( simpleTestWorkClass , simpleWithSeedTestWorkClass ) ;
324
+
325
+ // Now let's do same thing but with presence of stratificationColumn.
326
+ // Rows with same values in this stratificationColumn should end up in same subset (train or test).
327
+ // So let's break dataset by "Workclass" column.
328
+ var ( stratTrain , stratTest ) = mlContext . BinaryClassification . TrainTestSplit ( input , stratificationColumn : "Workclass" ) ;
329
+ var stratTrainWorkclass = getWorkclass ( stratTrain ) ;
330
+ var stratTestWorkClass = getWorkclass ( stratTest ) ;
331
+ // Let's get unique values for "Workclass" column from train subset.
332
+ var uniqueTrain = stratTrainWorkclass . GroupBy ( x => x . ToString ( ) ) . Select ( x => x . First ( ) ) . ToList ( ) ;
333
+ // and from test subset.
334
+ var uniqueTest = stratTestWorkClass . GroupBy ( x => x . ToString ( ) ) . Select ( x => x . First ( ) ) . ToList ( ) ;
335
+ // Validate we don't have intersection between workclass values since we use that column as stratification column
336
+ Assert . True ( Enumerable . Intersect ( uniqueTrain , uniqueTest ) . Count ( ) == 0 ) ;
337
+
338
+ // Let's do same thing, but this time we will choose different seed.
339
+ // Stratification column should still break dataset properly without same values in both subsets.
340
+ var ( stratWithSeedTrain , stratWithSeedTest ) = mlContext . BinaryClassification . TrainTestSplit ( input , stratificationColumn : "Workclass" , seed : 1000000 ) ;
341
+ var stratTrainWithSeedWorkclass = getWorkclass ( stratWithSeedTrain ) ;
342
+ var stratTestWithSeedWorkClass = getWorkclass ( stratWithSeedTest ) ;
343
+ // Let's get unique values for "Workclass" column from train subset.
344
+ var uniqueSeedTrain = stratTrainWithSeedWorkclass . GroupBy ( x => x . ToString ( ) ) . Select ( x => x . First ( ) ) . ToList ( ) ;
345
+ // and from test subset.
346
+ var uniqueSeedTest = stratTestWithSeedWorkClass . GroupBy ( x => x . ToString ( ) ) . Select ( x => x . First ( ) ) . ToList ( ) ;
347
+
348
+ // Validate we don't have intersection between workclass values since we use that column as stratification column
349
+ Assert . True ( Enumerable . Intersect ( uniqueSeedTrain , uniqueSeedTest ) . Count ( ) == 0 ) ;
350
+ // Validate we got different test results on same stratification column with different seeds
351
+ Assert . NotEqual ( uniqueTest , uniqueSeedTest ) ;
352
+
353
+ }
354
+
292
355
}
293
356
}
0 commit comments