Skip to content

Commit a810d8d

Browse files
authored
Sample revs; ColumnInformation property name revs; pre-featurizer fixes (dotnet#346)
1 parent cab5809 commit a810d8d

32 files changed

+291
-256
lines changed

src/Microsoft.ML.Auto/API/ColumnInference.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ public sealed class ColumnInferenceResults
1616

1717
public sealed class ColumnInformation
1818
{
19-
public string LabelColumn { get; set; } = DefaultColumnNames.Label;
20-
public string ExampleWeightColumn { get; set; }
21-
public string SamplingKeyColumn { get; set; }
22-
public ICollection<string> CategoricalColumns { get; } = new Collection<string>();
23-
public ICollection<string> NumericColumns { get; } = new Collection<string>();
24-
public ICollection<string> TextColumns { get; } = new Collection<string>();
25-
public ICollection<string> IgnoredColumns { get; } = new Collection<string>();
19+
public string LabelColumnName { get; set; } = DefaultColumnNames.Label;
20+
public string ExampleWeightColumnName { get; set; }
21+
public string SamplingKeyColumnName { get; set; }
22+
public ICollection<string> CategoricalColumnNames { get; } = new Collection<string>();
23+
public ICollection<string> NumericColumnNames { get; } = new Collection<string>();
24+
public ICollection<string> TextColumnNames { get; } = new Collection<string>();
25+
public ICollection<string> IgnoredColumnNames { get; } = new Collection<string>();
2626
}
2727
}

src/Microsoft.ML.Auto/API/ExperimentBase.cs

+58-18
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, string lab
3737
{
3838
var columnInformation = new ColumnInformation()
3939
{
40-
LabelColumn = labelColumn,
41-
SamplingKeyColumn = samplingKeyColumn
40+
LabelColumnName = labelColumn,
41+
SamplingKeyColumnName = samplingKeyColumn
4242
};
4343
return Execute(trainData, columnInformation, preFeaturizers, progressHandler);
4444
}
@@ -56,51 +56,51 @@ public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, ColumnInfo
5656
if (rowCount < crossValRowCountThreshold)
5757
{
5858
const int numCrossValFolds = 10;
59-
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumn);
59+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName);
6060
return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
6161
}
6262
else
6363
{
64-
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
64+
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
6565
return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
6666
}
6767
}
6868

6969
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
7070
{
71-
var columnInformation = new ColumnInformation() { LabelColumn = labelColumn };
71+
var columnInformation = new ColumnInformation() { LabelColumnName = labelColumn };
7272
return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
7373
}
7474

7575
public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetails<TMetrics>> progressHandler = null)
7676
{
7777
if (validationData == null)
7878
{
79-
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumn);
79+
var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
8080
trainData = splitResult.trainData;
8181
validationData = splitResult.validationData;
8282
}
8383
return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
8484
}
8585

86-
public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizers = null, IProgress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
86+
public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator<ITransformer> preFeaturizer = null, IProgress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
8787
{
8888
UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
89-
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumn);
90-
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizers, progressHandler);
89+
var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
90+
return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
9191
}
9292

9393
public IEnumerable<CrossValidationRunDetails<TMetrics>> Execute(IDataView trainData,
9494
uint numberOfCVFolds, string labelColumn = DefaultColumnNames.Label,
95-
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizers = null,
95+
string samplingKeyColumn = null, IEstimator<ITransformer> preFeaturizer = null,
9696
Progress<CrossValidationRunDetails<TMetrics>> progressHandler = null)
9797
{
9898
var columnInformation = new ColumnInformation()
9999
{
100-
LabelColumn = labelColumn,
101-
SamplingKeyColumn = samplingKeyColumn
100+
LabelColumnName = labelColumn,
101+
SamplingKeyColumnName = samplingKeyColumn
102102
};
103-
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizers, progressHandler);
103+
return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler);
104104
}
105105

