Skip to content

Commit 2ede226

Browse files
Add AutoZero tuner to BinaryClassification (#6615)
* add autozero tuner * update portfolios.json * add tests
1 parent ebb5789 commit 2ede226

File tree

5 files changed

+2011
-7
lines changed

5 files changed

+2011
-7
lines changed

src/Microsoft.ML.AutoML/API/BinaryClassificationExperiment.cs

+23-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Threading;
1010
using System.Threading.Tasks;
1111
using Microsoft.Extensions.DependencyInjection;
12+
using Microsoft.ML.AutoML.Tuner;
1213
using Microsoft.ML.Data;
1314
using Microsoft.ML.Runtime;
1415
using Microsoft.ML.Trainers;
@@ -35,13 +36,19 @@ public sealed class BinaryExperimentSettings : ExperimentSettings
3536
/// <value>The default value is a collection auto-populated with all possible trainers (all values of <see cref="BinaryClassificationTrainer" />).</value>
3637
public ICollection<BinaryClassificationTrainer> Trainers { get; }
3738

39+
/// <summary>
40+
/// Set if use <see cref="AutoZeroTuner"/> for hyper-parameter optimization, default to false.
41+
/// </summary>
42+
public bool UseAutoZeroTuner { get; set; }
43+
3844
/// <summary>
3945
/// Initializes a new instance of <see cref="BinaryExperimentSettings"/>.
4046
/// </summary>
4147
public BinaryExperimentSettings()
4248
{
4349
OptimizingMetric = BinaryClassificationMetric.Accuracy;
4450
Trainers = Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType<BinaryClassificationTrainer>().ToList();
51+
UseAutoZeroTuner = false;
4552
}
4653
}
4754

@@ -133,7 +140,7 @@ public enum BinaryClassificationTrainer
133140
/// </example>
134141
public sealed class BinaryClassificationExperiment : ExperimentBase<BinaryClassificationMetrics, BinaryExperimentSettings>
135142
{
136-
private readonly AutoMLExperiment _experiment;
143+
private AutoMLExperiment _experiment;
137144
private const string Features = "__Features__";
138145
private SweepablePipeline _pipeline;
139146

@@ -151,13 +158,13 @@ internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSetti
151158
_experiment.SetMaximumMemoryUsageInMegaByte(d);
152159
}
153160
_experiment.SetMaxModelToExplore(settings.MaxModels);
161+
_experiment.SetTrainingTimeInSeconds(settings.MaxExperimentTimeInSeconds);
154162
}
155163

156164
public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView trainData, ColumnInformation columnInformation, IEstimator<ITransformer> preFeaturizer = null, IProgress<RunDetail<BinaryClassificationMetrics>> progressHandler = null)
157165
{
158166
var label = columnInformation.LabelColumnName;
159167
_experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
160-
_experiment.SetTrainingTimeInSeconds(Settings.MaxExperimentTimeInSeconds);
161168

162169
// Cross val threshold for # of dataset rows --
163170
// If dataset has < threshold # of rows, use cross val.
@@ -194,7 +201,7 @@ public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView
194201

195202
return monitor;
196203
});
197-
_experiment.SetTrialRunner<BinaryClassificationRunner>();
204+
_experiment = PostConfigureAutoMLExperiment(_experiment);
198205
_experiment.Run();
199206

200207
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline));
@@ -208,7 +215,6 @@ public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView
208215
{
209216
var label = columnInformation.LabelColumnName;
210217
_experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
211-
_experiment.SetTrainingTimeInSeconds(Settings.MaxExperimentTimeInSeconds);
212218
_experiment.SetDataset(trainData, validationData);
213219
_pipeline = CreateBinaryClassificationPipeline(trainData, columnInformation, preFeaturizer);
214220
_experiment.SetPipeline(_pipeline);
@@ -228,7 +234,7 @@ public override ExperimentResult<BinaryClassificationMetrics> Execute(IDataView
228234

229235
return monitor;
230236
});
231-
_experiment.SetTrialRunner<BinaryClassificationRunner>();
237+
_experiment = PostConfigureAutoMLExperiment(_experiment);
232238
_experiment.Run();
233239

