Skip to content

Commit fcdb09a

Browse files
srsaggamDmitry-A
authored andcommitted
remove unused methods in consolehelper and nit picks in generated code (dotnet#261)
* nit picks * change in console helper * fix tests * add space * fix tests
1 parent 125237a commit fcdb09a

6 files changed

+147
-485
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
using System;
1+
//*****************************************************************************************
2+
//* *
3+
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
4+
//* *
5+
//*****************************************************************************************
6+
7+
using System;
28
using System.Collections.Generic;
39
using System.Linq;
4-
using Microsoft.Data.DataView;
5-
using Microsoft.ML.Core.Data;
10+
using Microsoft.ML;
611
using Microsoft.ML.Data;
712

813
namespace MyNamespace
@@ -47,32 +52,15 @@ namespace MyNamespace
4752
Console.WriteLine($"************************************************************");
4853
}
4954

50-
public static void PrintMultiClassClassificationMetrics(string name, MultiClassClassifierMetrics metrics)
51-
{
52-
Console.WriteLine($"************************************************************");
53-
Console.WriteLine($"* Metrics for {name} multi-class classification model ");
54-
Console.WriteLine($"*-----------------------------------------------------------");
55-
Console.WriteLine($" AccuracyMacro = {metrics.AccuracyMacro:0.####}, a value between 0 and 1, the closer to 1, the better");
56-
Console.WriteLine($" AccuracyMicro = {metrics.AccuracyMicro:0.####}, a value between 0 and 1, the closer to 1, the better");
57-
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
58-
Console.WriteLine($" LogLoss for class 1 = {metrics.PerClassLogLoss[0]:0.####}, the closer to 0, the better");
59-
Console.WriteLine($" LogLoss for class 2 = {metrics.PerClassLogLoss[1]:0.####}, the closer to 0, the better");
60-
Console.WriteLine($" LogLoss for class 3 = {metrics.PerClassLogLoss[2]:0.####}, the closer to 0, the better");
61-
Console.WriteLine($"************************************************************");
62-
}
63-
64-
6555
public static void PrintRegressionFoldsAverageMetrics(string algorithmName,
66-
(RegressionMetrics metrics,
67-
ITransformer model,
68-
IDataView scoredTestData)[] crossValidationResults
56+
TrainCatalogBase.CrossValidationResult<RegressionMetrics>[] crossValidationResults
6957
)
7058
{
71-
var L1 = crossValidationResults.Select(r => r.metrics.L1);
72-
var L2 = crossValidationResults.Select(r => r.metrics.L2);
73-
var RMS = crossValidationResults.Select(r => r.metrics.L1);
74-
var lossFunction = crossValidationResults.Select(r => r.metrics.LossFn);
75-
var R2 = crossValidationResults.Select(r => r.metrics.RSquared);
59+
var L1 = crossValidationResults.Select(r => r.Metrics.L1);
60+
var L2 = crossValidationResults.Select(r => r.Metrics.L2);
61+
var RMS = crossValidationResults.Select(r => r.Metrics.L1);
62+
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFn);
63+
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
7664

7765
Console.WriteLine($"*************************************************************************************************************");
7866
Console.WriteLine($"* Metrics for {algorithmName} Regression model ");
@@ -87,12 +75,9 @@ namespace MyNamespace
8775

8876
public static void PrintBinaryClassificationFoldsAverageMetrics(
8977
string algorithmName,
90-
(BinaryClassificationMetrics metrics,
91-
ITransformer model,
92-
IDataView scoredTestData)[] crossValResults
93-
)
78+
TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>[] crossValResults)
9479
{
95-
var metricsInMultipleFolds = crossValResults.Select(r => r.metrics);
80+
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
9681

9782
var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
9883
var AccuracyAverage = AccuracyValues.Average();
@@ -108,45 +93,6 @@ namespace MyNamespace
10893

10994
}
11095

111-
public static void PrintMulticlassClassificationFoldsAverageMetrics(
112-
string algorithmName,
113-
(MultiClassClassifierMetrics metrics,
114-
ITransformer model,
115-
IDataView scoredTestData)[] crossValResults
116-
)
117-
{
118-
var metricsInMultipleFolds = crossValResults.Select(r => r.metrics);
119-
120-
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMicro);
121-
var microAccuracyAverage = microAccuracyValues.Average();
122-
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
123-
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
124-
125-
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.AccuracyMacro);
126-
var macroAccuracyAverage = macroAccuracyValues.Average();
127-
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
128-
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
129-
130-
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
131-
var logLossAverage = logLossValues.Average();
132-
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
133-
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
134-
135-
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
136-
var logLossReductionAverage = logLossReductionValues.Average();
137-
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
138-
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
139-
140-
Console.WriteLine($"*************************************************************************************************************");
141-
Console.WriteLine($"* Metrics for {algorithmName} Multi-class Classification model ");
142-
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
143-
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
144-
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
145-
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
146-
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
147-
Console.WriteLine($"*************************************************************************************************************");
148-
149-
}
15096

15197
public static double CalculateStandardDeviation(IEnumerable<double> values)
15298
{
@@ -162,16 +108,6 @@ namespace MyNamespace
162108
return confidenceInterval95;
163109
}
164110

