@@ -12,32 +12,28 @@ namespace Microsoft.ML.Auto
12
12
{
13
13
public static class RegressionExtensions
14
14
{
15
- public static RegressionResult AutoFit ( this RegressionContext context ,
15
+ public static IEnumerable < IterationResult < RegressionMetrics > > AutoFit ( this RegressionContext context ,
16
16
IDataView trainData ,
17
17
string label = DefaultColumnNames . Label ,
18
18
IDataView validationData = null ,
19
19
uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
20
20
IEstimator < ITransformer > preFeaturizers = null ,
21
- IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
22
- CancellationToken cancellationToken = default ,
23
- IProgress < RegressionIterationResult > iterationCallback = null )
21
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null )
24
22
{
25
23
var settings = new AutoFitSettings ( ) ;
26
24
settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
27
25
28
26
return AutoFit ( context , trainData , label , validationData , settings ,
29
- preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
27
+ preFeaturizers , columnPurposes , null ) ;
30
28
}
31
29
32
- internal static RegressionResult AutoFit ( this RegressionContext context ,
30
+ internal static IEnumerable < IterationResult < RegressionMetrics > > AutoFit ( this RegressionContext context ,
33
31
IDataView trainData ,
34
32
string label = DefaultColumnNames . Label ,
35
33
IDataView validationData = null ,
36
34
AutoFitSettings settings = null ,
37
35
IEstimator < ITransformer > preFeaturizers = null ,
38
36
IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
39
- CancellationToken cancellationToken = default ,
40
- IProgress < RegressionIterationResult > iterationCallback = null ,
41
37
IDebugLogger debugLogger = null )
42
38
{
43
39
UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
@@ -48,49 +44,38 @@ internal static RegressionResult AutoFit(this RegressionContext context,
48
44
}
49
45
50
46
// run autofit & get all pipelines run in that process
51
- var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
52
- settings , preFeaturizers , TaskKind . Regression , OptimizingMetric . RSquared , columnPurposes , debugLogger ) ;
47
+ var autoFitter = new AutoFitter < RegressionMetrics > ( TaskKind . Regression , trainData , label , validationData ,
48
+ settings , preFeaturizers , columnPurposes ,
49
+ OptimizingMetric . RSquared , debugLogger ) ;
53
50
54
- var results = new RegressionIterationResult [ allPipelines . Length ] ;
55
- for ( var i = 0 ; i < results . Length ; i ++ )
56
- {
57
- var iterationResult = allPipelines [ i ] ;
58
- var result = new RegressionIterationResult ( iterationResult . Model , ( RegressionMetrics ) iterationResult . EvaluatedMetrics , iterationResult . ScoredValidationData , iterationResult . Pipeline . ToPipeline ( ) ) ;
59
- results [ i ] = result ;
60
- }
61
- var bestResult = new RegressionIterationResult ( bestPipeline . Model , ( RegressionMetrics ) bestPipeline . EvaluatedMetrics , bestPipeline . ScoredValidationData , bestPipeline . Pipeline . ToPipeline ( ) ) ;
62
- return new RegressionResult ( bestResult , results ) ;
51
+ return autoFitter . Fit ( ) ;
63
52
}
64
53
}
65
54
66
55
public static class BinaryClassificationExtensions
67
56
{
68
- public static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
57
+ public static IEnumerable < IterationResult < BinaryClassificationMetrics > > AutoFit ( this BinaryClassificationContext context ,
69
58
IDataView trainData ,
70
59
string label = DefaultColumnNames . Label ,
71
60
IDataView validationData = null ,
72
61
uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
73
62
IEstimator < ITransformer > preFeaturizers = null ,
74
- IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
75
- CancellationToken cancellationToken = default ,
76
- IProgress < BinaryClassificationItertionResult > iterationCallback = null )
63
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null )
77
64
{
78
65
var settings = new AutoFitSettings ( ) ;
79
66
settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
80
67
81
68
return AutoFit ( context , trainData , label , validationData , settings ,
82
- preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
69
+ preFeaturizers , columnPurposes , null ) ;
83
70
}
84
71
85
- internal static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
72
+ internal static IEnumerable < IterationResult < BinaryClassificationMetrics > > AutoFit ( this BinaryClassificationContext context ,
86
73
IDataView trainData ,
87
74
string label = DefaultColumnNames . Label ,
88
75
IDataView validationData = null ,
89
76
AutoFitSettings settings = null ,
90
77
IEstimator < ITransformer > preFeaturizers = null ,
91
78
IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
92
- CancellationToken cancellationToken = default ,
93
- IProgress < BinaryClassificationItertionResult > iterationCallback = null ,
94
79
IDebugLogger debugLogger = null )
95
80
{
96
81
UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
@@ -101,159 +86,67 @@ internal static BinaryClassificationResult AutoFit(this BinaryClassificationCont
101
86
}
102
87
103
88
// run autofit & get all pipelines run in that process
104
- var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
105
- settings , preFeaturizers , TaskKind . BinaryClassification , OptimizingMetric . Accuracy ,
106
- columnPurposes , debugLogger ) ;
107
-
108
- var results = new BinaryClassificationItertionResult [ allPipelines . Length ] ;
109
- for ( var i = 0 ; i < results . Length ; i ++ )
110
- {
111
- var iterationResult = allPipelines [ i ] ;
112
- var result = new BinaryClassificationItertionResult ( iterationResult . Model , ( BinaryClassificationMetrics ) iterationResult . EvaluatedMetrics , iterationResult . ScoredValidationData , iterationResult . Pipeline . ToPipeline ( ) ) ;
113
- results [ i ] = result ;
114
- }
115
- var bestResult = new BinaryClassificationItertionResult ( bestPipeline . Model , ( BinaryClassificationMetrics ) bestPipeline . EvaluatedMetrics , bestPipeline . ScoredValidationData , bestPipeline . Pipeline . ToPipeline ( ) ) ;
116
- return new BinaryClassificationResult ( bestResult , results ) ;
89
+ var autoFitter = new AutoFitter < BinaryClassificationMetrics > ( TaskKind . BinaryClassification , trainData , label , validationData ,
90
+ settings , preFeaturizers , columnPurposes ,
91
+ OptimizingMetric . RSquared , debugLogger ) ;
92
+
93
+ return autoFitter . Fit ( ) ;
117
94
}
118
95
}
119
96
120
97
public static class MulticlassExtensions
121
98
{
122
- public static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
99
+ public static IEnumerable < IterationResult < MultiClassClassifierMetrics > > AutoFit ( this MulticlassClassificationContext context ,
123
100
IDataView trainData ,
124
101
string label = DefaultColumnNames . Label ,
125
102
IDataView validationData = null ,
126
103
uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
127
104
IEstimator < ITransformer > preFeaturizers = null ,
128
- IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
129
- CancellationToken cancellationToken = default ,
130
- IProgress < MulticlassClassificationIterationResult > iterationCallback = null )
105
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null )
131
106
{
132
107
var settings = new AutoFitSettings ( ) ;
133
108
settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
134
109
135
110
return AutoFit ( context , trainData , label , validationData , settings ,
136
- preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
111
+ preFeaturizers , columnPurposes , null ) ;
137
112
}
138
113
139
- internal static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
114
+ internal static IEnumerable < IterationResult < MultiClassClassifierMetrics > > AutoFit ( this MulticlassClassificationContext context ,
140
115
IDataView trainData ,
141
116
string label = DefaultColumnNames . Label ,
142
117
IDataView validationData = null ,
143
118
AutoFitSettings settings = null ,
144
119
IEstimator < ITransformer > preFeaturizers = null ,
145
120
IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
146
- CancellationToken cancellationToken = default ,
147
- IProgress < MulticlassClassificationIterationResult > iterationCallback = null , IDebugLogger debugLogger = null )
121
+ IDebugLogger debugLogger = null )
148
122
{
149
123
UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
150
124
151
125
if ( validationData == null )
152
126
{
153
127
( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
154
128
}
155
-
129
+
156
130
// run autofit & get all pipelines run in that process
157
- var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
158
- settings , preFeaturizers , TaskKind . MulticlassClassification , OptimizingMetric . Accuracy ,
159
- columnPurposes , debugLogger ) ;
160
-
161
- var results = new MulticlassClassificationIterationResult [ allPipelines . Length ] ;
162
- for ( var i = 0 ; i < results . Length ; i ++ )
163
- {
164
- var iterationResult = allPipelines [ i ] ;
165
- var result = new MulticlassClassificationIterationResult ( iterationResult . Model , ( MultiClassClassifierMetrics ) iterationResult . EvaluatedMetrics , iterationResult . ScoredValidationData , iterationResult . Pipeline . ToPipeline ( ) ) ;
166
- results [ i ] = result ;
167
- }
168
- var bestResult = new MulticlassClassificationIterationResult ( bestPipeline . Model , ( MultiClassClassifierMetrics ) bestPipeline . EvaluatedMetrics , bestPipeline . ScoredValidationData , bestPipeline . Pipeline . ToPipeline ( ) ) ;
169
- return new MulticlassClassificationResult ( bestResult , results ) ;
170
- }
171
- }
172
-
173
- public class BinaryClassificationResult
174
- {
175
- public readonly BinaryClassificationItertionResult BestIteration ;
176
- public readonly BinaryClassificationItertionResult [ ] IterationResults ;
177
-
178
- public BinaryClassificationResult ( BinaryClassificationItertionResult bestPipeline ,
179
- BinaryClassificationItertionResult [ ] iterationResults )
180
- {
181
- BestIteration = bestPipeline ;
182
- IterationResults = iterationResults ;
183
- }
184
- }
185
-
186
- public class MulticlassClassificationResult
187
- {
188
- public readonly MulticlassClassificationIterationResult BestIteration ;
189
- public readonly MulticlassClassificationIterationResult [ ] IterationResults ;
190
-
191
- public MulticlassClassificationResult ( MulticlassClassificationIterationResult bestPipeline ,
192
- MulticlassClassificationIterationResult [ ] iterationResults )
193
- {
194
- BestIteration = bestPipeline ;
195
- IterationResults = iterationResults ;
196
- }
197
- }
198
-
199
- public class RegressionResult
200
- {
201
- public readonly RegressionIterationResult BestIteration ;
202
- public readonly RegressionIterationResult [ ] IterationResults ;
203
-
204
- public RegressionResult ( RegressionIterationResult bestPipeline ,
205
- RegressionIterationResult [ ] iterationResults )
206
- {
207
- BestIteration = bestPipeline ;
208
- IterationResults = iterationResults ;
209
- }
210
- }
211
-
212
- public class BinaryClassificationItertionResult
213
- {
214
- public readonly BinaryClassificationMetrics Metrics ;
215
- public readonly ITransformer Model ;
216
- public readonly IDataView ScoredValidationData ;
217
- internal readonly Pipeline Pipeline ;
218
-
219
- internal BinaryClassificationItertionResult ( ITransformer model , BinaryClassificationMetrics metrics , IDataView scoredValidationData , Pipeline pipeline )
220
- {
221
- Model = model ;
222
- ScoredValidationData = scoredValidationData ;
223
- Metrics = metrics ;
224
- Pipeline = pipeline ;
225
- }
226
- }
227
-
228
- public class MulticlassClassificationIterationResult
229
- {
230
- public readonly MultiClassClassifierMetrics Metrics ;
231
- public readonly ITransformer Model ;
232
- public readonly IDataView ScoredValidationData ;
233
- internal readonly Pipeline Pipeline ;
234
-
235
- internal MulticlassClassificationIterationResult ( ITransformer model , MultiClassClassifierMetrics metrics , IDataView scoredValidationData , Pipeline pipeline )
236
- {
237
- Model = model ;
238
- Metrics = metrics ;
239
- ScoredValidationData = scoredValidationData ;
240
- Pipeline = pipeline ;
131
+ var autoFitter = new AutoFitter < MultiClassClassifierMetrics > ( TaskKind . MulticlassClassification , trainData , label , validationData ,
132
+ settings , preFeaturizers , columnPurposes , OptimizingMetric . RSquared , debugLogger ) ;
133
+ return autoFitter . Fit ( ) ;
241
134
}
242
135
}
243
136
244
- public class RegressionIterationResult
137
+ public class IterationResult < T >
245
138
{
246
- public readonly RegressionMetrics Metrics ;
139
+ public readonly T Metrics ;
247
140
public readonly ITransformer Model ;
248
- public readonly IDataView ScoredValidationData ;
141
+ public readonly Exception Exception ;
249
142
internal readonly Pipeline Pipeline ;
250
143
251
- internal RegressionIterationResult ( ITransformer model , RegressionMetrics metrics , IDataView scoredValidationData , Pipeline pipeline )
144
+ internal IterationResult ( ITransformer model , T metrics , Pipeline pipeline , Exception exception )
252
145
{
253
146
Model = model ;
254
147
Metrics = metrics ;
255
- ScoredValidationData = scoredValidationData ;
256
148
Pipeline = pipeline ;
149
+ Exception = exception ;
257
150
}
258
151
}
259
152
}
0 commit comments