234240
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToRunDetail(Context, e, _pipeline));
@@ -263,7 +269,6 @@ public override CrossValidationExperimentResult<BinaryClassificationMetrics> Exe
263269
{
264270
var label = columnInformation.LabelColumnName;
265271
_experiment.SetBinaryClassificationMetric(Settings.OptimizingMetric, label);
266-
_experiment.SetTrainingTimeInSeconds(Settings.MaxExperimentTimeInSeconds);
267272
_experiment.SetDataset(trainData, (int)numberOfCVFolds);
268273
_pipeline = CreateBinaryClassificationPipeline(trainData, columnInformation, preFeaturizer);
269274
_experiment.SetPipeline(_pipeline);
@@ -284,7 +289,7 @@ public override CrossValidationExperimentResult<BinaryClassificationMetrics> Exe
284289
return monitor;
285290
});
286291

287-
_experiment.SetTrialRunner<BinaryClassificationRunner>();
292+
_experiment = PostConfigureAutoMLExperiment(_experiment);
288293
_experiment.Run();
289294

290295
var runDetails = monitor.RunDetails.Select(e => BestResultUtil.ToCrossValidationRunDetail(Context, e, _pipeline));
@@ -335,6 +340,17 @@ private SweepablePipeline CreateBinaryClassificationPipeline(IDataView trainData
335340
.Append(Context.Auto().BinaryClassification(labelColumnName: columnInformation.LabelColumnName, useSdcaLogisticRegression: useSdca, useFastTree: useFastTree, useLgbm: useLgbm, useLbfgsLogisticRegression: uselbfgs, useFastForest: useFastForest, featureColumnName: Features));
336341
}
337342
}
343+
344+
private AutoMLExperiment PostConfigureAutoMLExperiment(AutoMLExperiment experiment)
345+
{
346+
experiment.SetTrialRunner<BinaryClassificationRunner>();
347+
if (Settings.UseAutoZeroTuner)
348+
{
349+
experiment.SetTuner<AutoZeroTuner>();
350+
}
351+
352+
return experiment;
353+
}
338354
}
339355

340356
internal class BinaryClassificationRunner : ITrialRunner

src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj

+6
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,12 @@
6868
<AdditionalFiles Include="CodeGen\*-estimators.json" />
6969
</ItemGroup>
7070

