Skip to content

Commit e4a64cf

Browse files
First public api propsal (#52)
* Includes following 1) Final proposal for 0.1 public API surface 2) Prefeaturization 3) Splitting train data into train and validate when validation data is null 4) Providing end to end samples one each for regression, binaryclassification and multiclass classification * Incorporating code review feedbacks
1 parent 41c663c commit e4a64cf

21 files changed

+200991
-310
lines changed

src/AutoML/API/AutoFitSettings.cs

+27-17
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,38 @@
77

88
namespace Microsoft.ML.Auto
99
{
10-
public class AutoFitSettings
10+
internal static class AutoFitDefaults
1111
{
12+
public const uint TimeOutInMinutes = 24 * 60;
13+
public const uint MaxIterations = 1000;
14+
}
15+
16+
internal class AutoFitSettings
17+
{
18+
// All the following settings only capture the surface area of capabilities we want to ship in future.
19+
// However, most certainly they will not ship using following types and structures
20+
// These should remain internal until we have rationalized
21+
1222
public ExperimentStoppingCriteria StoppingCriteria = new ExperimentStoppingCriteria();
1323
internal IterationStoppingCriteria IterationStoppingCriteria;
1424
internal Concurrency Concurrency;
1525
internal Filters Filters;
1626
internal CrossValidationSettings CrossValidationSettings;
1727
internal OptimizingMetric OptimizingMetric;
18-
internal bool EnableEnsembling;
19-
internal bool EnableModelExplainability;
20-
internal bool EnableAutoTransformation;
28+
internal bool DisableEnsembling;
29+
internal bool CaclculateModelExplainability;
30+
internal bool DisableFeaturization;
2131

22-
// spec question: Are following automatic or a user setting?
23-
internal bool EnableSubSampling;
24-
internal bool EnableCaching;
32+
internal bool DisableSubSampling;
33+
internal bool DisableCaching;
2534
internal bool ExternalizeTraining;
26-
internal TraceLevel TraceLevel; // Should this be controlled through code or appconfig?
35+
internal TraceLevel TraceLevel;
2736
}
2837

29-
public class ExperimentStoppingCriteria
38+
internal class ExperimentStoppingCriteria
3039
{
31-
public int MaxIterations = 100;
32-
public int TimeOutInMinutes = 300;
40+
public uint TimeOutInMinutes = AutoFitDefaults.TimeOutInMinutes;
41+
public uint MaxIterations = AutoFitDefaults.MaxIterations;
3342
internal bool StopAfterConverging;
3443
internal double ExperimentExitScore;
3544
}
@@ -40,19 +49,20 @@ internal class Filters
4049
internal IEnumerable<Trainers> BlackListTrainers;
4150
internal IEnumerable<Transformers> WhitelistTransformers;
4251
internal IEnumerable<Transformers> BlacklistTransformers;
43-
internal bool PreferExplainability;
44-
internal bool PreferInferenceSpeed;
45-
internal bool PreferSmallDeploymentSize;
46-
internal bool PreferSmallMemoryFootprint;
52+
internal uint? Explainability;
53+
internal uint? InferenceSpeed;
54+
internal uint? DeploymentSize;
55+
internal uint? TrainingMemorySize;
56+
internal bool? GpuTraining;
4757
}
4858

49-
public class IterationStoppingCriteria
59+
internal class IterationStoppingCriteria
5060
{
5161
internal int TimeOutInSeconds;
5262
internal bool TerminateOnLowAccuracy;
5363
}
5464

