Skip to content

Commit 11ff554

Browse files
author
Pete Luferenko
committed
CrossValidator class
1 parent 7a66e19 commit 11ff554

File tree

2 files changed

+82
-53
lines changed

2 files changed

+82
-53
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/CrossValidation.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,12 @@ void New_CrossValidation()
105105
ConvergenceTolerance = 1f
106106
}, "Features", "Label"));
107107

108-
(var models, var metrics) = MyHelperExtensions.CrossValidateBinary(env, data, pipeline, "Label", numFolds: 2);
108+
var cv = new MyCrossValidation.BinaryCrossValidator(env)
109+
{
110+
NumFolds = 2
111+
};
112+
113+
var cvResult = cv.CrossValidate(data, pipeline);
109114
}
110115
}
111116
}

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 76 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,11 @@ public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel
187187
public TModel InnerModel { get; }
188188
}
189189

190-
public class BinaryScorerWrapper<TModel>: ScorerWrapper<TModel>
191-
where TModel: IPredictor
190+
public class BinaryScorerWrapper<TModel> : ScorerWrapper<TModel>
191+
where TModel : IPredictor
192192
{
193193
public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args)
194-
:base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn)
194+
: base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn)
195195
{
196196
}
197197

@@ -235,7 +235,7 @@ public SchemaShape GetOutputSchema()
235235
}
236236

237237
public abstract class TrainerBase<TTransformer, TModel> : IEstimator<TTransformer>
238-
where TTransformer: ScorerWrapper<TModel>
238+
where TTransformer : ScorerWrapper<TModel>
239239
where TModel : IPredictor
240240
{
241241
protected readonly IHostEnvironment _env;
@@ -414,7 +414,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
414414
}
415415
}
416416

417-
public sealed class MySdca : TrainerBase<BinaryScorerWrapper<IPredictor>,IPredictor>
417+
public sealed class MySdca : TrainerBase<BinaryScorerWrapper<IPredictor>, IPredictor>
418418
{
419419
private readonly LinearClassificationTrainer.Arguments _args;
420420

@@ -428,7 +428,7 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
428428

429429
public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
430430

431-
protected override BinaryScorerWrapper<IPredictor> MakeScorer(IPredictor predictor, RoleMappedData data)
431+
protected override BinaryScorerWrapper<IPredictor> MakeScorer(IPredictor predictor, RoleMappedData data)
432432
=> new BinaryScorerWrapper<IPredictor>(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments());
433433
}
434434

@@ -512,68 +512,92 @@ public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn,
512512
}
513513
}
514514

