Skip to content

Commit 4a86939

Browse files
abgoswameerhardt
authored andcommitted
PipelineSweeperMacro for Multi-Class Classification (dotnet#539)
* failing test case for multiclass * Refactored PipelineSweeperSupportedMetrics Class; added unit test for MultiClassClassification; refactored out unit tests for the PipelineSweeper * take care of review comments; display transforms/learners + metrics in pipeline * taking care of PR comments + refactor PipelineSweeperRunSummary * taking care of review comments
1 parent 598b05f commit 4a86939

File tree

10 files changed

+895
-703
lines changed

10 files changed

+895
-703
lines changed

src/Microsoft.ML.PipelineInference/AutoInference.cs

Lines changed: 20 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -51,67 +51,6 @@ public class LevelDependencyMap : Dictionary<ColumnInfo, List<TransformInference
5151
/// </summary>
5252
public class DependencyMap : Dictionary<int, LevelDependencyMap> { }
5353

54-
/// <summary>
55-
/// AutoInference will support metrics as they are added here.
56-
/// </summary>
57-
public sealed class SupportedMetric
58-
{
59-
public static readonly SupportedMetric Auc = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Auc, true);
60-
public static readonly SupportedMetric AccuracyMicro = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMicro, true);
61-
public static readonly SupportedMetric AccuracyMacro = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AccuracyMacro, true);
62-
public static readonly SupportedMetric L1 = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L1, false);
63-
public static readonly SupportedMetric L2 = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.L2, false);
64-
public static readonly SupportedMetric F1 = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.F1, true);
65-
public static readonly SupportedMetric AuPrc = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AuPrc, true);
66-
public static readonly SupportedMetric TopKAccuracy = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.TopKAccuracy, true);
67-
public static readonly SupportedMetric Rms = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Rms, false);
68-
public static readonly SupportedMetric LossFn = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LossFn, false);
69-
public static readonly SupportedMetric RSquared = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.RSquared, false);
70-
public static readonly SupportedMetric LogLoss = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLoss, false);
71-
public static readonly SupportedMetric LogLossReduction = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.LogLossReduction, true);
72-
public static readonly SupportedMetric Ndcg = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Ndcg, true);
73-
public static readonly SupportedMetric Dcg = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dcg, true);
74-
public static readonly SupportedMetric PositivePrecision = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositivePrecision, true);
75-
public static readonly SupportedMetric PositiveRecall = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.PositiveRecall, true);
76-
public static readonly SupportedMetric NegativePrecision = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativePrecision, true);
77-
public static readonly SupportedMetric NegativeRecall = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NegativeRecall, true);
78-
public static readonly SupportedMetric DrAtK = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtK, true);
79-
public static readonly SupportedMetric DrAtPFpr = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtPFpr, true);
80-
public static readonly SupportedMetric DrAtNumPos = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.DrAtNumPos, true);
81-
public static readonly SupportedMetric NumAnomalies = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.NumAnomalies, true);
82-
public static readonly SupportedMetric ThreshAtK = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtK, false);
83-
public static readonly SupportedMetric ThreshAtP = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtP, false);
84-
public static readonly SupportedMetric ThreshAtNumPos = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.ThreshAtNumPos, false);
85-
public static readonly SupportedMetric Nmi = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Nmi, true);
86-
public static readonly SupportedMetric AvgMinScore = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.AvgMinScore, false);
87-
public static readonly SupportedMetric Dbi = new SupportedMetric(FieldNames.PipelineSweeperSupportedMetrics.Dbi, false);
88-
89-
public string Name { get; }
90-
public bool IsMaximizing { get; }
91-
92-
private SupportedMetric(string name, bool isMaximizing)
93-
{
94-
Name = name;
95-
IsMaximizing = isMaximizing;
96-
}
97-
98-
public static SupportedMetric ByName(string name)
99-
{
100-
var fields =
101-
typeof(SupportedMetric).GetFields(BindingFlags.Static | BindingFlags.Public);
102-
103-
foreach (var field in fields)
104-
{
105-
var metric = (SupportedMetric)field.GetValue(Auc);
106-
if (name.Equals(metric.Name, StringComparison.OrdinalIgnoreCase))
107-
return metric;
108-
}
109-
throw new NotSupportedException($"Metric '{name}' not supported.");
110-
}
111-
112-
public override string ToString() => Name;
113-
}
114-
11554
/// <summary>
11655
/// Class for encapsulating an entrypoint experiment graph
11756
/// and keeping track of the input and output nodes.
@@ -167,26 +106,6 @@ private bool GetDataVariableName(IExceptionContext ectx, string nameOfData, JTok
167106
}
168107
}
169108

