Skip to content

Commit c5ec302

Browse files
authored
Change in project structure (dotnet#385)
* initial changes * Change in project structure * correcting test * change variable name * fix tests * fix tests * fix more tests * fix codegen errors * adde log file message * changed name of args * change variable names * fix test
1 parent e257648 commit c5ec302

28 files changed

+752
-2088
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.TrainProgramCSFileContentTest.approved.txt renamed to src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt

+52-11
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,27 @@
55
//*****************************************************************************************
66

77
using System;
8+
using System.Collections.Generic;
89
using System.IO;
910
using System.Linq;
1011
using Microsoft.ML;
12+
using Microsoft.ML.Data;
1113
using TestNamespace.Model.DataModels;
1214

13-
namespace TestNamespace.Train
15+
namespace TestNamespace.ConsoleApp
1416
{
15-
class Program
17+
public static class ModelBuilder
1618
{
1719
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv";
1820
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv";
1921
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip";
2022

21-
static void Main(string[] args)
22-
{
23-
// Create MLContext to be shared across the model creation workflow objects
24-
// Set a random seed for repeatable/deterministic results across multiple trainings.
25-
MLContext mlContext = new MLContext(seed: 1);
23+
// Create MLContext to be shared across the model creation workflow objects
24+
// Set a random seed for repeatable/deterministic results across multiple trainings.
25+
private static MLContext mlContext = new MLContext(seed: 1);
2626

27+
public static void CreateModel()
28+
{
2729
// Load Data
2830
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
2931
path: TRAIN_DATA_FILEPATH,
@@ -49,9 +51,6 @@ namespace TestNamespace.Train
4951

5052
// Save model
5153
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
52-
53-
Console.WriteLine("=============== End of process, hit any key to finish ===============");
54-
Console.ReadKey();
5554
}
5655

5756
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
@@ -83,7 +82,7 @@ namespace TestNamespace.Train
8382
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
8483
IDataView predictions = mlModel.Transform(testDataView);
8584
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(predictions, "Label", "Score");
86-
ConsoleHelper.PrintBinaryClassificationMetrics(metrics);
85+
PrintBinaryClassificationMetrics(metrics);
8786
}
8887
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
8988
{
@@ -104,5 +103,47 @@ namespace TestNamespace.Train
104103

105104
return fullPath;
106105
}
106+
107+
public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics)
108+
{
109+
Console.WriteLine($"************************************************************");
110+
Console.WriteLine($"* Metrics for binary classification model ");
111+
Console.WriteLine($"*-----------------------------------------------------------");
112+
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
113+
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
114+
Console.WriteLine($"************************************************************");
115+
}
116+
117+
118+
public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
119+
{
120+
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
121+
122+
var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
123+
var AccuracyAverage = AccuracyValues.Average();
124+
var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);
125+
var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues);
126+
127+
128+
Console.WriteLine($"*************************************************************************************************************");
129+
Console.WriteLine($"* Metrics for Binary Classification model ");
130+
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
131+
Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})");
132+
Console.WriteLine($"*************************************************************************************************************");
133+
}
134+
135+
public static double CalculateStandardDeviation(IEnumerable<double> values)
136+
{
137+
double average = values.Average();
138+
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
139+
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
140+
return standardDeviation;
141+
}
142+
143+
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
144+
{
145+
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
146+
return confidenceInterval95;
147+
}
107148
}
108149
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleHelperFileContentTest.approved.txt renamed to src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt

+74-46
Original file line numberDiff line numberDiff line change
@@ -6,74 +6,102 @@
66

77
using System;
88
using System.Collections.Generic;
9+
using System.IO;
910
using System.Linq;
1011
using Microsoft.ML;
1112
using Microsoft.ML.Data;
13+
using TestNamespace.Model.DataModels;
1214

