@@ -14,32 +14,42 @@ public static class RegressionExtensions
14
14
{
15
15
public static RegressionResult AutoFit ( this RegressionContext context ,
16
16
IDataView trainData ,
17
- string label ,
18
- IDataView validationData ,
19
- AutoFitSettings settings = null ,
20
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
17
+ string label = DefaultColumnNames . Label ,
18
+ IDataView validationData = null ,
19
+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
20
+ IEstimator < ITransformer > preFeaturizers = null ,
21
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
21
22
CancellationToken cancellationToken = default ,
22
23
IProgress < RegressionIterationResult > iterationCallback = null )
23
24
{
25
+ var settings = new AutoFitSettings ( ) ;
26
+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
27
+
24
28
return AutoFit ( context , trainData , label , validationData , settings ,
25
- purposeOverrides , cancellationToken , iterationCallback , null ) ;
29
+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
26
30
}
27
31
28
32
internal static RegressionResult AutoFit ( this RegressionContext context ,
29
33
IDataView trainData ,
30
- string label ,
31
- IDataView validationData ,
34
+ string label = DefaultColumnNames . Label ,
35
+ IDataView validationData = null ,
32
36
AutoFitSettings settings = null ,
33
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
37
+ IEstimator < ITransformer > preFeaturizers = null ,
38
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
34
39
CancellationToken cancellationToken = default ,
35
40
IProgress < RegressionIterationResult > iterationCallback = null ,
36
41
IDebugLogger debugLogger = null )
37
42
{
38
- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
43
+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
44
+
45
+ if ( validationData == null )
46
+ {
47
+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
48
+ }
39
49
40
50
// run autofit & get all pipelines run in that process
41
51
var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
42
- settings , TaskKind . Regression , OptimizingMetric . RSquared , purposeOverrides , debugLogger ) ;
52
+ settings , preFeaturizers , TaskKind . Regression , OptimizingMetric . RSquared , columnPurposes , debugLogger ) ;
43
53
44
54
var results = new RegressionIterationResult [ allPipelines . Length ] ;
45
55
for ( var i = 0 ; i < results . Length ; i ++ )
@@ -57,33 +67,43 @@ public static class BinaryClassificationExtensions
57
67
{
58
68
public static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
59
69
IDataView trainData ,
60
- string label ,
61
- IDataView validationData ,
62
- AutoFitSettings settings = null ,
63
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
70
+ string label = DefaultColumnNames . Label ,
71
+ IDataView validationData = null ,
72
+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
73
+ IEstimator < ITransformer > preFeaturizers = null ,
74
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
64
75
CancellationToken cancellationToken = default ,
65
76
IProgress < BinaryClassificationItertionResult > iterationCallback = null )
66
77
{
78
+ var settings = new AutoFitSettings ( ) ;
79
+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
80
+
67
81
return AutoFit ( context , trainData , label , validationData , settings ,
68
- purposeOverrides , cancellationToken , iterationCallback , null ) ;
82
+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
69
83
}
70
84
71
85
internal static BinaryClassificationResult AutoFit ( this BinaryClassificationContext context ,
72
86
IDataView trainData ,
73
- string label ,
74
- IDataView validationData ,
87
+ string label = DefaultColumnNames . Label ,
88
+ IDataView validationData = null ,
75
89
AutoFitSettings settings = null ,
76
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
90
+ IEstimator < ITransformer > preFeaturizers = null ,
91
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
77
92
CancellationToken cancellationToken = default ,
78
93
IProgress < BinaryClassificationItertionResult > iterationCallback = null ,
79
94
IDebugLogger debugLogger = null )
80
95
{
81
- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
96
+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
97
+
98
+ if ( validationData == null )
99
+ {
100
+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
101
+ }
82
102
83
103
// run autofit & get all pipelines run in that process
84
104
var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
85
- settings , TaskKind . BinaryClassification , OptimizingMetric . Accuracy ,
86
- purposeOverrides , debugLogger ) ;
105
+ settings , preFeaturizers , TaskKind . BinaryClassification , OptimizingMetric . Accuracy ,
106
+ columnPurposes , debugLogger ) ;
87
107
88
108
var results = new BinaryClassificationItertionResult [ allPipelines . Length ] ;
89
109
for ( var i = 0 ; i < results . Length ; i ++ )
@@ -101,32 +121,42 @@ public static class MulticlassExtensions
101
121
{
102
122
public static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
103
123
IDataView trainData ,
104
- string label ,
105
- IDataView validationData ,
106
- AutoFitSettings settings = null ,
107
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
124
+ string label = DefaultColumnNames . Label ,
125
+ IDataView validationData = null ,
126
+ uint timeoutInMinutes = AutoFitDefaults . TimeOutInMinutes ,
127
+ IEstimator < ITransformer > preFeaturizers = null ,
128
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
108
129
CancellationToken cancellationToken = default ,
109
130
IProgress < MulticlassClassificationIterationResult > iterationCallback = null )
110
131
{
132
+ var settings = new AutoFitSettings ( ) ;
133
+ settings . StoppingCriteria . TimeOutInMinutes = timeoutInMinutes ;
134
+
111
135
return AutoFit ( context , trainData , label , validationData , settings ,
112
- purposeOverrides , cancellationToken , iterationCallback , null ) ;
136
+ preFeaturizers , columnPurposes , cancellationToken , iterationCallback , null ) ;
113
137
}
114
138
115
139
internal static MulticlassClassificationResult AutoFit ( this MulticlassClassificationContext context ,
116
140
IDataView trainData ,
117
- string label ,
118
- IDataView validationData ,
141
+ string label = DefaultColumnNames . Label ,
142
+ IDataView validationData = null ,
119
143
AutoFitSettings settings = null ,
120
- IEnumerable < ( string , ColumnPurpose ) > purposeOverrides = null ,
144
+ IEstimator < ITransformer > preFeaturizers = null ,
145
+ IEnumerable < ( string , ColumnPurpose ) > columnPurposes = null ,
121
146
CancellationToken cancellationToken = default ,
122
147
IProgress < MulticlassClassificationIterationResult > iterationCallback = null , IDebugLogger debugLogger = null )
123
148
{
124
- UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , purposeOverrides ) ;
149
+ UserInputValidationUtil . ValidateAutoFitArgs ( trainData , label , validationData , settings , columnPurposes ) ;
150
+
151
+ if ( validationData == null )
152
+ {
153
+ ( trainData , validationData ) = context . TestValidateSplit ( trainData ) ;
154
+ }
125
155
126
156
// run autofit & get all pipelines run in that process
127
157
var ( allPipelines , bestPipeline ) = AutoFitApi . Fit ( trainData , validationData , label ,
128
- settings , TaskKind . MulticlassClassification , OptimizingMetric . Accuracy ,
129
- purposeOverrides , debugLogger ) ;
158
+ settings , preFeaturizers , TaskKind . MulticlassClassification , OptimizingMetric . Accuracy ,
159
+ columnPurposes , debugLogger ) ;
130
160
131
161
var results = new MulticlassClassificationIterationResult [ allPipelines . Length ] ;
132
162
for ( var i = 0 ; i < results . Length ; i ++ )
@@ -142,39 +172,39 @@ internal static MulticlassClassificationResult AutoFit(this MulticlassClassifica
142
172
143
173
public class BinaryClassificationResult
144
174
{
145
- public readonly BinaryClassificationItertionResult BestPipeline ;
175
+ public readonly BinaryClassificationItertionResult BestIteration ;
146
176
public readonly BinaryClassificationItertionResult [ ] IterationResults ;
147
177
148
178
public BinaryClassificationResult ( BinaryClassificationItertionResult bestPipeline ,
149
179
BinaryClassificationItertionResult [ ] iterationResults )
150
180
{
151
- BestPipeline = bestPipeline ;
181
+ BestIteration = bestPipeline ;
152
182
IterationResults = iterationResults ;
153
183
}
154
184
}
155
185
156
186
public class MulticlassClassificationResult
157
187
{
158
- public readonly MulticlassClassificationIterationResult BestPipeline ;
188
+ public readonly MulticlassClassificationIterationResult BestIteration ;
159
189
public readonly MulticlassClassificationIterationResult [ ] IterationResults ;
160
190
161
191
public MulticlassClassificationResult ( MulticlassClassificationIterationResult bestPipeline ,
162
192
MulticlassClassificationIterationResult [ ] iterationResults )
163
193
{
164
- BestPipeline = bestPipeline ;
194
+ BestIteration = bestPipeline ;
165
195
IterationResults = iterationResults ;
166
196
}
167
197
}
168
198
169
199
public class RegressionResult
170
200
{
171
- public readonly RegressionIterationResult BestPipeline ;
201
+ public readonly RegressionIterationResult BestIteration ;
172
202
public readonly RegressionIterationResult [ ] IterationResults ;
173
203
174
204
public RegressionResult ( RegressionIterationResult bestPipeline ,
175
205
RegressionIterationResult [ ] iterationResults )
176
206
{
177
- BestPipeline = bestPipeline ;
207
+ BestIteration = bestPipeline ;
178
208
IterationResults = iterationResults ;
179
209
}
180
210
}
0 commit comments