170-
/// <summary>
171-
/// Class containing some information about an exectuted pipeline.
172-
/// These are analogous to IRunResult for smart sweepers.
173-
/// </summary>
174-
public sealed class RunSummary
175-
{
176-
public double MetricValue { get; }
177-
public double TrainingMetricValue { get; }
178-
public int NumRowsInTraining { get; }
179-
public long RunTimeMilliseconds { get; }
180-
181-
public RunSummary(double metricValue, int numRows, long runTimeMilliseconds, double trainingMetricValue)
182-
{
183-
MetricValue = metricValue;
184-
TrainingMetricValue = trainingMetricValue;
185-
NumRowsInTraining = numRows;
186-
RunTimeMilliseconds = runTimeMilliseconds;
187-
}
188-
}
189-
190109
[TlcModule.ComponentKind("AutoMlStateBase")]
191110
public interface ISupportAutoMlStateFactory : IComponentFactory<IMlState>
192111
{ }
@@ -218,42 +137,8 @@ public sealed class AutoMlMlState : IMlState
218137
Desc = "State of an AutoML search and search space.")]
219138
public sealed class Arguments : ISupportAutoMlStateFactory
220139
{
221-
// REVIEW: These should be the same as SupportedMetrics above. Not sure how to reference that class,
222-
// without the C# API generator trying to create a version of that class in the API as well.
223-
public enum Metrics
224-
{
225-
Auc,
226-
AccuracyMicro,
227-
AccuracyMacro,
228-
L2,
229-
F1,
230-
AuPrc,
231-
TopKAccuracy,
232-
Rms,
233-
LossFn,
234-
RSquared,
235-
LogLoss,
236-
LogLossReduction,
237-
Ndcg,
238-
Dcg,
239-
PositivePrecision,
240-
PositiveRecall,
241-
NegativePrecision,
242-
NegativeRecall,
243-
DrAtK,
244-
DrAtPFpr,
245-
DrAtNumPos,
246-
NumAnomalies,
247-
ThreshAtK,
248-
ThreshAtP,
249-
ThreshAtNumPos,
250-
Nmi,
251-
AvgMinScore,
252-
Dbi
253-
};
254-
255140
[Argument(ArgumentType.Required, HelpText = "Supported metric for evaluator.", ShortName = "metric")]
256-
public Metrics Metric;
141+
public PipelineSweeperSupportedMetrics.Metrics Metric;
257142

258143
[Argument(ArgumentType.Required, HelpText = "AutoML engine (pipeline optimizer) that generates next candidates.", ShortName = "engine")]
259144
public ISupportIPipelineOptimizerFactory Engine;
@@ -271,7 +156,9 @@ public enum Metrics
271156
}
272157

273158
public AutoMlMlState(IHostEnvironment env, Arguments args)
274-
: this(env, SupportedMetric.ByName(Enum.GetName(typeof(Arguments.Metrics), args.Metric)), args.Engine.CreateComponent(env),
159+
: this(env,
160+
PipelineSweeperSupportedMetrics.GetSupportedMetric(args.Metric),
161+
args.Engine.CreateComponent(env),
275162
args.TerminatorArgs.CreateComponent(env), args.TrainerKind, requestedLearners: args.RequestedLearners)
276163
{
277164
}
@@ -355,8 +242,7 @@ private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, S
355242
testMetricVal += 1e-10;
356243

357244
// Save performance score
358-
candidate.PerformanceSummary =
359-
new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
245+
candidate.PerformanceSummary = new PipelineSweeperRunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
360246
_sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate);
361247
_history.Add(candidate);
362248
}
@@ -524,6 +410,21 @@ public void AddEvaluated(PipelinePattern pipeline)
524410
d += 1e-3;
525411
_sortedSampledElements.Add(d, pipeline);
526412
_history.Add(pipeline);
413+
414+
using (var ch = _host.Start("Suggested Pipeline"))
415+
{
416+
ch.Info($"PipelineSweeper Iteration Number : {_history.Count}");
417+
ch.Info($"PipelineSweeper Pipeline Id : {pipeline.UniqueId}");
418+
419+
foreach (var transform in pipeline.Transforms)
420+
{
421+
ch.Info($"PipelineSweeper Transform : {transform.Transform}");
422+
}
423+
424+
ch.Info($"PipelineSweeper Learner : {pipeline.Learner}");
425+
ch.Info($"PipelineSweeper Train Metric Value : {pipeline.PerformanceSummary.TrainingMetricValue}");
426+
ch.Info($"PipelineSweeper Test Metric Value : {pipeline.PerformanceSummary.MetricValue}");
427+
}
527428
}
528429

