Skip to content

Commit fe503d3

Browse files
authored
Initial work for multi-class classification support for CLI (dotnet#226)
* Initial work for multi-class classification support for CLI * String updates * more strings * Whitelist non-OVA multi-class learners
1 parent 3ad0798 commit fe503d3

18 files changed

+689
-738
lines changed

src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Foundation under one or more agreements.
1+
// Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Foundation under one or more agreements.
1+
// Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

@@ -98,7 +98,7 @@ internal IEnumerable<RunResult<MultiClassClassifierMetrics>> Execute(MLContext c
9898
columnInfo, validationData, preFeaturizers, new OptimizingMetricInfo(_settings.OptimizingMetric),
9999
_settings.ProgressHandler, _settings, new MultiMetricsAgent(_settings.OptimizingMetric),
100100
TrainerExtensionUtil.GetTrainerNames(_settings.Trainers));
101-
101+
102102
return experiment.Execute();
103103
}
104104
}

src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

src/Samples/AutoTrainMulticlassClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

src/Test/AutoFitTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Licensed to the .NET Foundation under one or more agreements.
1+
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedHelperCodeTest.approved.txt

+46
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,42 @@ namespace MyNamespace
9393

9494
}
9595

96+
public static void PrintMulticlassClassificationFoldsAverageMetrics(
97+
string algorithmName,
98+
TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResults)
99+
{
100+
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
101+
102+
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMicro);
103+
var microAccuracyAverage = microAccuracyValues.Average();
104+
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
105+
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
106+
107+
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);
108+
var macroAccuracyAverage = macroAccuracyValues.Average();
109+
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
110+
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
111+
112+
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
113+
var logLossAverage = logLossValues.Average();
114+
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
115+
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
116+
117+
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
118+
var logLossReductionAverage = logLossReductionValues.Average();
119+
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
120+
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
121+
122+
Console.WriteLine($"*************************************************************************************************************");
123+
Console.WriteLine($"* Metrics for {algorithmName} Multi-class Classification model ");
124+
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
125+
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
126+
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
127+
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
128+
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
129+
Console.WriteLine($"*************************************************************************************************************");
130+
131+
}
96132

97133
public static double CalculateStandardDeviation(IEnumerable<double> values)
98134
{
@@ -108,6 +144,16 @@ namespace MyNamespace
108144
return confidenceInterval95;
109145
}
110146

147+
public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
148+
{
149+
Console.WriteLine($"*************************************************");
150+
Console.WriteLine($"* Metrics for {name} clustering model ");
151+
Console.WriteLine($"*------------------------------------------------");
152+
Console.WriteLine($"* AvgMinScore: {metrics.AvgMinScore}");
153+
Console.WriteLine($"* DBI is: {metrics.Dbi}");
154+
Console.WriteLine($"*************************************************");
155+
}
156+
111157
public static void ConsoleWriteHeader(params string[] lines)
112158
{
113159
var defaultColor = Console.ForegroundColor;

src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ internal static ITrainerGenerator GetInstance(Pipeline pipeline)
2323
throw new ArgumentNullException(nameof(pipeline));
2424
var node = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer).First();
2525
if (node == null)
26-
return null;
26+
throw new ArgumentException($"The trainer was not found.");
2727
if (Enum.TryParse(node.Name, out TrainerName trainer))
2828
{
2929
switch (trainer)
@@ -67,10 +67,10 @@ internal static ITrainerGenerator GetInstance(Pipeline pipeline)
6767
case TrainerName.SymSgdBinary:
6868
return new SymbolicStochasticGradientDescent(node);
6969
default:
70-
return null;
70+
throw new ArgumentException($"The trainer '{trainer}' is not handled currently.");
7171
}
7272
}
73-
return null;
73+
throw new ArgumentException($"The trainer '{node.Name}' is not handled currently.");
7474
}
7575
}
7676
}

src/mlnet/Commands/CommandDefinitions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ Option HasHeader() =>
103103

104104
private static string[] GetMlTaskSuggestions()
105105
{
106-
return new[] { "binary-classification", "regression" };
106+
return new[] { "binary-classification", "multiclass-classification", "regression" };
107107
}
108108

109109
private static string[] GetVerbositySuggestions()

src/mlnet/Commands/New/NewCommandHandler.cs

+20-2
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,27 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p
160160

161161
if (taskKind == TaskKind.MulticlassClassification)
162162
{
163-
throw new NotImplementedException();
163+
var progressReporter = new ProgressHandlers.MulticlassClassificationHandler();
164+
165+
var experimentSettings = new MulticlassExperimentSettings()
166+
{
167+
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
168+
ProgressHandler = progressReporter
169+
};
170+
171+
experimentSettings.Trainers.Clear();
172+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LightGbm);
173+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.LogisticRegression);
174+
experimentSettings.Trainers.Add(MulticlassClassificationTrainer.StochasticDualCoordinateAscent);
175+
176+
var result = context.Auto()
177+
.CreateMulticlassClassificationExperiment(experimentSettings)
178+
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
179+
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
180+
var bestIteration = result.Best();
181+
pipeline = bestIteration.Pipeline;
182+
model = bestIteration.Model;
164183
}
165-
//Multi-class exploration here
166184

167185
return (pipeline, model);
168186
}

src/mlnet/Strings.resx

+3
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@
150150
<data name="MetricsForRegressionModels" xml:space="preserve">
151151
<value>Metrics for regression models</value>
152152
</data>
153+
<data name="MetricsForMulticlassModels" xml:space="preserve">
154+
<value>Metrics for multi-class models</value>
155+
</data>
153156
<data name="RetrieveBestPipeline" xml:space="preserve">
154157
<value>Retrieving best pipeline ...</value>
155158
</data>

0 commit comments

Comments
 (0)