106106
private IEnumerable<RunDetails<TMetrics>> ExecuteTrainValidate(IDataView trainData,
@@ -111,8 +111,18 @@ private IEnumerable<RunDetails<TMetrics>> ExecuteTrainValidate(IDataView trainDa
111111
{
112112
columnInfo = columnInfo ?? new ColumnInformation();
113113
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
114-
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumn, _metricsAgent,
115-
preFeaturizer, _settings.DebugLogger);
114+
115+
// Apply pre-featurizer
116+
ITransformer preprocessorTransform = null;
117+
if (preFeaturizer != null)
118+
{
119+
preprocessorTransform = preFeaturizer.Fit(trainData);
120+
trainData = preprocessorTransform.Transform(trainData);
121+
validationData = preprocessorTransform.Transform(validationData);
122+
}
123+
124+
var runner = new TrainValidateRunner<TMetrics>(Context, trainData, validationData, columnInfo.LabelColumnName, _metricsAgent,
125+
preFeaturizer, preprocessorTransform, _settings.DebugLogger);
116126
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
117127
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
118128
}
@@ -125,8 +135,13 @@ private IEnumerable<CrossValidationRunDetails<TMetrics>> ExecuteCrossVal(IDataVi
125135
{
126136
columnInfo = columnInfo ?? new ColumnInformation();
127137
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
128-
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
129-
columnInfo.LabelColumn, _settings.DebugLogger);
138+
139+
// Apply pre-featurizer
140+
ITransformer[] preprocessorTransforms = null;
141+
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
142+
143+
var runner = new CrossValRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
144+
preprocessorTransforms, columnInfo.LabelColumnName, _settings.DebugLogger);
130145
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
131146
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
132147
}
@@ -139,8 +154,13 @@ private IEnumerable<RunDetails<TMetrics>> ExecuteCrossValSummary(IDataView[] tra
139154
{
140155
columnInfo = columnInfo ?? new ColumnInformation();
141156
UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
157+
158+
// Apply pre-featurizer
159+
ITransformer[] preprocessorTransforms = null;
160+
(trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
161+
142162
var runner = new CrossValSummaryRunner<TMetrics>(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
143-
columnInfo.LabelColumn, _optimizingMetricInfo, _settings.DebugLogger);
163+
preprocessorTransforms, columnInfo.LabelColumnName, _optimizingMetricInfo, _settings.DebugLogger);
144164
var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
145165
return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
146166
}
@@ -158,5 +178,25 @@ private IEnumerable<TRunDetails> Execute<TRunDetails>(ColumnInformation columnIn
158178

159179
return experiment.Execute();
160180
}
181+
182+
private static (IDataView[] trainDatasets, IDataView[] validDatasets, ITransformer[] preprocessorTransforms)
183+
ApplyPreFeaturizerCrossVal(IDataView[] trainDatasets, IDataView[] validDatasets, IEstimator<ITransformer> preFeaturizer)
184+
{
185+
if (preFeaturizer == null)
186+
{
187+
return (trainDatasets, validDatasets, null);
188+
}
189+
190+
var preprocessorTransforms = new ITransformer[trainDatasets.Length];
191+
for (var i = 0; i < trainDatasets.Length; i++)
192+
{
193+
// Preprocess train and validation data
194+
preprocessorTransforms[i] = preFeaturizer.Fit(trainDatasets[i]);
195+
trainDatasets[i] = preprocessorTransforms[i].Transform(trainDatasets[i]);
196+
validDatasets[i] = preprocessorTransforms[i].Transform(validDatasets[i]);
197+
}
198+
199+
return (trainDatasets, validDatasets, preprocessorTransforms);
200+
}
161201
}
162202
}

src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
2424
typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
2525
}
2626

27-
var columnInfo = new ColumnInformation() { LabelColumn = typeInference.Columns[labelColumnIndex].SuggestedName };
27+
var columnInfo = new ColumnInformation() { LabelColumnName = typeInference.Columns[labelColumnIndex].SuggestedName };
2828

2929
return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
3030
}
3131

3232
public static ColumnInferenceResults InferColumns(MLContext context, string path, string labelColumn,
3333
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
3434
{
35-
var columnInfo = new ColumnInformation() { LabelColumn = labelColumn };
35+
var columnInfo = new ColumnInformation() { LabelColumnName = labelColumn };
3636
return InferColumns(context, path, columnInfo, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
3737
}
3838

@@ -41,7 +41,7 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
4141
{
4242
var sample = TextFileSample.CreateFromFullFile(path);
4343
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
44-
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, columnInfo.LabelColumn);
44+
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, columnInfo.LabelColumnName);
4545
return InferColumns(context, path, columnInfo, true, splitInference, typeInference, trimWhitespace, groupColumns);
4646
}
4747

src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs

