Skip to content

Commit bbbd341

Browse files
AutoFit return type is now an IEnumerable (dotnet#55)
AutoFit returns is now an IEnumerable - this enables many good things Implementing variety of early stopping criteria (See sample) Early discard of models that are no good. This improves memory usage efficiency. (See sample) No need to implement a callback to get results back Getting best score is now outside of API implementation. It is a simple math function to compare scores (See sample). Also templatized the return type for better type safety through out the code.
1 parent e6fa88e commit bbbd341

17 files changed

+365
-357
lines changed

src/AutoML/API/MLContextAutoFitExtensions.cs

+31-138
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,28 @@ namespace Microsoft.ML.Auto
1212
{
1313
public static class RegressionExtensions
1414
{
15-
public static RegressionResult AutoFit(this RegressionContext context,
15+
public static IEnumerable<IterationResult<RegressionMetrics>> AutoFit(this RegressionContext context,
1616
IDataView trainData,
1717
string label = DefaultColumnNames.Label,
1818
IDataView validationData = null,
1919
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
2020
IEstimator<ITransformer> preFeaturizers = null,
21-
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
22-
CancellationToken cancellationToken = default,
23-
IProgress<RegressionIterationResult> iterationCallback = null)
21+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null)
2422
{
2523
var settings = new AutoFitSettings();
2624
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
2725

2826
return AutoFit(context, trainData, label, validationData, settings,
29-
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
27+
preFeaturizers, columnPurposes, null);
3028
}
3129

32-
internal static RegressionResult AutoFit(this RegressionContext context,
30+
internal static IEnumerable<IterationResult<RegressionMetrics>> AutoFit(this RegressionContext context,
3331
IDataView trainData,
3432
string label = DefaultColumnNames.Label,
3533
IDataView validationData = null,
3634
AutoFitSettings settings = null,
3735
IEstimator<ITransformer> preFeaturizers = null,
3836
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
39-
CancellationToken cancellationToken = default,
40-
IProgress<RegressionIterationResult> iterationCallback = null,
4137
IDebugLogger debugLogger = null)
4238
{
4339
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
@@ -48,49 +44,38 @@ internal static RegressionResult AutoFit(this RegressionContext context,
4844
}
4945

5046
// run autofit & get all pipelines run in that process
51-
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
52-
settings, preFeaturizers, TaskKind.Regression, OptimizingMetric.RSquared, columnPurposes, debugLogger);
47+
var autoFitter = new AutoFitter<RegressionMetrics>(TaskKind.Regression, trainData, label, validationData,
48+
settings, preFeaturizers, columnPurposes,
49+
OptimizingMetric.RSquared, debugLogger);
5350

54-
var results = new RegressionIterationResult[allPipelines.Length];
55-
for (var i = 0; i < results.Length; i++)
56-
{
57-
var iterationResult = allPipelines[i];
58-
var result = new RegressionIterationResult(iterationResult.Model, (RegressionMetrics)iterationResult.EvaluatedMetrics, iterationResult.ScoredValidationData, iterationResult.Pipeline.ToPipeline());
59-
results[i] = result;
60-
}
61-
var bestResult = new RegressionIterationResult(bestPipeline.Model, (RegressionMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
62-
return new RegressionResult(bestResult, results);
51+
return autoFitter.Fit();
6352
}
6453
}
6554

6655
public static class BinaryClassificationExtensions
6756
{
68-
public static BinaryClassificationResult AutoFit(this BinaryClassificationContext context,
57+
public static IEnumerable<IterationResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationContext context,
6958
IDataView trainData,
7059
string label = DefaultColumnNames.Label,
7160
IDataView validationData = null,
7261
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
7362
IEstimator<ITransformer> preFeaturizers = null,
74-
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
75-
CancellationToken cancellationToken = default,
76-
IProgress<BinaryClassificationItertionResult> iterationCallback = null)
63+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null)
7764
{
7865
var settings = new AutoFitSettings();
7966
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
8067

8168
return AutoFit(context, trainData, label, validationData, settings,
82-
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
69+
preFeaturizers, columnPurposes, null);
8370
}
8471

85-
internal static BinaryClassificationResult AutoFit(this BinaryClassificationContext context,
72+
internal static IEnumerable<IterationResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationContext context,
8673
IDataView trainData,
8774
string label = DefaultColumnNames.Label,
8875
IDataView validationData = null,
8976
AutoFitSettings settings = null,
9077
IEstimator<ITransformer> preFeaturizers = null,
9178
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
92-
CancellationToken cancellationToken = default,
93-
IProgress<BinaryClassificationItertionResult> iterationCallback = null,
9479
IDebugLogger debugLogger = null)
9580
{
9681
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
@@ -101,159 +86,67 @@ internal static BinaryClassificationResult AutoFit(this BinaryClassificationCont
10186
}
10287

10388
// run autofit & get all pipelines run in that process
104-
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
105-
settings, preFeaturizers, TaskKind.BinaryClassification, OptimizingMetric.Accuracy,
106-
columnPurposes, debugLogger);
107-
108-
var results = new BinaryClassificationItertionResult[allPipelines.Length];
109-
for (var i = 0; i < results.Length; i++)
110-
{
111-
var iterationResult = allPipelines[i];
112-
var result = new BinaryClassificationItertionResult(iterationResult.Model, (BinaryClassificationMetrics)iterationResult.EvaluatedMetrics, iterationResult.ScoredValidationData, iterationResult.Pipeline.ToPipeline());
113-
results[i] = result;
114-
}
115-
var bestResult = new BinaryClassificationItertionResult(bestPipeline.Model, (BinaryClassificationMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
116-
return new BinaryClassificationResult(bestResult, results);
89+
var autoFitter = new AutoFitter<BinaryClassificationMetrics>(TaskKind.BinaryClassification, trainData, label, validationData,
90+
settings, preFeaturizers, columnPurposes,
91+
OptimizingMetric.RSquared, debugLogger);
92+
93+
return autoFitter.Fit();
11794
}
11895
}
11996

12097
public static class MulticlassExtensions
12198
{
122-
public static MulticlassClassificationResult AutoFit(this MulticlassClassificationContext context,
99+
public static IEnumerable<IterationResult<MultiClassClassifierMetrics>> AutoFit(this MulticlassClassificationContext context,
123100
IDataView trainData,
124101
string label = DefaultColumnNames.Label,
125102
IDataView validationData = null,
126103
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
127104
IEstimator<ITransformer> preFeaturizers = null,
128-
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
129-
CancellationToken cancellationToken = default,
130-
IProgress<MulticlassClassificationIterationResult> iterationCallback = null)
105+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null)
131106
{
132107
var settings = new AutoFitSettings();
133108
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
134109

135110
return AutoFit(context, trainData, label, validationData, settings,
136-
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
111+
preFeaturizers, columnPurposes, null);
137112
}
138113

139-
internal static MulticlassClassificationResult AutoFit(this MulticlassClassificationContext context,
114+
internal static IEnumerable<IterationResult<MultiClassClassifierMetrics>> AutoFit(this MulticlassClassificationContext context,
140115
IDataView trainData,
141116
string label = DefaultColumnNames.Label,
142117
IDataView validationData = null,
143118
AutoFitSettings settings = null,
144119
IEstimator<ITransformer> preFeaturizers = null,
145120
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
146-
CancellationToken cancellationToken = default,
147-
IProgress<MulticlassClassificationIterationResult> iterationCallback = null, IDebugLogger debugLogger = null)
121+
IDebugLogger debugLogger = null)
148122
{
149123
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
150124

151125
if (validationData == null)
152126
{
153127
(trainData, validationData) = context.TestValidateSplit(trainData);
154128
}
155-
129+
156130
// run autofit & get all pipelines run in that process
157-
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
158-
settings, preFeaturizers, TaskKind.MulticlassClassification, OptimizingMetric.Accuracy,
159-
columnPurposes, debugLogger);
160-
161-
var results = new MulticlassClassificationIterationResult[allPipelines.Length];
162-
for (var i = 0; i < results.Length; i++)
163-
{
164-
var iterationResult = allPipelines[i];
165-
var result = new MulticlassClassificationIterationResult(iterationResult.Model, (MultiClassClassifierMetrics)iterationResult.EvaluatedMetrics, iterationResult.ScoredValidationData, iterationResult.Pipeline.ToPipeline());
166-
results[i] = result;
167-
}
168-
var bestResult = new MulticlassClassificationIterationResult(bestPipeline.Model, (MultiClassClassifierMetrics)bestPipeline.EvaluatedMetrics, bestPipeline.ScoredValidationData, bestPipeline.Pipeline.ToPipeline());
169-
return new MulticlassClassificationResult(bestResult, results);
170-
}
171-
}
172-
173-
public class BinaryClassificationResult
174-
{
175-
public readonly BinaryClassificationItertionResult BestIteration;
176-
public readonly BinaryClassificationItertionResult[] IterationResults;
177-
178-
public BinaryClassificationResult(BinaryClassificationItertionResult bestPipeline,
179-
BinaryClassificationItertionResult[] iterationResults)
180-
{
181-
BestIteration = bestPipeline;
182-
IterationResults = iterationResults;
183-
}
184-
}
185-
186-
public class MulticlassClassificationResult
187-
{
188-
public readonly MulticlassClassificationIterationResult BestIteration;
189-
public readonly MulticlassClassificationIterationResult[] IterationResults;
190-
191-
public MulticlassClassificationResult(MulticlassClassificationIterationResult bestPipeline,
192-
MulticlassClassificationIterationResult[] iterationResults)
193-
{
194-
BestIteration = bestPipeline;
195-
IterationResults = iterationResults;
196-
}
197-
}
198-
199-
public class RegressionResult
200-
{
201-
public readonly RegressionIterationResult BestIteration;
202-
public readonly RegressionIterationResult[] IterationResults;
203-
204-
public RegressionResult(RegressionIterationResult bestPipeline,
205-
RegressionIterationResult[] iterationResults)
206-
{
207-
BestIteration = bestPipeline;
208-
IterationResults = iterationResults;
209-
}
210-
}
211-
212-
public class BinaryClassificationItertionResult
213-
{
214-
public readonly BinaryClassificationMetrics Metrics;
215-
public readonly ITransformer Model;
216-
public readonly IDataView ScoredValidationData;
217-
internal readonly Pipeline Pipeline;
218-
219-
internal BinaryClassificationItertionResult(ITransformer model, BinaryClassificationMetrics metrics, IDataView scoredValidationData, Pipeline pipeline)
220-
{
221-
Model = model;
222-
ScoredValidationData = scoredValidationData;
223-
Metrics = metrics;
224-
Pipeline = pipeline;
225-
}
226-
}
227-
228-
public class MulticlassClassificationIterationResult
229-
{
230-
public readonly MultiClassClassifierMetrics Metrics;
231-
public readonly ITransformer Model;
232-
public readonly IDataView ScoredValidationData;
233-
internal readonly Pipeline Pipeline;
234-
235-
internal MulticlassClassificationIterationResult(ITransformer model, MultiClassClassifierMetrics metrics, IDataView scoredValidationData, Pipeline pipeline)
236-
{
237-
Model = model;
238-
Metrics = metrics;
239-
ScoredValidationData = scoredValidationData;
240-
Pipeline = pipeline;
131+
var autoFitter = new AutoFitter<MultiClassClassifierMetrics>(TaskKind.MulticlassClassification, trainData, label, validationData,
132+
settings, preFeaturizers, columnPurposes, OptimizingMetric.RSquared, debugLogger);
133+
return autoFitter.Fit();
241134
}
242135
}
243136

244-
public class RegressionIterationResult
137+
public class IterationResult<T>
245138
{
246-
public readonly RegressionMetrics Metrics;
139+
public readonly T Metrics;
247140
public readonly ITransformer Model;
248-
public readonly IDataView ScoredValidationData;
141+
public readonly Exception Exception;
249142
internal readonly Pipeline Pipeline;
250143

251-
internal RegressionIterationResult(ITransformer model, RegressionMetrics metrics, IDataView scoredValidationData, Pipeline pipeline)
144+
internal IterationResult(ITransformer model, T metrics, Pipeline pipeline, Exception exception)
252145
{
253146
Model = model;
254147
Metrics = metrics;
255-
ScoredValidationData = scoredValidationData;
256148
Pipeline = pipeline;
149+
Exception = exception;
257150
}
258151
}
259152
}

src/AutoML/API/Pipeline.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ internal Pipeline()
2323

2424
public IEstimator<ITransformer> ToEstimator()
2525
{
26-
var inferredPipeline = InferredPipeline.FromPipeline(this);
26+
var inferredPipeline = SuggestedPipeline.FromPipeline(this);
2727
return inferredPipeline.ToEstimator();
2828
}
2929
}
@@ -87,7 +87,7 @@ internal CustomProperty()
8787
}
8888
}
8989

90-
internal class PipelineRunResult
90+
internal class PipelineScore
9191
{
9292
public readonly double Score;
9393

@@ -99,7 +99,7 @@ internal class PipelineRunResult
9999

100100
internal readonly Pipeline Pipeline;
101101

102-
internal PipelineRunResult(Pipeline pipeline, double score, bool runSucceeded)
102+
internal PipelineScore(Pipeline pipeline, double score, bool runSucceeded)
103103
{
104104
Pipeline = pipeline;
105105
Score = score;

src/AutoML/AutoFitter/AutoFitApi.cs

-53
This file was deleted.

0 commit comments

Comments
 (0)