Skip to content

Commit 50f8f62

Browse files
authored
Refactor the orchestration of AutoML calls (dotnet#272)
1 parent fe503d3 commit 50f8f62

File tree

5 files changed

+250
-178
lines changed

5 files changed

+250
-178
lines changed

src/mlnet/AutoML/AutoMLEngine.cs

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Data.DataView;
6+
using Microsoft.ML.Auto;
7+
using Microsoft.ML.CLI.Data;
8+
using Microsoft.ML.CLI.Utilities;
9+
using NLog;
10+
11+
namespace Microsoft.ML.CLI.CodeGenerator
12+
{
13+
internal class AutoMLEngine : IAutoMLEngine
14+
{
15+
private NewCommandSettings settings;
16+
private TaskKind taskKind;
17+
private static Logger logger = LogManager.GetCurrentClassLogger();
18+
19+
public AutoMLEngine(NewCommandSettings settings)
20+
{
21+
this.settings = settings;
22+
this.taskKind = Utils.GetTaskKind(settings.MlTask);
23+
}
24+
25+
public ColumnInferenceResults InferColumns(MLContext context)
26+
{
27+
//Check what overload method of InferColumns needs to be called.
28+
logger.Log(LogLevel.Info, Strings.InferColumns);
29+
ColumnInferenceResults columnInference = null;
30+
var dataset = settings.Dataset.FullName;
31+
if (settings.LabelColumnName != null)
32+
{
33+
columnInference = context.Auto().InferColumns(dataset, settings.LabelColumnName, groupColumns: false);
34+
}
35+
else
36+
{
37+
columnInference = context.Auto().InferColumns(dataset, settings.LabelColumnIndex, hasHeader: settings.HasHeader, groupColumns: false);
38+
}
39+
40+
return columnInference;
41+
}
42+
43+
(Pipeline, ITransformer) IAutoMLEngine.ExploreModels(MLContext context, IDataView trainData, IDataView validationData, string labelName)
44+
{
45+
ITransformer model = null;
46+
47+
Pipeline pipeline = null;
48+
49+
if (taskKind == TaskKind.BinaryClassification)
50+
{
51+
var progressReporter = new ProgressHandlers.BinaryClassificationHandler();
52+
var result = context.Auto()
53+
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
54+
{
55+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
56+
ProgressHandler = progressReporter
57+
})
58+
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
59+
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
60+
var bestIteration = result.Best();
61+
pipeline = bestIteration.Pipeline;
62+
model = bestIteration.Model;
63+
}
64+
65+
if (taskKind == TaskKind.Regression)
66+
{
67+
var progressReporter = new ProgressHandlers.RegressionHandler();
68+
var result = context.Auto()
69+
.CreateRegressionExperiment(new RegressionExperimentSettings()
70+
{
71+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
72+
ProgressHandler = progressReporter
73+
}).Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
74+
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
75+
var bestIteration = result.Best();
76+
pipeline = bestIteration.Pipeline;
77+
model = bestIteration.Model;
78+
}
79+
80+
if (taskKind == TaskKind.MulticlassClassification)
81+
{
82+
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler();
83+
84+
var experimentSettings = new MulticlassExperimentSettings()
85+
{
86+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
87+
ProgressHandler = progressReporter
88+
};
89+
90+
// Inclusion list for currently supported learners. Need to remove once we have codegen support for all other learners.
91+
experimentSettings.Trainers.Clear();
92+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LightGbm);
93+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LogisticRegression);
94+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.StochasticDualCoordinateAscent);
95+
96+
var result = context.Auto()
97+
.CreateMulticlassClassificationExperiment(experimentSettings)
98+
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
99+
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
100+
var bestIteration = result.Best();
101+
pipeline = bestIteration.Pipeline;
102+
model = bestIteration.Model;
103+
}
104+
105+
return (pipeline, model);
106+
}
107+
}
108+
}

src/mlnet/AutoML/IAutoMLEngine.cs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+