529430
public void AddEvaluated(PipelinePattern[] pipelines)
@@ -541,19 +442,6 @@ public PipelinePattern[] GetNextCandidates(int numberOfCandidates)
541442
currentBatchSize = Math.Min(itr.RemainingIterations(_history), numberOfCandidates);
542443
BatchCandidates = AutoMlEngine.GetNextCandidates(_sortedSampledElements.Select(kvp => kvp.Value), currentBatchSize, _dataRoles);
543444

544-
using (var ch = _host.Start("Suggested Pipeline"))
545-
{
546-
foreach (var pipeline in BatchCandidates)
547-
{
548-
ch.Info($"AutoInference Pipeline Id : {pipeline.UniqueId}");
549-
foreach (var transform in pipeline.Transforms)
550-
{
551-
ch.Info($"AutoInference Transform : {transform.Transform}");
552-
}
553-
ch.Info($"AutoInference Learner : {pipeline.Learner}");
554-
}
555-
}
556-
557445
return BatchCandidates;
558446
}
559447

src/Microsoft.ML.PipelineInference/AutoMlUtils.cs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,15 @@ public static double ExtractValueFromIdv(IHostEnvironment env, IDataView result,
3838
return outputValue;
3939
}
4040

41-
public static AutoInference.RunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null)
41+
public static PipelineSweeperRunSummary ExtractRunSummary(IHostEnvironment env, IDataView result, string metricColumnName, IDataView trainResult = null)
4242
{
43+
Contracts.CheckValue(env, nameof(env));
44+
env.CheckValue(result, nameof(result));
45+
env.CheckNonEmpty(metricColumnName, nameof(metricColumnName));
46+
4347
double testingMetricValue = ExtractValueFromIdv(env, result, metricColumnName);
4448
double trainingMetricValue = trainResult != null ? ExtractValueFromIdv(env, trainResult, metricColumnName) : double.MinValue;
45-
return new AutoInference.RunSummary(testingMetricValue, 0, 0, trainingMetricValue);
49+
return new PipelineSweeperRunSummary(testingMetricValue, 0, 0, trainingMetricValue);
4650
}
4751

4852
public static CommonInputs.IEvaluatorInput CloneEvaluatorInstance(CommonInputs.IEvaluatorInput evalInput) =>
@@ -566,14 +570,15 @@ private static ParameterSet ConvertToParameterSet(TlcModule.SweepableParamAttrib
566570
return learner.PipelineNode.HyperSweeperParamSet;
567571
}
568572

569-
public static IRunResult ConvertToRunResult(RecipeInference.SuggestedRecipe.SuggestedLearner learner,
570-
AutoInference.RunSummary rs, bool isMetricMaximizing) =>
571-
new RunResult(ConvertToParameterSet(learner.PipelineNode.SweepParams, learner), rs.MetricValue, isMetricMaximizing);
572-
573-
public static IRunResult[] ConvertToRunResults(PipelinePattern[] history, bool isMetricMaximizing) =>
574-
history.Select(h =>
575-
ConvertToRunResult(h.Learner, h.PerformanceSummary, isMetricMaximizing)).ToArray();
573+
public static IRunResult ConvertToRunResult(RecipeInference.SuggestedRecipe.SuggestedLearner learner, PipelineSweeperRunSummary rs, bool isMetricMaximizing)
574+
{
575+
return new RunResult(ConvertToParameterSet(learner.PipelineNode.SweepParams, learner), rs.MetricValue, isMetricMaximizing);
576+
}
576577

