Skip to content

Commit 6162944

Browse files
srsaggamDmitry-A
authored andcommitted
Console helper bug in generated code for multiclass (dotnet#323)
* fix * fix test * looping perlogclass * fix test
1 parent 14f9d17 commit 6162944

File tree

3 files changed

+59
-59
lines changed

3 files changed

+59
-59
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt

+4-3
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ namespace TestNamespace.Train
8484
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
8585
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
8686
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
87-
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
88-
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
89-
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
87+
for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)
88+
{
89+
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
90+
}
9091
Console.WriteLine($"************************************************************");
9192
}
9293

src/mlnet/Templates/Console/ConsoleHelper.cs

+51-53
Original file line numberDiff line numberDiff line change
@@ -104,61 +104,59 @@ namespace ");
104104
"ole.WriteLine($\" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value betw" +
105105
"een 0 and 1, the closer to 1, the better\");\r\n Console.WriteLine($\" " +
106106
" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better\");\r\n " +
107-
" Console.WriteLine($\" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.###" +
108-
"#}, the closer to 0, the better\");\r\n Console.WriteLine($\" LogLoss " +
109-
"for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better\")" +
110-
";\r\n Console.WriteLine($\" LogLoss for class 3 = {metrics.PerClassLo" +
111-
"gLoss[2]:0.####}, the closer to 0, the better\");\r\n Console.WriteLine(" +
112-
"$\"************************************************************\");\r\n }\r\n\r\n" +
113-
" public static void PrintMulticlassClassificationFoldsAverageMetrics(Trai" +
114-
"nCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[] crossValResult" +
115-
"s)\r\n {\r\n var metricsInMultipleFolds = crossValResults.Select(r" +
116-
" => r.Metrics);\r\n\r\n var microAccuracyValues = metricsInMultipleFolds." +
117-
"Select(m => m.AccuracyMicro);\r\n var microAccuracyAverage = microAccur" +
118-
"acyValues.Average();\r\n var microAccuraciesStdDeviation = CalculateSta" +
119-
"ndardDeviation(microAccuracyValues);\r\n var microAccuraciesConfidenceI" +
120-
"nterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r\n\r\n v" +
121-
"ar macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);\r\n " +
122-
" var macroAccuracyAverage = macroAccuracyValues.Average();\r\n " +
123-
" var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValu" +
124-
"es);\r\n var macroAccuraciesConfidenceInterval95 = CalculateConfidenceI" +
125-
"nterval95(macroAccuracyValues);\r\n\r\n var logLossValues = metricsInMult" +
126-
"ipleFolds.Select(m => m.LogLoss);\r\n var logLossAverage = logLossValue" +
127-
"s.Average();\r\n var logLossStdDeviation = CalculateStandardDeviation(l" +
128-
"ogLossValues);\r\n var logLossConfidenceInterval95 = CalculateConfidenc" +
129-
"eInterval95(logLossValues);\r\n\r\n var logLossReductionValues = metricsI" +
130-
"nMultipleFolds.Select(m => m.LogLossReduction);\r\n var logLossReductio" +
131-
"nAverage = logLossReductionValues.Average();\r\n var logLossReductionSt" +
132-
"dDeviation = CalculateStandardDeviation(logLossReductionValues);\r\n va" +
133-
"r logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossRe" +
134-
"ductionValues);\r\n\r\n Console.WriteLine($\"*****************************" +
107+
" for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)\r\n {\r\n " +
108+
" Console.WriteLine($\" LogLoss for class {i + 1} = {metrics.PerClassL" +
109+
"ogLoss[i]:0.####}, the closer to 0, the better\");\r\n }\r\n Co" +
110+
"nsole.WriteLine($\"************************************************************\")" +
111+
";\r\n }\r\n\r\n public static void PrintMulticlassClassificationFoldsAve" +
112+
"rageMetrics(TrainCatalogBase.CrossValidationResult<MultiClassClassifierMetrics>[" +
113+
"] crossValResults)\r\n {\r\n var metricsInMultipleFolds = crossVal" +
114+
"Results.Select(r => r.Metrics);\r\n\r\n var microAccuracyValues = metrics" +
115+
"InMultipleFolds.Select(m => m.AccuracyMicro);\r\n var microAccuracyAver" +
116+
"age = microAccuracyValues.Average();\r\n var microAccuraciesStdDeviatio" +
117+
"n = CalculateStandardDeviation(microAccuracyValues);\r\n var microAccur" +
118+
"aciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);\r" +
119+
"\n\r\n var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.Ac" +
120+
"curacyMacro);\r\n var macroAccuracyAverage = macroAccuracyValues.Averag" +
121+
"e();\r\n var macroAccuraciesStdDeviation = CalculateStandardDeviation(m" +
122+
"acroAccuracyValues);\r\n var macroAccuraciesConfidenceInterval95 = Calc" +
123+
"ulateConfidenceInterval95(macroAccuracyValues);\r\n\r\n var logLossValues" +
124+
" = metricsInMultipleFolds.Select(m => m.LogLoss);\r\n var logLossAverag" +
125+
"e = logLossValues.Average();\r\n var logLossStdDeviation = CalculateSta" +
126+
"ndardDeviation(logLossValues);\r\n var logLossConfidenceInterval95 = Ca" +
127+
"lculateConfidenceInterval95(logLossValues);\r\n\r\n var logLossReductionV" +
128+
"alues = metricsInMultipleFolds.Select(m => m.LogLossReduction);\r\n var" +
129+
" logLossReductionAverage = logLossReductionValues.Average();\r\n var lo" +
130+
"gLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);" +
131+
"\r\n var logLossReductionConfidenceInterval95 = CalculateConfidenceInte" +
132+
"rval95(logLossReductionValues);\r\n\r\n Console.WriteLine($\"*************" +
135133
"********************************************************************************" +
136-
"\");\r\n Console.WriteLine($\"* Metrics for Multi-class Classificat" +
137-
"ion model \");\r\n Console.WriteLine($\"*---------------------------" +
134+
"****************\");\r\n Console.WriteLine($\"* Metrics for Multi-c" +
135+
"lass Classification model \");\r\n Console.WriteLine($\"*-----------" +
138136
"--------------------------------------------------------------------------------" +
139-
"-\");\r\n Console.WriteLine($\"* Average MicroAccuracy: {microAc" +
140-
"curacyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}" +
141-
") - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})\");\r\n" +
142-
" Console.WriteLine($\"* Average MacroAccuracy: {macroAccuracy" +
143-
"Average:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - C" +
144-
"onfidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})\");\r\n " +
145-
" Console.WriteLine($\"* Average LogLoss: {logLossAverage:#.##" +
146-
"#} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 9" +
147-
"5%: ({logLossConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"* " +
148-
" Average LogLossReduction: {logLossReductionAverage:#.###} - Standard devi" +
149-
"ation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logL" +
150-
"ossReductionConfidenceInterval95:#.###})\");\r\n Console.WriteLine($\"***" +
151-
"********************************************************************************" +
152-
"**************************\");\r\n\r\n }\r\n\r\n public static double Calcu" +
153-
"lateStandardDeviation(IEnumerable<double> values)\r\n {\r\n double" +
154-
" average = values.Average();\r\n double sumOfSquaresOfDifferences = val" +
155-
"ues.Select(val => (val - average) * (val - average)).Sum();\r\n double " +
156-
"standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));" +
157-
"\r\n return standardDeviation;\r\n }\r\n\r\n public static doub" +
158-
"le CalculateConfidenceInterval95(IEnumerable<double> values)\r\n {\r\n " +
159-
" double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / M" +
160-
"ath.Sqrt((values.Count() - 1));\r\n return confidenceInterval95;\r\n " +
161-
" }\r\n }\r\n}\r\n");
137+
"-----------------\");\r\n Console.WriteLine($\"* Average MicroAccur" +
138+
"acy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStd" +
139+
"Deviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterva" +
140+
"l95:#.###})\");\r\n Console.WriteLine($\"* Average MacroAccuracy: " +
141+
" {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviat" +
142+
"ion:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#." +
143+
"###})\");\r\n Console.WriteLine($\"* Average LogLoss: {log" +
144+
"LossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confi" +
145+
"dence Interval 95%: ({logLossConfidenceInterval95:#.###})\");\r\n Consol" +
146+
"e.WriteLine($\"* Average LogLossReduction: {logLossReductionAverage:#.###} " +
147+
" - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Inte" +
148+
"rval 95%: ({logLossReductionConfidenceInterval95:#.###})\");\r\n Console" +
149+
".WriteLine($\"*******************************************************************" +
150+
"******************************************\");\r\n\r\n }\r\n\r\n public sta" +
151+
"tic double CalculateStandardDeviation(IEnumerable<double> values)\r\n {\r\n " +
152+
" double average = values.Average();\r\n double sumOfSquaresOfD" +
153+
"ifferences = values.Select(val => (val - average) * (val - average)).Sum();\r\n " +
154+
" double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (value" +
155+
"s.Count() - 1));\r\n return standardDeviation;\r\n }\r\n\r\n pu" +
156+
"blic static double CalculateConfidenceInterval95(IEnumerable<double> values)\r\n " +
157+
" {\r\n double confidenceInterval95 = 1.96 * CalculateStandardDevia" +
158+
"tion(values) / Math.Sqrt((values.Count() - 1));\r\n return confidenceIn" +
159+
"terval95;\r\n }\r\n }\r\n}\r\n");
162160
return this.GenerationEnvironment.ToString();
163161
}
164162

src/mlnet/Templates/Console/ConsoleHelper.tt

+4-3
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ namespace <#= Namespace #>.Train
8989
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
9090
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
9191
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
92-
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
93-
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
94-
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
92+
for (int i = 0; i < metrics.PerClassLogLoss.Length; i++)
93+
{
94+
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
95+
}
9596
Console.WriteLine($"************************************************************");
9697
}
9798

0 commit comments

Comments
 (0)