55-
public class Concurrency
65+
internal class Concurrency
5666
{
5767
internal int MaxConcurrentIterations;
5868
internal int MaxCoresPerIteration;

src/AutoML/API/MLContextAutoFitExtensions.cs

+68-38
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,42 @@ public static class RegressionExtensions
1414
{
1515
public static RegressionResult AutoFit(this RegressionContext context,
1616
IDataView trainData,
17-
string label,
18-
IDataView validationData,
19-
AutoFitSettings settings = null,
20-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
17+
string label = DefaultColumnNames.Label,
18+
IDataView validationData = null,
19+
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
20+
IEstimator<ITransformer> preFeaturizers = null,
21+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
2122
CancellationToken cancellationToken = default,
2223
IProgress<RegressionIterationResult> iterationCallback = null)
2324
{
25+
var settings = new AutoFitSettings();
26+
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
27+
2428
return AutoFit(context, trainData, label, validationData, settings,
25-
purposeOverrides, cancellationToken, iterationCallback, null);
29+
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
2630
}
2731

2832
internal static RegressionResult AutoFit(this RegressionContext context,
2933
IDataView trainData,
30-
string label,
31-
IDataView validationData,
34+
string label = DefaultColumnNames.Label,
35+
IDataView validationData = null,
3236
AutoFitSettings settings = null,
33-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
37+
IEstimator<ITransformer> preFeaturizers = null,
38+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
3439
CancellationToken cancellationToken = default,
3540
IProgress<RegressionIterationResult> iterationCallback = null,
3641
IDebugLogger debugLogger = null)
3742
{
38-
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, purposeOverrides);
43+
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
44+
45+
if (validationData == null)
46+
{
47+
(trainData, validationData) = context.TestValidateSplit(trainData);
48+
}
3949

4050
// run autofit & get all pipelines run in that process
4151
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
42-
settings, TaskKind.Regression, OptimizingMetric.RSquared, purposeOverrides, debugLogger);
52+
settings, preFeaturizers, TaskKind.Regression, OptimizingMetric.RSquared, columnPurposes, debugLogger);
4353

4454
var results = new RegressionIterationResult[allPipelines.Length];
4555
for (var i = 0; i < results.Length; i++)
@@ -57,33 +67,43 @@ public static class BinaryClassificationExtensions
5767
{
5868
public static BinaryClassificationResult AutoFit(this BinaryClassificationContext context,
5969
IDataView trainData,
60-
string label,
61-
IDataView validationData,
62-
AutoFitSettings settings = null,
63-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
70+
string label = DefaultColumnNames.Label,
71+
IDataView validationData = null,
72+
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
73+
IEstimator<ITransformer> preFeaturizers = null,
74+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
6475
CancellationToken cancellationToken = default,
6576
IProgress<BinaryClassificationItertionResult> iterationCallback = null)
6677
{
78+
var settings = new AutoFitSettings();
79+
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
80+
6781
return AutoFit(context, trainData, label, validationData, settings,
68-
purposeOverrides, cancellationToken, iterationCallback, null);
82+
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
6983
}
7084

7185
internal static BinaryClassificationResult AutoFit(this BinaryClassificationContext context,
7286
IDataView trainData,
73-
string label,
74-
IDataView validationData,
87+
string label = DefaultColumnNames.Label,
88+
IDataView validationData = null,
7589
AutoFitSettings settings = null,
76-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
90+
IEstimator<ITransformer> preFeaturizers = null,
91+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
7792
CancellationToken cancellationToken = default,
7893
IProgress<BinaryClassificationItertionResult> iterationCallback = null,
7994
IDebugLogger debugLogger = null)
8095
{
81-
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, purposeOverrides);
96+
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
97+
98+
if (validationData == null)
99+
{
100+
(trainData, validationData) = context.TestValidateSplit(trainData);
101+
}
82102

83103
// run autofit & get all pipelines run in that process
84104
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
85-
settings, TaskKind.BinaryClassification, OptimizingMetric.Accuracy,
86-
purposeOverrides, debugLogger);
105+
settings, preFeaturizers, TaskKind.BinaryClassification, OptimizingMetric.Accuracy,
106+
columnPurposes, debugLogger);
87107