13-
namespace TestNamespace.Train
15+
namespace TestNamespace.ConsoleApp
1416
{
15-
public static class ConsoleHelper
17+
public static class ModelBuilder
1618
{
19+
private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv";
20+
private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv";
21+
private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip";
1722

18-
public static void PrintRegressionMetrics(RegressionMetrics metrics)
23+
// Create MLContext to be shared across the model creation workflow objects
24+
// Set a random seed for repeatable/deterministic results across multiple trainings.
25+
private static MLContext mlContext = new MLContext(seed: 1);
26+
27+
public static void CreateModel()
1928
{
20-
Console.WriteLine($"*************************************************");
21-
Console.WriteLine($"* Metrics for regression model ");
22-
Console.WriteLine($"*------------------------------------------------");
23-
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}");
24-
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
25-
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}");
26-
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}");
27-
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}");
28-
Console.WriteLine($"*************************************************");
29+
// Load Data
30+
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
31+
path: TRAIN_DATA_FILEPATH,
32+
hasHeader: true,
33+
separatorChar: ',',
34+
allowQuoting: true,
35+
allowSparse: true);
36+
37+
IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
38+
path: TEST_DATA_FILEPATH,
39+
hasHeader: true,
40+
separatorChar: ',',
41+
allowQuoting: true,
42+
allowSparse: true);
43+
// Build training pipeline
44+
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
45+
46+
// Train Model
47+
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
48+
49+
// Evaluate quality of Model
50+
EvaluateModel(mlContext, mlModel, testDataView);
51+
52+
// Save model
53+
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
2954
}
3055

31-
public static void PrintRegressionFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RegressionMetrics>> crossValidationResults)
56+
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
3257
{
33-
var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
34-
var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError);
35-
var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError);
36-
var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction);
37-
var R2 = crossValidationResults.Select(r => r.Metrics.RSquared);
58+
// Data process configuration with pipeline data transformations
59+
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
60+
.AppendCacheCheckpoint(mlContext);
3861

39-
Console.WriteLine($"*************************************************************************************************************");
40-
Console.WriteLine($"* Metrics for Regression model ");
41-
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
42-
Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} ");
43-
Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} ");
44-
Console.WriteLine($"* Average RMS: {RMS.Average():0.###} ");
45-
Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} ");
46-
Console.WriteLine($"* Average R-squared: {R2.Average():0.###} ");
47-
Console.WriteLine($"*************************************************************************************************************");
62+
// Set the training algorithm
63+
var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label");
64+
var trainingPipeline = dataProcessPipeline.Append(trainer);
65+
66+
return trainingPipeline;
4867
}
4968

50-
public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics)
69+
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
5170
{
52-
Console.WriteLine($"************************************************************");
53-
Console.WriteLine($"* Metrics for binary classification model ");
54-
Console.WriteLine($"*-----------------------------------------------------------");
55-
Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}");
56-
Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}");
57-
Console.WriteLine($"************************************************************");
58-
}
71+
Console.WriteLine("=============== Training model ===============");
5972

73+
ITransformer model = trainingPipeline.Fit(trainingDataView);
6074

61-
public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<BinaryClassificationMetrics>> crossValResults)
75+
Console.WriteLine("=============== End of training process ===============");
76+
return model;
77+
}
78+
79+
private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView)
6280
{
63-
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
81+
// Evaluate the model and show accuracy stats
82+
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
83+
IDataView predictions = mlModel.Transform(testDataView);
84+
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score");
85+
PrintMulticlassClassificationMetrics(metrics);
86+
}
87+
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
88+
{
89+
// Save/persist the trained model to a .ZIP file
90+
Console.WriteLine($"=============== Saving the model ===============");
91+
using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write))
92+
mlContext.Model.Save(mlModel, modelInputSchema, fs);
6493

65-
var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy);
66-
var AccuracyAverage = AccuracyValues.Average();
67-
var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues);
68-
var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues);
94+
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
95+
}
6996

97+
public static string GetAbsolutePath(string relativePath)
98+
{
99+
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
100+
string assemblyFolderPath = _dataRoot.Directory.FullName;
70101

71-
Console.WriteLine($"*************************************************************************************************************");
72-
Console.WriteLine($"* Metrics for Binary Classification model ");
73-
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
74-
Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})");
75-
Console.WriteLine($"*************************************************************************************************************");
102+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
76103

104+
return fullPath;
77105
}
78106

79107
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)

0 commit comments

Comments
 (0)