515-
public static class MyHelperExtensions
515+
public static class MyCrossValidation
516516
{
517-
public static void SaveAsBinary(this IDataView data, IHostEnvironment env, Stream stream)
517+
public sealed class BinaryCrossValidationMetrics
518518
{
519-
var saver = new BinarySaver(env, new BinarySaver.Arguments());
520-
using (var ch = env.Start("SaveData"))
521-
DataSaverUtils.SaveDataView(ch, saver, data, stream);
519+
public readonly ITransformer[] FoldModels;
520+
public readonly BinaryClassificationMetrics[] FoldMetrics;
521+
522+
public BinaryCrossValidationMetrics(ITransformer[] models, BinaryClassificationMetrics[] metrics)
523+
{
524+
FoldModels = models;
525+
FoldMetrics = metrics;
526+
}
522527
}
523528

524-
public static IDataView FitAndTransform(this IEstimator<ITransformer> est, IDataView data) => est.Fit(data).Transform(data);
529+
public sealed class BinaryCrossValidator
530+
{
531+
private readonly IHostEnvironment _env;
525532

526-
public static IDataView FitAndRead<TSource>(this IDataReaderEstimator<TSource, IDataReader<TSource>> est, TSource source)
527-
=> est.Fit(source).Read(source);
533+
public int NumFolds { get; set; } = 2;
528534

529-
public static (ITransformer[] Models, BinaryClassificationMetrics[] Metrics) CrossValidateBinary(IHostEnvironment env, IDataView trainData, IEstimator<ITransformer> estimator,
530-
string labelColumn,
531-
int numFolds = 2,
532-
string stratificationColumn = null,
533-
bool cache = false)
534-
{
535-
var models = new ITransformer[numFolds];
536-
var metrics = new BinaryClassificationMetrics[numFolds];
535+
public string StratificationColumn { get; set; }
537536

538-
if (stratificationColumn == null)
537+
public string LabelColumn { get; set; } = DefaultColumnNames.Label;
538+
539+
public BinaryCrossValidator(IHostEnvironment env)
539540
{
540-
stratificationColumn = "StratificationColumn";
541-
var random = new GenerateNumberTransform(env, trainData, stratificationColumn);
542-
trainData = random;
541+
_env = env;
543542
}
544-
else
545-
throw new NotImplementedException();
546-
547-
IDataView cachedTrain = trainData;
548-
if (cache)
549-
cachedTrain = new CacheDataView(env, trainData, prefetch: null);
550543

551-
var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { });
552-
553-
for (int fold = 0; fold < numFolds; fold++)
544+
public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator<ITransformer> estimator)
554545
{
555-
var trainFilter = new RangeFilter(env, new RangeFilter.Arguments()
546+
var models = new ITransformer[NumFolds];
547+
var metrics = new BinaryClassificationMetrics[NumFolds];
548+
549+
if (StratificationColumn == null)
556550
{
557-
Column = stratificationColumn,
558-
Min = (Double)fold / numFolds,
559-
Max = (Double)(fold + 1) / numFolds,
560-
Complement = true
561-
}, cachedTrain);
562-
var testFilter = new RangeFilter(env, new RangeFilter.Arguments()
551+
StratificationColumn = "StratificationColumn";
552+
var random = new GenerateNumberTransform(_env, trainData, StratificationColumn);
553+
trainData = random;
554+
}
555+
else
556+
throw new NotImplementedException();
557+
558+
var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { });
559+
560+
for (int fold = 0; fold < NumFolds; fold++)
563561
{
564-
Column = stratificationColumn,
565-
Min = (Double)fold / numFolds,
566-
Max = (Double)(fold + 1) / numFolds,
567-
Complement = false
568-
}, cachedTrain);
569-
570-
models[fold] = estimator.Fit(trainFilter);
571-
var scoredTest = models[fold].Transform(testFilter);
572-
metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: labelColumn, probabilityColumn: "Probability");
562+
var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments()
563+
{
564+
Column = StratificationColumn,
565+
Min = (Double)fold / NumFolds,
566+
Max = (Double)(fold + 1) / NumFolds,
567+
Complement = true
568+
}, trainData);
569+
var testFilter = new RangeFilter(_env, new RangeFilter.Arguments()
570+
{
571+
Column = StratificationColumn,
572+
Min = (Double)fold / NumFolds,
573+
Max = (Double)(fold + 1) / NumFolds,
574+
Complement = false
575+
}, trainData);
576+
577+
models[fold] = estimator.Fit(trainFilter);
578+
var scoredTest = models[fold].Transform(testFilter);
579+
metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability");
580+
}
581+
582+
return new BinaryCrossValidationMetrics(models, metrics);
583+
573584
}
585+
}
586+
}
574587

575-
return (models, metrics);
588+
589+
public static class MyHelperExtensions
590+
{
591+
public static void SaveAsBinary(this IDataView data, IHostEnvironment env, Stream stream)
592+
{
593+
var saver = new BinarySaver(env, new BinarySaver.Arguments());
594+
using (var ch = env.Start("SaveData"))
595+
DataSaverUtils.SaveDataView(ch, saver, data, stream);
576596
}
577597

598+
public static IDataView FitAndTransform(this IEstimator<ITransformer> est, IDataView data) => est.Fit(data).Transform(data);
599+
600+
public static IDataView FitAndRead<TSource>(this IDataReaderEstimator<TSource, IDataReader<TSource>> est, TSource source)
601+
=> est.Fit(source).Read(source);
578602
}
579603
}

0 commit comments

Comments
 (0)