71+
<ItemGroup>
72+
<EmbeddedResource Include="Tuner\Portfolios.json">
73+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
74+
</EmbeddedResource>
75+
</ItemGroup>
76+
7177
<Target DependsOnTargets="ResolveReferences" Name="CopyProjectReferencesToPackage">
7278
<ItemGroup>
7379
<!--Include DLLs of Project References-->
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Globalization;
8+
using System.IO;
9+
using System.Linq;
10+
using System.Reflection;
11+
using System.Text;
12+
using System.Text.Json;
13+
using Microsoft.ML.AutoML.CodeGen;
14+
using Microsoft.ML.SearchSpace;
15+
16+
namespace Microsoft.ML.AutoML.Tuner
17+
{
18+
internal class AutoZeroTuner : ITuner
19+
{
20+
private readonly List<Config> _configs = new List<Config>();
21+
private readonly IEnumerator<Config> _configsEnumerator;
22+
private readonly Dictionary<string, string> _pipelineStrings;
23+
private readonly SweepablePipeline _sweepablePipeline;
24+
private readonly Dictionary<int, Config> _configLookBook = new Dictionary<int, Config>();
25+
private readonly string _metricName;
26+
27+
public AutoZeroTuner(SweepablePipeline pipeline, AggregateTrainingStopManager aggregateTrainingStopManager, IEvaluateMetricManager evaluateMetricManager, AutoMLExperiment.AutoMLExperimentSettings settings)
28+
{
29+
_configs = LoadConfigsFromJson();
30+
_sweepablePipeline = pipeline;
31+
_pipelineStrings = _sweepablePipeline.Schema.ToTerms().Select(t => new
32+
{
33+
schema = t.ToString(),
34+
pipelineString = string.Join("=>", t.ValueEntities().Select(e => _sweepablePipeline.Estimators[e.ToString()].EstimatorType)),
35+
}).ToDictionary(kv => kv.schema, kv => kv.pipelineString);
36+
37+
// todo
38+
// filter configs on trainers
39+
var trainerEstimators = _sweepablePipeline.Estimators.Where(e => e.Value.EstimatorType.IsTrainer()).Select(e => e.Value.EstimatorType.ToString()).ToList();
40+
_configs = evaluateMetricManager switch
41+
{
42+
BinaryMetricManager => _configs.Where(c => c.Task == "binary-classification" && trainerEstimators.Contains(c.Trainer)).ToList(),
43+
MultiClassMetricManager => _configs.Where(c => c.Task == "multi-classification" && trainerEstimators.Contains(c.Trainer)).ToList(),
44+
RegressionMetricManager => _configs.Where(c => c.Task == "regression" && trainerEstimators.Contains(c.Trainer)).ToList(),
45+
_ => throw new Exception(),
46+
};
47+
_metricName = evaluateMetricManager switch
48+
{
49+
BinaryMetricManager bm => bm.Metric.ToString(),
50+
MultiClassMetricManager mm => mm.Metric.ToString(),
51+
RegressionMetricManager rm => rm.Metric.ToString(),
52+
_ => throw new Exception(),
53+
};
54+
55+
if (_configs.Count == 0)
56+
{
57+
throw new ArgumentException($"Fail to find available configs for given trainers: {string.Join(",", trainerEstimators)}");
58+
}
59+
60+
_configsEnumerator = _configs.GetEnumerator();
61+
aggregateTrainingStopManager.AddTrainingStopManager(new MaxModelStopManager(_configs.Count, null));
62+
}
63+
64+
private List<Config> LoadConfigsFromJson()
65+
{
66+
var assembly = Assembly.GetExecutingAssembly();
67+
var resourceName = "Microsoft.ML.AutoML.Tuner.Portfolios.json";
68+
69+
using (Stream stream = assembly.GetManifestResourceStream(resourceName))
70+
using (StreamReader reader = new StreamReader(stream))
71+
{
72+
var json = reader.ReadToEnd();
73+
var res = JsonSerializer.Deserialize<List<Config>>(json);
74+
75+
return res;
76+
}
77+
}
78+
79+
public Parameter Propose(TrialSettings settings)
80+
{
81+
if (_configsEnumerator.MoveNext())
82+
{
83+
var config = _configsEnumerator.Current;
84+
IEnumerable<KeyValuePair<string, string>> pipelineSchemas = default;
85+
if (_pipelineStrings.Any(kv => kv.Value.Contains("OneHotHashEncoding") || kv.Value.Contains("OneHotEncoding")))
86+
{
87+
pipelineSchemas = _pipelineStrings.Where(kv => kv.Value.Contains(config.CatalogTransformer));
88+
}
89+
else
90+
{
91+
pipelineSchemas = _pipelineStrings;
92+
}
93+
94+
pipelineSchemas = pipelineSchemas.Where(kv => kv.Value.Contains(config.Trainer));
95+
var pipelineSchema = pipelineSchemas.First().Key;
96+
var pipeline = _sweepablePipeline.BuildSweepableEstimatorPipeline(pipelineSchema);
97+
var parameter = pipeline.SearchSpace.SampleFromFeatureSpace(pipeline.SearchSpace.Default);
98+
var trainerEstimatorName = pipeline.Estimators.Where(kv => kv.Value.EstimatorType.IsTrainer()).First().Key;
99+
var label = parameter[trainerEstimatorName]["LabelColumnName"].AsType<string>();
100+
var feature = parameter[trainerEstimatorName]["FeatureColumnName"].AsType<string>();
101+
parameter[trainerEstimatorName] = config.TrainerParameter;
102+
parameter[trainerEstimatorName]["LabelColumnName"] = Parameter.FromString(label);
103+
parameter[trainerEstimatorName]["FeatureColumnName"] = Parameter.FromString(feature);
104+
settings.Parameter[AutoMLExperiment.PipelineSearchspaceName] = parameter;
105+
_configLookBook[settings.TrialId] = config;
106+
return settings.Parameter;
107+
}
108+
109+
throw new OperationCanceledException();
110+
}
111+
112+
public void Update(TrialResult result)
113+
{
114+
}
115+
116+
class Config
117+
{
118+
/// <summary>
119+
/// one of OneHot, HashEncoding
120+
/// </summary>
121+
public string CatalogTransformer { get; set; }
122+
123+
/// <summary>
124+
/// One of Lgbm, Sdca, FastTree,,,
125+
/// </summary>
126+
public string Trainer { get; set; }
127+
128+
public Parameter TrainerParameter { get; set; }
129+
130+
public string Task { get; set; }
131+
}
132+
133+
class Rows
134+
{
135+
public string CustomDimensionsBestPipeline { get; set; }
136+
137+
public string CustomDimensionsOptionsTask { get; set; }
138+
139+
public Parameter CustomDimensionsParameter { get; set; }
140+
}
141+
}
142+
}

0 commit comments

Comments
 (0)