165-
public static void PrintClusteringMetrics(string name, ClusteringMetrics metrics)
166-
{
167-
Console.WriteLine($"*************************************************");
168-
Console.WriteLine($"* Metrics for {name} clustering model ");
169-
Console.WriteLine($"*------------------------------------------------");
170-
Console.WriteLine($"* AvgMinScore: {metrics.AvgMinScore}");
171-
Console.WriteLine($"* DBI is: {metrics.Dbi}");
172-
Console.WriteLine($"*************************************************");
173-
}
174-
175111
public static void ConsoleWriteHeader(params string[] lines)
176112
{
177113
var defaultColor = Console.ForegroundColor;
@@ -185,59 +121,5 @@ namespace MyNamespace
185121
Console.WriteLine(new string('#', maxLength));
186122
Console.ForegroundColor = defaultColor;
187123
}
188-
189-
public static void ConsoleWriterSection(params string[] lines)
190-
{
191-
var defaultColor = Console.ForegroundColor;
192-
Console.ForegroundColor = ConsoleColor.Blue;
193-
Console.WriteLine(" ");
194-
foreach (var line in lines)
195-
{
196-
Console.WriteLine(line);
197-
}
198-
var maxLength = lines.Select(x => x.Length).Max();
199-
Console.WriteLine(new string('-', maxLength));
200-
Console.ForegroundColor = defaultColor;
201-
}
202-
203-
public static void ConsolePressAnyKey()
204-
{
205-
var defaultColor = Console.ForegroundColor;
206-
Console.ForegroundColor = ConsoleColor.Green;
207-
Console.WriteLine(" ");
208-
Console.WriteLine("Press any key to finish.");
209-
Console.ReadKey();
210-
}
211-
212-
public static void ConsoleWriteException(params string[] lines)
213-
{
214-
var defaultColor = Console.ForegroundColor;
215-
Console.ForegroundColor = ConsoleColor.Red;
216-
const string exceptionTitle = "EXCEPTION";
217-
Console.WriteLine(" ");
218-
Console.WriteLine(exceptionTitle);
219-
Console.WriteLine(new string('#', exceptionTitle.Length));
220-
Console.ForegroundColor = defaultColor;
221-
foreach (var line in lines)
222-
{
223-
Console.WriteLine(line);
224-
}
225-
}
226-
227-
public static void ConsoleWriteWarning(params string[] lines)
228-
{
229-
var defaultColor = Console.ForegroundColor;
230-
Console.ForegroundColor = ConsoleColor.DarkMagenta;
231-
const string warningTitle = "WARNING";
232-
Console.WriteLine(" ");
233-
Console.WriteLine(warningTitle);
234-
Console.WriteLine(new string('#', warningTitle.Length));
235-
Console.ForegroundColor = defaultColor;
236-
foreach (var line in lines)
237-
{
238-
Console.WriteLine(line);
239-
}
240-
}
241-
242124
}
243125
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeTest.approved.txt

+11-21
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,26 @@ namespace MyNamespace
2121
private static string TestDataPath = @"x:\dummypath\dummy_test.csv";
2222
private static string ModelPath = @"x:\models\model.zip";
2323

24-
// Set this flag to enable the training process.
25-
private static bool EnableTraining = false;
26-
2724
static void Main(string[] args)
2825
{
2926
// Create MLContext to be shared across the model creation workflow objects
30-
// Set a random seed for repeatable/deterministic results across multiple trainings.
31-
var mlContext = new MLContext(seed: 1);
27+
var mlContext = new MLContext();
3228

33-
if (EnableTraining)
34-
{
35-
// Create, Train, Evaluate and Save a model
36-
BuildTrainEvaluateAndSaveModel(mlContext);
37-
ConsoleHelper.ConsoleWriteHeader("=============== End of training process ===============");
38-
}
39-
else
40-
{
41-
ConsoleHelper.ConsoleWriteHeader("Skipping the training process. Please set the flag : 'EnableTraining' to 'true' to enable the training process.");
42-
}
29+
// (Optional step) Create, Train, Evaluate and Save the model.zip file
30+
TrainEvaluateAndSaveModel(mlContext);
4331

44-
// Make a single test prediction loading the model from .ZIP file
45-
TestSinglePrediction(mlContext);
32+
// Make a single test prediction loading the model from model.zip file
33+
Predict(mlContext);
4634

4735
ConsoleHelper.ConsoleWriteHeader("=============== End of process, hit any key to finish ===============");
4836
Console.ReadKey();
4937

5038
}
5139

52-
private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
40+
private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext)
5341
{
54-
// Data loading
42+
// Load data
43+
Console.WriteLine("=============== Loading data ===============");
5544
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
5645
path: TrainDataPath,
5746
hasHeader: true,
@@ -88,12 +77,13 @@ namespace MyNamespace
8877
mlContext.Model.Save(trainedModel, fs);
8978

9079
Console.WriteLine("The model is saved to {0}", ModelPath);
80+
ConsoleHelper.ConsoleWriteHeader("=============== End of training process ===============");
9181

9282
return trainedModel;
9383
}
9484

95-
// (OPTIONAL) Try/test a single prediction by loading the model from the file, first.
96-
private static void TestSinglePrediction(MLContext mlContext)
85+
// Try/test a single prediction by loading the model from the file, first.
86+
private static void Predict(MLContext mlContext)
9787
{
9888
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
9989
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(

0 commit comments

Comments
 (0)