@@ -12,6 +12,7 @@ namespace Microsoft.ML.CLI.Templates.Console
12
12
using System . Linq ;
13
13
using System . Text ;
14
14
using System . Collections . Generic ;
15
+ using Microsoft . ML . CLI . Utilities ;
15
16
using System ;
16
17
17
18
/// <summary>
@@ -135,16 +136,20 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
135
136
this . Write ( this . ToStringHelper . ToStringWithCulture ( TaskType ) ) ;
136
137
this . Write ( ".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: " ) ;
137
138
this . Write ( this . ToStringHelper . ToStringWithCulture ( Kfolds ) ) ;
138
- this . Write ( ", labelColumn:\" Label\" );\r \n ConsoleHelper.PrintBinaryClassificationFolds" +
139
- "AverageMetrics(trainer.ToString(), crossValidationResults);\r \n " ) ;
139
+ this . Write ( ", labelColumn:\" " ) ;
140
+ this . Write ( this . ToStringHelper . ToStringWithCulture ( LabelName ) ) ;
141
+ this . Write ( "\" );\r \n ConsoleHelper.PrintBinaryClassificationFoldsAverageMetrics(train" +
142
+ "er.ToString(), crossValidationResults);\r \n " ) ;
140
143
}
141
144
if ( "Regression" . Equals ( TaskType ) ) {
142
145
this . Write ( " var crossValidationResults = mlContext." ) ;
143
146
this . Write ( this . ToStringHelper . ToStringWithCulture ( TaskType ) ) ;
144
147
this . Write ( ".CrossValidate(trainingDataView, trainingPipeline, numFolds: " ) ;
145
148
this . Write ( this . ToStringHelper . ToStringWithCulture ( Kfolds ) ) ;
146
- this . Write ( ", labelColumn:\" Label\" );\r \n ConsoleHelper.PrintRegressionFoldsAverageMet" +
147
- "rics(trainer.ToString(), crossValidationResults);\r \n " ) ;
149
+ this . Write ( ", labelColumn:\" " ) ;
150
+ this . Write ( this . ToStringHelper . ToStringWithCulture ( LabelName ) ) ;
151
+ this . Write ( "\" );\r \n ConsoleHelper.PrintRegressionFoldsAverageMetrics(trainer.ToStrin" +
152
+ "g(), crossValidationResults);\r \n " ) ;
148
153
}
149
154
}
150
155
this . Write ( "\r \n // Train the model fitting to the DataSet\r \n Console.Writ" +
@@ -157,14 +162,18 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
157
162
if ( "BinaryClassification" . Equals ( TaskType ) ) {
158
163
this . Write ( " var metrics = mlContext." ) ;
159
164
this . Write ( this . ToStringHelper . ToStringWithCulture ( TaskType ) ) ;
160
- this . Write ( ".EvaluateNonCalibrated(predictions, \" Label\" , \" Score\" );\r \n ConsoleHelper" +
161
- ".PrintBinaryClassificationMetrics(trainer.ToString(), metrics);\r \n " ) ;
165
+ this . Write ( ".EvaluateNonCalibrated(predictions, \" " ) ;
166
+ this . Write ( this . ToStringHelper . ToStringWithCulture ( LabelName ) ) ;
167
+ this . Write ( "\" , \" Score\" );\r \n ConsoleHelper.PrintBinaryClassificationMetrics(trainer." +
168
+ "ToString(), metrics);\r \n " ) ;
162
169
}
163
170
if ( "Regression" . Equals ( TaskType ) ) {
164
171
this . Write ( " var metrics = mlContext." ) ;
165
172
this . Write ( this . ToStringHelper . ToStringWithCulture ( TaskType ) ) ;
166
- this . Write ( ".Evaluate(predictions, \" Label\" , \" Score\" );\r \n ConsoleHelper.PrintRegress" +
167
- "ionMetrics(trainer.ToString(), metrics);\r \n " ) ;
173
+ this . Write ( ".Evaluate(predictions, \" " ) ;
174
+ this . Write ( this . ToStringHelper . ToStringWithCulture ( LabelName ) ) ;
175
+ this . Write ( "\" , \" Score\" );\r \n ConsoleHelper.PrintRegressionMetrics(trainer.ToString()" +
176
+ ", metrics);\r \n " ) ;
168
177
}
169
178
}
170
179
this . Write ( @"
@@ -211,7 +220,9 @@ private static void TestSinglePrediction(MLContext mlContext)
211
220
var resultprediction = predEngine.Predict(sample);
212
221
213
222
Console.WriteLine($""=============== Single Prediction ==============="");
214
- Console.WriteLine($""Actual value: {sample.Label} | Predicted value: {resultprediction." ) ;
223
+ Console.WriteLine($""Actual value: {sample." ) ;
224
+ this . Write ( this . ToStringHelper . ToStringWithCulture ( Utils . Normalize ( LabelName ) ) ) ;
225
+ this . Write ( "} | Predicted value: {resultprediction." ) ;
215
226
if ( "BinaryClassification" . Equals ( TaskType ) ) {
216
227
this . Write ( "Prediction" ) ;
217
228
} else {
@@ -258,6 +269,7 @@ private static void TestSinglePrediction(MLContext mlContext)
258
269
public bool TrimWhiteSpace { get ; set ; }
259
270
public int Kfolds { get ; set ; } = 5 ;
260
271
public string Namespace { get ; set ; }
272
+ public string LabelName { get ; set ; }
261
273
262
274
}
263
275
#region Base class
0 commit comments