+14-14
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,37 @@ internal static class ColumnInformationUtil
1212
{
1313
internal static ColumnPurpose? GetColumnPurpose(this ColumnInformation columnInfo, string columnName)
1414
{
15-
if (columnName == columnInfo.LabelColumn)
15+
if (columnName == columnInfo.LabelColumnName)
1616
{
1717
return ColumnPurpose.Label;
1818
}
1919

20-
if (columnName == columnInfo.ExampleWeightColumn)
20+
if (columnName == columnInfo.ExampleWeightColumnName)
2121
{
2222
return ColumnPurpose.Weight;
2323
}
2424

25-
if (columnName == columnInfo.SamplingKeyColumn)
25+
if (columnName == columnInfo.SamplingKeyColumnName)
2626
{
2727
return ColumnPurpose.SamplingKey;
2828
}
2929

30-
if (columnInfo.CategoricalColumns.Contains(columnName))
30+
if (columnInfo.CategoricalColumnNames.Contains(columnName))
3131
{
3232
return ColumnPurpose.CategoricalFeature;
3333
}
3434

35-
if (columnInfo.NumericColumns.Contains(columnName))
35+
if (columnInfo.NumericColumnNames.Contains(columnName))
3636
{
3737
return ColumnPurpose.NumericFeature;
3838
}
3939

40-
if (columnInfo.TextColumns.Contains(columnName))
40+
if (columnInfo.TextColumnNames.Contains(columnName))
4141
{
4242
return ColumnPurpose.TextFeature;
4343
}
4444

45-
if (columnInfo.IgnoredColumns.Contains(columnName))
45+
if (columnInfo.IgnoredColumnNames.Contains(columnName))
4646
{
4747
return ColumnPurpose.Ignore;
4848
}
@@ -59,25 +59,25 @@ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, Colu
5959
switch (column.purpose)
6060
{
6161
case ColumnPurpose.Label:
62-
columnInfo.LabelColumn = column.name;
62+
columnInfo.LabelColumnName = column.name;
6363
break;
6464
case ColumnPurpose.Weight:
65-
columnInfo.ExampleWeightColumn = column.name;
65+
columnInfo.ExampleWeightColumnName = column.name;
6666
break;
6767
case ColumnPurpose.SamplingKey:
68-
columnInfo.SamplingKeyColumn = column.name;
68+
columnInfo.SamplingKeyColumnName = column.name;
6969
break;
7070
case ColumnPurpose.CategoricalFeature:
71-
columnInfo.CategoricalColumns.Add(column.name);
71+
columnInfo.CategoricalColumnNames.Add(column.name);
7272
break;
7373
case ColumnPurpose.Ignore:
74-
columnInfo.IgnoredColumns.Add(column.name);
74+
columnInfo.IgnoredColumnNames.Add(column.name);
7575
break;
7676
case ColumnPurpose.NumericFeature:
77-
columnInfo.NumericColumns.Add(column.name);
77+
columnInfo.NumericColumnNames.Add(column.name);
7878
break;
7979
case ColumnPurpose.TextFeature:
80-
columnInfo.TextColumns.Add(column.name);
80+
columnInfo.TextColumnNames.Add(column.name);
8181
break;
8282
}
8383
}

src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs

+4-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ public CrossValRunner(MLContext context,
2626
IDataView[] validDatasets,
2727
IMetricsAgent<TMetrics> metricsAgent,
2828
IEstimator<ITransformer> preFeaturizer,
29+
ITransformer[] preprocessorTransforms,
2930
string labelColumn,
3031
IDebugLogger logger)
3132
{
@@ -34,21 +35,10 @@ public CrossValRunner(MLContext context,
3435
_validDatasets = validDatasets;
3536
_metricsAgent = metricsAgent;
3637
_preFeaturizer = preFeaturizer;
38+
_preprocessorTransforms = preprocessorTransforms;
3739
_labelColumn = labelColumn;
3840
_logger = logger;
3941
_modelInputSchema = trainDatasets[0].Schema;
40-
41-
if (_preFeaturizer != null)
42-
{
43-
_preprocessorTransforms = new ITransformer[_trainDatasets.Length];
44-
for (var i = 0; i < _trainDatasets.Length; i++)
45-
{
46-
// Preprocess train and validation data
47-
_preprocessorTransforms[i] = _preFeaturizer.Fit(_trainDatasets[i]);
48-
_trainDatasets[i] = _preprocessorTransforms[i].Transform(_trainDatasets[i]);
49-
_validDatasets[i] = _preprocessorTransforms[i].Transform(_validDatasets[i]);
50-
}
51-
}
5242
}
5343

5444
public (SuggestedPipelineRunDetails suggestedPipelineRunDetails, CrossValidationRunDetails<TMetrics> runDetails)
@@ -60,15 +50,15 @@ public CrossValRunner(MLContext context,
6050
{
6151
var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
6252
var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
63-
_labelColumn, _metricsAgent, _preFeaturizer, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
53+
_labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
6454
trainResults.Add(new SuggestedPipelineTrainResult<TMetrics>(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score));
6555
}
6656

6757
var avgScore = CalcAverageScore(trainResults.Select(r => r.Score));
6858
var allRunsSucceeded = trainResults.All(r => r.Exception == null);
6959

7060
var suggestedPipelineRunDetails = new SuggestedPipelineCrossValRunDetails<TMetrics>(pipeline, avgScore, allRunsSucceeded, trainResults);
71-
var runDetails = suggestedPipelineRunDetails.ToIterationResult();
61+
var runDetails = suggestedPipelineRunDetails.ToIterationResult(_preFeaturizer);
7262
return (suggestedPipelineRunDetails, runDetails);
7363
}
7464

0 commit comments

Comments
 (0)