578+
public static IRunResult[] ConvertToRunResults(PipelinePattern[] history, bool isMetricMaximizing)
579+
{
580+
return history.Select(h => ConvertToRunResult(h.Learner, h.PerformanceSummary, isMetricMaximizing)).ToArray();
581+
}
577582
/// <summary>
578583
/// Method to convert set of sweepable hyperparameters into strings of a format understood
579584
/// by the current smart hyperparameter sweepers.

src/Microsoft.ML.PipelineInference/Macros/PipelineSweeperMacro.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ public static CommonOutputs.MacroOutput<Output> PipelineSweep(
239239
if (node.Context.TryGetVariable(ExperimentUtils.GenerateOverallMetricVarName(pipeline.UniqueId), out var v) &&
240240
node.Context.TryGetVariable(AutoMlUtils.GenerateOverallTrainingMetricVarName(pipeline.UniqueId), out var v2))
241241
{
242-
pipeline.PerformanceSummary =
243-
AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name, (IDataView)v2.Value);
242+
pipeline.PerformanceSummary = AutoMlUtils.ExtractRunSummary(env, (IDataView)v.Value, autoMlState.Metric.Name, (IDataView)v2.Value);
244243
autoMlState.AddEvaluated(pipeline);
245244
}
246245
}

src/Microsoft.ML.PipelineInference/PipelinePattern.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ public PipelineResultRow(string graphJson, double metricValue,
5555
private readonly IHostEnvironment _env;
5656
public readonly TransformInference.SuggestedTransform[] Transforms;
5757
public readonly RecipeInference.SuggestedRecipe.SuggestedLearner Learner;
58-
public AutoInference.RunSummary PerformanceSummary { get; set; }
58+
public PipelineSweeperRunSummary PerformanceSummary { get; set; }
5959
public string LoaderSettings { get; set; }
6060
public Guid UniqueId { get; }
6161

6262
public PipelinePattern(TransformInference.SuggestedTransform[] transforms,
6363
RecipeInference.SuggestedRecipe.SuggestedLearner learner,
64-
string loaderSettings, IHostEnvironment env, AutoInference.RunSummary summary = null)
64+
string loaderSettings, IHostEnvironment env, PipelineSweeperRunSummary summary = null)
6565
{
6666
// Make sure internal pipeline nodes and sweep params are cloned, not shared.
6767
// Cloning the transforms and learner rather than assigning outright
@@ -205,7 +205,7 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
205205
/// Runs a train-test experiment on the current pipeline, through entrypoints.
206206
/// </summary>
207207
public void RunTrainTestExperiment(IDataView trainData, IDataView testData,
208-
AutoInference.SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue,
208+
SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue,
209209
out double trainMetricValue)
210210
{
211211
var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, true, out var trainTestOutput);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Diagnostics;
8+
using System.Linq;
9+
using System.Reflection;
10+
using Microsoft.ML.Runtime.CommandLine;
11+
using Microsoft.ML.Runtime.EntryPoints;
12+
using Microsoft.ML.Runtime.Data;
13+
using Microsoft.ML.Runtime.PipelineInference;
14+
using Microsoft.ML.Runtime.EntryPoints.JsonUtils;
15+
using Newtonsoft.Json.Linq;
16+
17+
namespace Microsoft.ML.Runtime.PipelineInference
18+
{
19+
/// <summary>
20+
/// Class containing some information about an exectuted pipeline.
21+
/// These are analogous to IRunResult for smart sweepers.
22+
/// </summary>
23+
public sealed class PipelineSweeperRunSummary
24+
{
25+
public double MetricValue { get; }
26+
public double TrainingMetricValue { get; }
27+
public int NumRowsInTraining { get; }
28+
public long RunTimeMilliseconds { get; }
29+
30+
public PipelineSweeperRunSummary(double metricValue, int numRows, long runTimeMilliseconds, double trainingMetricValue)
31+
{
32+
MetricValue = metricValue;
33+
TrainingMetricValue = trainingMetricValue;
34+
NumRowsInTraining = numRows;
35+
RunTimeMilliseconds = runTimeMilliseconds;
36+
}
37+
}
38+
}

0 commit comments

Comments
 (0)