2+
// Licensed to the .NET Foundation under one or more agreements.
3+
// The .NET Foundation licenses this file to you under the MIT license.
4+
// See the LICENSE file in the project root for more information.
5+
6+
using Microsoft.Data.DataView;
7+
using Microsoft.ML.Auto;
8+
9+
namespace Microsoft.ML.CLI.CodeGenerator
10+
{
11+
internal interface IAutoMLEngine
12+
{
13+
ColumnInferenceResults InferColumns(MLContext context);
14+
15+
(Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData, string labelName);
16+
17+
}
18+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.IO;
7+
using Microsoft.Data.DataView;
8+
using Microsoft.ML.Auto;
9+
using Microsoft.ML.CLI.CodeGenerator.CSharp;
10+
using Microsoft.ML.CLI.Data;
11+
using Microsoft.ML.CLI.Utilities;
12+
using Microsoft.ML.Data;
13+
using NLog;
14+
15+
namespace Microsoft.ML.CLI.CodeGenerator
16+
{
17+
internal class CodeGenerationHelper
18+
{
19+
20+
private IAutoMLEngine automlEngine;
21+
private NewCommandSettings settings;
22+
private static Logger logger = LogManager.GetCurrentClassLogger();
23+
private TaskKind taskKind;
24+
public CodeGenerationHelper(IAutoMLEngine automlEngine, NewCommandSettings settings)
25+
{
26+
this.automlEngine = automlEngine;
27+
this.settings = settings;
28+
this.taskKind = Utils.GetTaskKind(settings.MlTask);
29+
}
30+
31+
public void GenerateCode()
32+
{
33+
var context = new MLContext();
34+
35+
// Infer columns
36+
ColumnInferenceResults columnInference = null;
37+
try
38+
{
39+
columnInference = automlEngine.InferColumns(context);
40+
}
41+
catch (Exception e)
42+
{
43+
logger.Log(LogLevel.Error, $"{Strings.InferColumnError}");
44+
logger.Log(LogLevel.Error, e.Message);
45+
logger.Log(LogLevel.Debug, e.ToString());
46+
logger.Log(LogLevel.Error, Strings.Exiting);
47+
return;
48+
}
49+
50+
// Sanitize columns
51+
Array.ForEach(columnInference.TextLoaderOptions.Columns, t => t.Name = Utils.Sanitize(t.Name));
52+
53+
var sanitizedLabelName = Utils.Sanitize(columnInference.ColumnInformation.LabelColumn);
54+
55+
// Load data
56+
(IDataView trainData, IDataView validationData) = LoadData(context, columnInference.TextLoaderOptions);
57+
58+
// Explore the models
59+
(Pipeline, ITransformer) result = default;
60+
Console.WriteLine($"{Strings.ExplorePipeline}: {settings.MlTask}");
61+
try
62+
{
63+
result = automlEngine.ExploreModels(context, trainData, validationData, sanitizedLabelName);
64+
}
65+
catch (Exception e)
66+
{
67+
logger.Log(LogLevel.Error, $"{Strings.ExplorePipelineException}:");
68+
logger.Log(LogLevel.Error, e.Message);
69+
logger.Log(LogLevel.Debug, e.ToString());
70+
logger.Log(LogLevel.Error, Strings.Exiting);
71+
return;
72+
}
73+
74+
//Get the best pipeline
75+
Pipeline pipeline = null;
76+
pipeline = result.Item1;
77+
var model = result.Item2;
78+
79+
// Save the model
80+
logger.Log(LogLevel.Info, Strings.SavingBestModel);
81+
var modelPath = new FileInfo(Path.Combine(settings.OutputPath.FullName, "model.zip"));
82+
Utils.SaveModel(model, modelPath, context);
83+
84+
// Generate the Project
85+
GenerateProject(columnInference, pipeline, sanitizedLabelName, modelPath);
86+
}
87+
88+
internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline, string labelName, FileInfo modelPath)
89+
{
90+
//Generate code
91+
logger.Log(LogLevel.Info, $"{Strings.GenerateProject} : {settings.OutputPath.FullName}");
92+
var codeGenerator = new CodeGenerator.CSharp.CodeGenerator(
93+
pipeline,
94+
columnInference,
95+
new CodeGeneratorSettings()
96+
{
97+
TrainDataset = settings.Dataset.FullName,
98+
MlTask = taskKind,
99+
TestDataset = settings.TestDataset?.FullName,
100+
OutputName = settings.Name,
101+
OutputBaseDir = settings.OutputPath.FullName,
102+
LabelName = labelName,
103+
ModelPath = modelPath.FullName
104+
});
105+
codeGenerator.GenerateOutput();
106+
}
107+
108+
internal (IDataView, IDataView) LoadData(MLContext context, TextLoader.Options textLoaderOptions)
109+
{
110+
logger.Log(LogLevel.Info, Strings.CreateDataLoader);
111+
var textLoader = context.Data.CreateTextLoader(textLoaderOptions);
112+
113+
logger.Log(LogLevel.Info, Strings.LoadData);
114+
var trainData = textLoader.Load(settings.Dataset.FullName);
115+
var validationData = settings.ValidationDataset == null ? null : textLoader.Load(settings.ValidationDataset.FullName);
116+
117+
return (trainData, validationData);
118+
}
119+
}
120+
}

0 commit comments

Comments
 (0)