88108
var results = new BinaryClassificationItertionResult[allPipelines.Length];
89109
for (var i = 0; i < results.Length; i++)
@@ -101,32 +121,42 @@ public static class MulticlassExtensions
101121
{
102122
public static MulticlassClassificationResult AutoFit(this MulticlassClassificationContext context,
103123
IDataView trainData,
104-
string label,
105-
IDataView validationData,
106-
AutoFitSettings settings = null,
107-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
124+
string label = DefaultColumnNames.Label,
125+
IDataView validationData = null,
126+
uint timeoutInMinutes = AutoFitDefaults.TimeOutInMinutes,
127+
IEstimator<ITransformer> preFeaturizers = null,
128+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
108129
CancellationToken cancellationToken = default,
109130
IProgress<MulticlassClassificationIterationResult> iterationCallback = null)
110131
{
132+
var settings = new AutoFitSettings();
133+
settings.StoppingCriteria.TimeOutInMinutes = timeoutInMinutes;
134+
111135
return AutoFit(context, trainData, label, validationData, settings,
112-
purposeOverrides, cancellationToken, iterationCallback, null);
136+
preFeaturizers, columnPurposes, cancellationToken, iterationCallback, null);
113137
}
114138

115139
internal static MulticlassClassificationResult AutoFit(this MulticlassClassificationContext context,
116140
IDataView trainData,
117-
string label,
118-
IDataView validationData,
141+
string label = DefaultColumnNames.Label,
142+
IDataView validationData = null,
119143
AutoFitSettings settings = null,
120-
IEnumerable<(string, ColumnPurpose)> purposeOverrides = null,
144+
IEstimator<ITransformer> preFeaturizers = null,
145+
IEnumerable<(string, ColumnPurpose)> columnPurposes = null,
121146
CancellationToken cancellationToken = default,
122147
IProgress<MulticlassClassificationIterationResult> iterationCallback = null, IDebugLogger debugLogger = null)
123148
{
124-
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, purposeOverrides);
149+
UserInputValidationUtil.ValidateAutoFitArgs(trainData, label, validationData, settings, columnPurposes);
150+
151+
if (validationData == null)
152+
{
153+
(trainData, validationData) = context.TestValidateSplit(trainData);
154+
}
125155

126156
// run autofit & get all pipelines run in that process
127157
var (allPipelines, bestPipeline) = AutoFitApi.Fit(trainData, validationData, label,
128-
settings, TaskKind.MulticlassClassification, OptimizingMetric.Accuracy,
129-
purposeOverrides, debugLogger);
158+
settings, preFeaturizers, TaskKind.MulticlassClassification, OptimizingMetric.Accuracy,
159+
columnPurposes, debugLogger);
130160

