Skip to content

Commit fe6f88b

Browse files
srsaggamDmitry-A
authored andcommitted
Codegen for multiclass non-ova (dotnet#303)
* changes to template * multicalss codegen * test cases * fix test cases
1 parent a8a9240 commit fe6f88b

11 files changed

+526
-58
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeTest.approved.txt renamed to src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeBinaryClassificationTest.approved.txt

+6-6
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ namespace MyNamespace
122122
var resultprediction = predEngine.Predict(sample);
123123

124124
Console.WriteLine($"=============== Single Prediction ===============");
125-
Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.Prediction}");
125+
Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.Prediction} ");
126126
Console.WriteLine($"==================================================");
127127
}
128128

@@ -135,23 +135,23 @@ namespace MyNamespace
135135

136136

137137
[ColumnName("col1"), LoadColumn(1)]
138-
public float col1 { get; set; }
138+
public float Col1 { get; set; }
139139

140140

141141
[ColumnName("col2"), LoadColumn(0)]
142-
public float col2 { get; set; }
142+
public float Col2 { get; set; }
143143

144144

145145
[ColumnName("col3"), LoadColumn(0)]
146-
public string col3 { get; set; }
146+
public string Col3 { get; set; }
147147

148148

149149
[ColumnName("col4"), LoadColumn(0)]
150-
public int col4 { get; set; }
150+
public int Col4 { get; set; }
151151

152152

153153
[ColumnName("col5"), LoadColumn(0)]
154-
public uint col5 { get; set; }
154+
public uint Col5 { get; set; }
155155

156156

