@@ -37,8 +37,8 @@ public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, string lab
37
37
{
38
38
var columnInformation = new ColumnInformation ( )
39
39
{
40
- LabelColumn = labelColumn ,
41
- SamplingKeyColumn = samplingKeyColumn
40
+ LabelColumnName = labelColumn ,
41
+ SamplingKeyColumnName = samplingKeyColumn
42
42
} ;
43
43
return Execute ( trainData , columnInformation , preFeaturizers , progressHandler ) ;
44
44
}
@@ -56,51 +56,51 @@ public IEnumerable<RunDetails<TMetrics>> Execute(IDataView trainData, ColumnInfo
56
56
if ( rowCount < crossValRowCountThreshold )
57
57
{
58
58
const int numCrossValFolds = 10 ;
59
- var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numCrossValFolds , columnInformation ? . SamplingKeyColumn ) ;
59
+ var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numCrossValFolds , columnInformation ? . SamplingKeyColumnName ) ;
60
60
return ExecuteCrossValSummary ( splitResult . trainDatasets , columnInformation , splitResult . validationDatasets , preFeaturizer , progressHandler ) ;
61
61
}
62
62
else
63
63
{
64
- var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , columnInformation ? . SamplingKeyColumn ) ;
64
+ var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , columnInformation ? . SamplingKeyColumnName ) ;
65
65
return ExecuteTrainValidate ( splitResult . trainData , columnInformation , splitResult . validationData , preFeaturizer , progressHandler ) ;
66
66
}
67
67
}
68
68
69
69
public IEnumerable < RunDetails < TMetrics > > Execute ( IDataView trainData , IDataView validationData , string labelColumn = DefaultColumnNames . Label , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetails < TMetrics > > progressHandler = null )
70
70
{
71
- var columnInformation = new ColumnInformation ( ) { LabelColumn = labelColumn } ;
71
+ var columnInformation = new ColumnInformation ( ) { LabelColumnName = labelColumn } ;
72
72
return Execute ( trainData , validationData , columnInformation , preFeaturizer , progressHandler ) ;
73
73
}
74
74
75
75
public IEnumerable < RunDetails < TMetrics > > Execute ( IDataView trainData , IDataView validationData , ColumnInformation columnInformation , IEstimator < ITransformer > preFeaturizer = null , IProgress < RunDetails < TMetrics > > progressHandler = null )
76
76
{
77
77
if ( validationData == null )
78
78
{
79
- var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , columnInformation ? . SamplingKeyColumn ) ;
79
+ var splitResult = SplitUtil . TrainValidateSplit ( Context , trainData , columnInformation ? . SamplingKeyColumnName ) ;
80
80
trainData = splitResult . trainData ;
81
81
validationData = splitResult . validationData ;
82
82
}
83
83
return ExecuteTrainValidate ( trainData , columnInformation , validationData , preFeaturizer , progressHandler ) ;
84
84
}
85
85
86
- public IEnumerable < CrossValidationRunDetails < TMetrics > > Execute ( IDataView trainData , uint numberOfCVFolds , ColumnInformation columnInformation = null , IEstimator < ITransformer > preFeaturizers = null , IProgress < CrossValidationRunDetails < TMetrics > > progressHandler = null )
86
+ public IEnumerable < CrossValidationRunDetails < TMetrics > > Execute ( IDataView trainData , uint numberOfCVFolds , ColumnInformation columnInformation = null , IEstimator < ITransformer > preFeaturizer = null , IProgress < CrossValidationRunDetails < TMetrics > > progressHandler = null )
87
87
{
88
88
UserInputValidationUtil . ValidateNumberOfCVFoldsArg ( numberOfCVFolds ) ;
89
- var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numberOfCVFolds , columnInformation ? . SamplingKeyColumn ) ;
90
- return ExecuteCrossVal ( splitResult . trainDatasets , columnInformation , splitResult . validationDatasets , preFeaturizers , progressHandler ) ;
89
+ var splitResult = SplitUtil . CrossValSplit ( Context , trainData , numberOfCVFolds , columnInformation ? . SamplingKeyColumnName ) ;
90
+ return ExecuteCrossVal ( splitResult . trainDatasets , columnInformation , splitResult . validationDatasets , preFeaturizer , progressHandler ) ;
91
91
}
92
92
93
93
public IEnumerable < CrossValidationRunDetails < TMetrics > > Execute ( IDataView trainData ,
94
94
uint numberOfCVFolds , string labelColumn = DefaultColumnNames . Label ,
95
- string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizers = null ,
95
+ string samplingKeyColumn = null , IEstimator < ITransformer > preFeaturizer = null ,
96
96
Progress < CrossValidationRunDetails < TMetrics > > progressHandler = null )
97
97
{
98
98
var columnInformation = new ColumnInformation ( )
99
99
{
100
- LabelColumn = labelColumn ,
101
- SamplingKeyColumn = samplingKeyColumn
100
+ LabelColumnName = labelColumn ,
101
+ SamplingKeyColumnName = samplingKeyColumn
102
102
} ;
103
- return Execute ( trainData , numberOfCVFolds , columnInformation , preFeaturizers , progressHandler ) ;
103
+ return Execute ( trainData , numberOfCVFolds , columnInformation , preFeaturizer , progressHandler ) ;
104
104
}
105
105
106
106
private IEnumerable < RunDetails < TMetrics > > ExecuteTrainValidate ( IDataView trainData ,
@@ -111,8 +111,18 @@ private IEnumerable<RunDetails<TMetrics>> ExecuteTrainValidate(IDataView trainDa
111
111
{
112
112
columnInfo = columnInfo ?? new ColumnInformation ( ) ;
113
113
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainData , columnInfo , validationData ) ;
114
- var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . LabelColumn , _metricsAgent ,
115
- preFeaturizer , _settings . DebugLogger ) ;
114
+
115
+ // Apply pre-featurizer
116
+ ITransformer preprocessorTransform = null ;
117
+ if ( preFeaturizer != null )
118
+ {
119
+ preprocessorTransform = preFeaturizer . Fit ( trainData ) ;
120
+ trainData = preprocessorTransform . Transform ( trainData ) ;
121
+ validationData = preprocessorTransform . Transform ( validationData ) ;
122
+ }
123
+
124
+ var runner = new TrainValidateRunner < TMetrics > ( Context , trainData , validationData , columnInfo . LabelColumnName , _metricsAgent ,
125
+ preFeaturizer , preprocessorTransform , _settings . DebugLogger ) ;
116
126
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainData , columnInfo ) ;
117
127
return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
118
128
}
@@ -125,8 +135,13 @@ private IEnumerable<CrossValidationRunDetails<TMetrics>> ExecuteCrossVal(IDataVi
125
135
{
126
136
columnInfo = columnInfo ?? new ColumnInformation ( ) ;
127
137
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainDatasets [ 0 ] , columnInfo , validationDatasets [ 0 ] ) ;
128
- var runner = new CrossValRunner < TMetrics > ( Context , trainDatasets , validationDatasets , _metricsAgent , preFeaturizer ,
129
- columnInfo . LabelColumn , _settings . DebugLogger ) ;
138
+
139
+ // Apply pre-featurizer
140
+ ITransformer [ ] preprocessorTransforms = null ;
141
+ ( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
142
+
143
+ var runner = new CrossValRunner < TMetrics > ( Context , trainDatasets , validationDatasets , _metricsAgent , preFeaturizer ,
144
+ preprocessorTransforms , columnInfo . LabelColumnName , _settings . DebugLogger ) ;
130
145
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
131
146
return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
132
147
}
@@ -139,8 +154,13 @@ private IEnumerable<RunDetails<TMetrics>> ExecuteCrossValSummary(IDataView[] tra
139
154
{
140
155
columnInfo = columnInfo ?? new ColumnInformation ( ) ;
141
156
UserInputValidationUtil . ValidateExperimentExecuteArgs ( trainDatasets [ 0 ] , columnInfo , validationDatasets [ 0 ] ) ;
157
+
158
+ // Apply pre-featurizer
159
+ ITransformer [ ] preprocessorTransforms = null ;
160
+ ( trainDatasets , validationDatasets , preprocessorTransforms ) = ApplyPreFeaturizerCrossVal ( trainDatasets , validationDatasets , preFeaturizer ) ;
161
+
142
162
var runner = new CrossValSummaryRunner < TMetrics > ( Context , trainDatasets , validationDatasets , _metricsAgent , preFeaturizer ,
143
- columnInfo . LabelColumn , _optimizingMetricInfo , _settings . DebugLogger ) ;
163
+ preprocessorTransforms , columnInfo . LabelColumnName , _optimizingMetricInfo , _settings . DebugLogger ) ;
144
164
var columns = DatasetColumnInfoUtil . GetDatasetColumnInfo ( Context , trainDatasets [ 0 ] , columnInfo ) ;
145
165
return Execute ( columnInfo , columns , preFeaturizer , progressHandler , runner ) ;
146
166
}
@@ -158,5 +178,25 @@ private IEnumerable<TRunDetails> Execute<TRunDetails>(ColumnInformation columnIn
158
178
159
179
return experiment . Execute ( ) ;
160
180
}
181
+
182
+ private static ( IDataView [ ] trainDatasets , IDataView [ ] validDatasets , ITransformer [ ] preprocessorTransforms )
183
+ ApplyPreFeaturizerCrossVal ( IDataView [ ] trainDatasets , IDataView [ ] validDatasets , IEstimator < ITransformer > preFeaturizer )
184
+ {
185
+ if ( preFeaturizer == null )
186
+ {
187
+ return ( trainDatasets , validDatasets , null ) ;
188
+ }
189
+
190
+ var preprocessorTransforms = new ITransformer [ trainDatasets . Length ] ;
191
+ for ( var i = 0 ; i < trainDatasets . Length ; i ++ )
192
+ {
193
+ // Preprocess train and validation data
194
+ preprocessorTransforms [ i ] = preFeaturizer . Fit ( trainDatasets [ i ] ) ;
195
+ trainDatasets [ i ] = preprocessorTransforms [ i ] . Transform ( trainDatasets [ i ] ) ;
196
+ validDatasets [ i ] = preprocessorTransforms [ i ] . Transform ( validDatasets [ i ] ) ;
197
+ }
198
+
199
+ return ( trainDatasets , validDatasets , preprocessorTransforms ) ;
200
+ }
161
201
}
162
202
}
0 commit comments