131161
var results = new MulticlassClassificationIterationResult[allPipelines.Length];
132162
for (var i = 0; i < results.Length; i++)
@@ -142,39 +172,39 @@ internal static MulticlassClassificationResult AutoFit(this MulticlassClassifica
142172

143173
public class BinaryClassificationResult
144174
{
145-
public readonly BinaryClassificationItertionResult BestPipeline;
175+
public readonly BinaryClassificationItertionResult BestIteration;
146176
public readonly BinaryClassificationItertionResult[] IterationResults;
147177

148178
public BinaryClassificationResult(BinaryClassificationItertionResult bestPipeline,
149179
BinaryClassificationItertionResult[] iterationResults)
150180
{
151-
BestPipeline = bestPipeline;
181+
BestIteration = bestPipeline;
152182
IterationResults = iterationResults;
153183
}
154184
}
155185

156186
public class MulticlassClassificationResult
157187
{
158-
public readonly MulticlassClassificationIterationResult BestPipeline;
188+
public readonly MulticlassClassificationIterationResult BestIteration;
159189
public readonly MulticlassClassificationIterationResult[] IterationResults;
160190

161191
public MulticlassClassificationResult(MulticlassClassificationIterationResult bestPipeline,
162192
MulticlassClassificationIterationResult[] iterationResults)
163193
{
164-
BestPipeline = bestPipeline;
194+
BestIteration = bestPipeline;
165195
IterationResults = iterationResults;
166196
}
167197
}
168198

169199
public class RegressionResult
170200
{
171-
public readonly RegressionIterationResult BestPipeline;
201+
public readonly RegressionIterationResult BestIteration;
172202
public readonly RegressionIterationResult[] IterationResults;
173203

174204
public RegressionResult(RegressionIterationResult bestPipeline,
175205
RegressionIterationResult[] iterationResults)
176206
{
177-
BestPipeline = bestPipeline;
207+
BestIteration = bestPipeline;
178208
IterationResults = iterationResults;
179209
}
180210
}

src/AutoML/AutoFitter/AutoFitApi.cs

+20-1
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,29 @@
44

55
using System.Collections.Generic;
66
using System.Linq;
7+
using Microsoft.ML.Core.Data;
78
using Microsoft.ML.Data;
89

910
namespace Microsoft.ML.Auto
1011
{
1112
internal static class AutoFitApi
1213
{
1314
public static (InferredPipelineRunResult[] allPipelines, InferredPipelineRunResult bestPipeline) Fit(IDataView trainData,
14-
IDataView validationData, string label, AutoFitSettings settings, TaskKind task, OptimizingMetric metric,
15+
IDataView validationData, string label, AutoFitSettings settings, IEstimator<ITransformer> preFeaturizers, TaskKind task, OptimizingMetric metric,
1516
IEnumerable<(string, ColumnPurpose)> purposeOverrides, IDebugLogger debugLogger)
1617
{
1718
// hack: init new MLContext
1819
var mlContext = new MLContext();
1920

21+
ITransformer preprocessorTransform = null;
22+
if (preFeaturizers != null)
23+
{
24+
// preprocess train and validation data
25+
preprocessorTransform = preFeaturizers.Fit(trainData);
26+
trainData = preprocessorTransform.Transform(trainData);
27+
validationData = preprocessorTransform.Transform(validationData);
28+
}
29+
2030
var purposeOverridesDict = purposeOverrides?.ToDictionary(p => p.Item1, p => p.Item2);
2131
var optimizingMetricfInfo = new OptimizingMetricInfo(metric);
2232

@@ -25,6 +35,15 @@ public static (InferredPipelineRunResult[] allPipelines, InferredPipelineRunResu
2535
label, trainData, validationData, purposeOverridesDict, debugLogger);
2636
var allPipelines = autoFitter.Fit();
2737

38+
// apply preprocessor to returned models
39+
if (preprocessorTransform != null)
40+
{
41+
for (var i = 0; i < allPipelines.Length; i++)
42+
{
43+
allPipelines[i].Model = preprocessorTransform.Append(allPipelines[i].Model);
44+
}
45+
}
46+
2847
var bestScore = allPipelines.Max(p => p.Score);
2948
var bestPipeline = allPipelines.First(p => p.Score == bestScore);
3049

src/AutoML/AutoFitter/AutoFitter.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public InferredPipelineRunResult[] Fit()
4848
do
4949
{
5050
// get next pipeline
51-
var iterationsRemaining = _settings.StoppingCriteria.MaxIterations - _history.Count;
51+
var iterationsRemaining = (int)_settings.StoppingCriteria.MaxIterations - _history.Count;
5252
var pipeline = PipelineSuggester.GetNextInferredPipeline(_history, columns, _task, iterationsRemaining, _optimizingMetricInfo.IsMaximizing);
5353

5454
// break if no candidates returned, means no valid pipeline available

src/AutoML/AutoMlUtils.cs

+14
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@ public static IDataView Take(this IDataView data, int count)
3030
return new CacheDataView(context, filter, Enumerable.Range(0, data.Schema.Count).ToArray());
3131
}
3232

33+
public static IDataView DropLastColumn(this IDataView data)
34+
{
35+
return new MLContext().Transforms.DropColumns(data.Schema[data.Schema.Count - 1].Name).Fit(data).Transform(data);
36+
}
37+
38+
public static (IDataView testData, IDataView validationData) TestValidateSplit(this TrainContextBase context, IDataView trainData)
39+
{
40+
IDataView validationData;
41+
(trainData, validationData) = context.TrainTestSplit(trainData);
42+
trainData = trainData.DropLastColumn();
43+
validationData = validationData.DropLastColumn();
44+
return (trainData, validationData);
45+
}
46+
3347
public static IDataView Skip(this IDataView data, int count)
3448
{
3549
var context = new MLContext();

src/AutoML/Utils/UserInputValidationUtil.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
112112
{
113113
if(validationData == null)
114114
{
115-
throw new ArgumentNullException("Validation data cannot be null", nameof(validationData));
115+
return;
116116
}
117117

118118
const string schemaMismatchError = "Training data and validation data schemas do not match.";

0 commit comments

Comments
 (0)