157157
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
//*****************************************************************************************
2+
//* *
3+
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
4+
//* *
5+
//*****************************************************************************************
6+
7+
using System;
8+
using System.IO;
9+
using System.Linq;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Data;
12+
using Microsoft.Data.DataView;
13+
using Microsoft.ML.LightGBM;
14+
15+
16+
namespace MyNamespace
17+
{
18+
class Program
19+
{
20+
private static string TrainDataPath = @"x:\dummypath\dummy_train.csv";
21+
private static string TestDataPath = @"x:\dummypath\dummy_test.csv";
22+
private static string ModelPath = @"x:\models\model.zip";
23+
24+
static void Main(string[] args)
25+
{
26+
// Create MLContext to be shared across the model creation workflow objects
27+
var mlContext = new MLContext();
28+
29+
var command = Command.Predict; // Your desired action here
30+
31+
if (command == Command.Predict)
32+
{
33+
Predict(mlContext);
34+
ConsoleHelper.ConsoleWriteHeader("=============== If you also want to train a model use Command.TrainAndPredict ===============");
35+
}
36+
37+
if (command == Command.TrainAndPredict)
38+
{
39+
TrainEvaluateAndSaveModel(mlContext);
40+
Predict(mlContext);
41+
}
42+
43+
Console.WriteLine("=============== End of process, hit any key to finish ===============");
44+
Console.ReadKey();
45+
}
46+
47+
private enum Command
48+
{
49+
Predict,
50+
TrainAndPredict
51+
}
52+
53+
private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext)
54+
{
55+
// Load data
56+
Console.WriteLine("=============== Loading data ===============");
57+
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
58+
path: TrainDataPath,
59+
hasHeader: true,
60+
separatorChar: ',',
61+
allowQuoting: true,
62+
allowSparse: true);
63+
IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
64+
path: TestDataPath,
65+
hasHeader: true,
66+
separatorChar: ',',
67+
allowQuoting: true,
68+
allowSparse: true);
69+
70+
// Common data process configuration with pipeline data transformations
71+
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
72+
.AppendCacheCheckpoint(mlContext);
73+
74+
// Set the training algorithm, then create and config the modelBuilder
75+
var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
76+
var trainingPipeline = dataProcessPipeline.Append(trainer);
77+
78+
// Train the model fitting to the DataSet
79+
Console.WriteLine("=============== Training the model ===============");
80+
var trainedModel = trainingPipeline.Fit(trainingDataView);
81+
82+
// Evaluate the model and show accuracy stats
83+
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
84+
var predictions = trainedModel.Transform(testDataView);
85+
86+
// Save/persist the trained model to a .ZIP file
87+
Console.WriteLine($"=============== Saving the model ===============");
88+
using (var fs = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
89+
mlContext.Model.Save(trainedModel, fs);
90+
91+
Console.WriteLine("The model is saved to {0}", ModelPath);
92+
Console.WriteLine("=============== End of training process ===============");
93+
94+
return trainedModel;
95+
}
96+
97+
// Try/test a single prediction by loading the model from the file, first.
98+
private static void Predict(MLContext mlContext)
99+
{
100+
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
101+
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
102+
path: TestDataPath,
103+
hasHeader: true,
104+
separatorChar: ',',
105+
allowQuoting: true,
106+
allowSparse: true);
107+
108+
var sample = mlContext.Data.CreateEnumerable<SampleObservation>(trainingDataView, false).First();
109+
110+
ITransformer trainedModel;
111+
using (var stream = new FileStream(ModelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
112+
{
113+
trainedModel = mlContext.Model.Load(stream);
114+
}
115+
116+
// Create prediction engine related to the loaded trained model
117+
var predEngine = trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
118+
119+
//Score
120+
var resultprediction = predEngine.Predict(sample);
121+
122+
Console.WriteLine($"=============== Single Prediction ===============");
123+
Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.Prediction} | Predicted scores: [{String.Join(", ", resultprediction.Score)}]");
124+
Console.WriteLine($"==================================================");
125+
}
126+
127+
}
128+
129+
public class SampleObservation
130+
{
131+
[ColumnName("Label"), LoadColumn(0)]
132+
public bool Label { get; set; }
133+
134+
135+
[ColumnName("col1"), LoadColumn(1)]
136+
public float Col1 { get; set; }
137+
138+
139+
[ColumnName("col2"), LoadColumn(0)]
140+
public float Col2 { get; set; }
141+
142+
143+
[ColumnName("col3"), LoadColumn(0)]
144+
public string Col3 { get; set; }
145+
146+
147+
[ColumnName("col4"), LoadColumn(0)]
148+
public int Col4 { get; set; }
149+
150+
151+
[ColumnName("col5"), LoadColumn(0)]
152+
public uint Col5 { get; set; }
153+
154+
155+
}
156+
157+
public class SamplePrediction
158+
{
159+
// ColumnName attribute is used to change the column name from
160+
// its default value, which is the name of the field.
161+
[ColumnName("PredictedLabel")]
162+
public Boolean Prediction { get; set; }
163+
public float[] Score { get; set; }
164+
}
165+
166+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//*****************************************************************************************
2+
//* *
3+
//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. *
4+
//* *
5+
//*****************************************************************************************
6+
7+
using System;
8+
using System.IO;
9+
using System.Linq;
10+
using Microsoft.ML;
11+
using Microsoft.ML.Data;
12+
using Microsoft.Data.DataView;
13+
using Microsoft.ML.LightGBM;
14+
15+
16+
namespace MyNamespace
17+
{
18+
class Program
19+
{
20+
private static string TrainDataPath = @"x:\dummypath\dummy_train.csv";
21+
private static string TestDataPath = @"x:\dummypath\dummy_test.csv";
22+
private static string ModelPath = @"x:\models\model.zip";
23+
24+
static void Main(string[] args)
25+
{
26+
// Create MLContext to be shared across the model creation workflow objects
27+
var mlContext = new MLContext();
28+
29+
var command = Command.Predict; // Your desired action here
30+
31+
if (command == Command.Predict)
32+
{
33+
Predict(mlContext);
34+
ConsoleHelper.ConsoleWriteHeader("=============== If you also want to train a model use Command.TrainAndPredict ===============");
35+
}
36+
37+
if (command == Command.TrainAndPredict)
38+
{
39+
TrainEvaluateAndSaveModel(mlContext);
40+
Predict(mlContext);
41+
}
42+
43+
Console.WriteLine("=============== End of process, hit any key to finish ===============");
44+
Console.ReadKey();
45+
}
46+
47+
private enum Command
48+
{
49+
Predict,
50+
TrainAndPredict
51+
}
52+
53+
private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext)
54+
{
55+
// Load data
56+
Console.WriteLine("=============== Loading data ===============");
57+
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
58+
path: TrainDataPath,
59+
hasHeader: true,
60+
separatorChar: ',',
61+
allowQuoting: true,
62+
allowSparse: true);
63+
IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
64+
path: TestDataPath,
65+
hasHeader: true,
66+
separatorChar: ',',
67+
allowQuoting: true,
68+
allowSparse: true);
69+
70+
// Common data process configuration with pipeline data transformations
71+
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" })
72+
.AppendCacheCheckpoint(mlContext);
73+
74+
// Set the training algorithm, then create and config the modelBuilder
75+
var trainer = mlContext.Regression.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" });
76+
var trainingPipeline = dataProcessPipeline.Append(trainer);
77+
78+
// Train the model fitting to the DataSet
79+
Console.WriteLine("=============== Training the model ===============");
80+
var trainedModel = trainingPipeline.Fit(trainingDataView);
81+
82+
// Evaluate the model and show accuracy stats
83+
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
84+
var predictions = trainedModel.Transform(testDataView);
85+
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
86+
ConsoleHelper.PrintRegressionMetrics(trainer.ToString(), metrics);
87+
88+
// Save/persist the trained model to a .ZIP file
89+
Console.WriteLine($"=============== Saving the model ===============");
90+
using (var fs = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
91+
mlContext.Model.Save(trainedModel, fs);
92+
93+
Console.WriteLine("The model is saved to {0}", ModelPath);
94+
Console.WriteLine("=============== End of training process ===============");
95+
96+
return trainedModel;
97+
}
98+
99+
// Try/test a single prediction by loading the model from the file, first.
100+
private static void Predict(MLContext mlContext)
101+
{
102+
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
103+
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
104+
path: TestDataPath,
105+
hasHeader: true,
106+
separatorChar: ',',
107+
allowQuoting: true,
108+
allowSparse: true);
109+
110+
var sample = mlContext.Data.CreateEnumerable<SampleObservation>(trainingDataView, false).First();
111+
112+
ITransformer trainedModel;
113+
using (var stream = new FileStream(ModelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
114+
{
115+
trainedModel = mlContext.Model.Load(stream);
116+
}
117+
118+
// Create prediction engine related to the loaded trained model
119+
var predEngine = trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
120+
121+
//Score
122+
var resultprediction = predEngine.Predict(sample);
123+
124+
Console.WriteLine($"=============== Single Prediction ===============");
125+
Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.Score} ");
126+
Console.WriteLine($"==================================================");
127+
}
128+
129+
}
130+
131+
public class SampleObservation
132+
{
133+
[ColumnName("Label"), LoadColumn(0)]
134+
public bool Label { get; set; }
135+
136+
137+
[ColumnName("col1"), LoadColumn(1)]
138+
public float Col1 { get; set; }
139+
140+
141+
[ColumnName("col2"), LoadColumn(0)]
142+
public float Col2 { get; set; }
143+
144+
145+
[ColumnName("col3"), LoadColumn(0)]
146+
public string Col3 { get; set; }
147+
148+
149+
[ColumnName("col4"), LoadColumn(0)]
150+
public int Col4 { get; set; }
151+
152+
153+
[ColumnName("col5"), LoadColumn(0)]
154+
public uint Col5 { get; set; }
155+
156+
157+
}
158+
159+
public class SamplePrediction
160+
{
161+
public float Score { get; set; }
162+
}
163+
164+
}

0 commit comments

Comments
 (0)