Skip to content

Commit 42b2d04

Browse files
srsaggamDmitry-A
authored andcommitted
accept label from user input and provide in generated code (dotnet#205)
1 parent 6fa307a commit 42b2d04

File tree

8 files changed

+66
-61
lines changed

8 files changed

+66
-61
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ public void GeneratedTrainCodeTest()
3434
OutputBaseDir = null,
3535
OutputName = "MyNamespace",
3636
TrainDataset = new FileInfo("x:\\dummypath\\dummy_train.csv"),
37-
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv")
38-
37+
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv"),
38+
LabelName = "Label"
3939
});
4040

4141
(string trainCode, string projectCode, string helperCode) = consoleCodeGen.GenerateCode();
@@ -57,8 +57,8 @@ public void GeneratedProjectCodeTest()
5757
OutputBaseDir = null,
5858
OutputName = "MyNamespace",
5959
TrainDataset = new FileInfo("x:\\dummypath\\dummy_train.csv"),
60-
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv")
61-
60+
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv"),
61+
LabelName = "Label"
6262
});
6363

6464
(string trainCode, string projectCode, string helperCode) = consoleCodeGen.GenerateCode();
@@ -80,8 +80,8 @@ public void GeneratedHelperCodeTest()
8080
OutputBaseDir = null,
8181
OutputName = "MyNamespace",
8282
TrainDataset = new FileInfo("x:\\dummypath\\dummy_train.csv"),
83-
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv")
84-
83+
TestDataset = new FileInfo("x:\\dummypath\\dummy_test.csv"),
84+
LabelName = "Label"
8585
});
8686

8787
(string trainCode, string projectCode, string helperCode) = consoleCodeGen.GenerateCode();

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

+4-28
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void GenerateOutput()
5151
var classLabels = this.GenerateClassLabels();
5252

5353
// Get Namespace
54-
var namespaceValue = Normalize(options.OutputName);
54+
var namespaceValue = Utils.Normalize(options.OutputName);
5555

5656
// Generate code for training and scoring
5757
var trainFileContent = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
@@ -108,7 +108,8 @@ internal string GenerateTrainCode(string usings, string trainer, List<string> tr
108108
Path = options.TrainDataset.FullName,
109109
TestPath = options.TestDataset?.FullName,
110110
TaskType = options.MlTask.ToString(),
111-
Namespace = namespaceValue
111+
Namespace = namespaceValue,
112+
LabelName = options.LabelName
112113
};
113114

114115
return trainingAndScoringCodeGen.TransformText();
@@ -214,15 +215,7 @@ internal IList<string> GenerateClassLabels()
214215
result.Add($"[ColumnName(\"{column.Name}\"), LoadColumn({column.Source[0].Min})]");
215216
}
216217
sb.Append(" ");
217-
if (column.Name.Equals(label_column))
218-
{
219-
sb.Append("Label");
220-
}
221-
else
222-
{
223-
sb.Append(Normalize(column.Name));
224-
}
225-
218+
sb.Append(Utils.Normalize(column.Name));
226219
sb.Append("{get; set;}");
227220
result.Add(sb.ToString());
228221
result.Add("\r\n");
@@ -277,22 +270,5 @@ private static string ConstructColumnDefinition(Column column)
277270
var def = $"new Column(\"{column.Name}\",DataKind.{column.Type},{rangeBuilder.ToString()}),";
278271
return def;
279272
}
280-
281-
private static string Normalize(string inputColumn)
282-
{
283-
//check if first character is int
284-
if (!string.IsNullOrEmpty(inputColumn) && int.TryParse(inputColumn.Substring(0, 1), out int val))
285-
{
286-
inputColumn = "Col" + inputColumn;
287-
return inputColumn;
288-
}
289-
switch (inputColumn)
290-
{
291-
case null: throw new ArgumentNullException(nameof(inputColumn));
292-
case "": throw new ArgumentException($"{nameof(inputColumn)} cannot be empty", nameof(inputColumn));
293-
default: return inputColumn.First().ToString().ToUpper() + inputColumn.Substring(1);
294-
}
295-
}
296-
297273
}
298274
}

src/mlnet/CodeGenerator/CSharp/CodeGeneratorOptions.cs

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ namespace Microsoft.ML.CLI.CodeGenerator.CSharp
55
{
66
internal class CodeGeneratorOptions
77
{
8+
public string LabelName { get; internal set; }
89
internal string OutputName { get; set; }
910

1011
internal string OutputBaseDir { get; set; }

src/mlnet/Commands/New/NewCommandHandler.cs

+10-13
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public void Execute()
4848
// Sanitize columns
4949
Array.ForEach(columnInference.TextLoaderArgs.Column, t => t.Name = Utils.Sanitize(t.Name));
5050

51+
var sanitized_Label_Name = Utils.Sanitize(columnInference.ColumnInformation.LabelColumn);
52+
5153
// Load data
5254
(IDataView trainData, IDataView validationData) = LoadData(context, columnInference.TextLoaderArgs);
5355

@@ -56,7 +58,7 @@ public void Execute()
5658
Console.WriteLine($"{Strings.ExplorePipeline}: {options.MlTask}");
5759
try
5860
{
59-
result = ExploreModels(context, trainData, validationData);
61+
result = ExploreModels(context, trainData, validationData, sanitized_Label_Name);
6062
}
6163
catch (Exception e)
6264
{
@@ -77,7 +79,7 @@ public void Execute()
7779
Utils.SaveModel(model, options.OutputPath.FullName, $"{options.Name}_model.zip", context);
7880

7981
// Generate the Project
80-
GenerateProject(columnInference, pipeline);
82+
GenerateProject(columnInference, pipeline, sanitized_Label_Name);
8183
}
8284

8385
internal ColumnInferenceResults InferColumns(MLContext context)
@@ -98,7 +100,7 @@ internal ColumnInferenceResults InferColumns(MLContext context)
98100
return columnInference;
99101
}
100102

101-
internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline)
103+
internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline, string labelName)
102104
{
103105
//Generate code
104106
logger.Log(LogLevel.Info, $"{Strings.GenerateProject} : {options.OutputPath.FullName}");
@@ -111,20 +113,15 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p
111113
MlTask = taskKind,
112114
TestDataset = options.TestDataset,
113115
OutputName = options.Name,
114-
OutputBaseDir = options.OutputPath.FullName
116+
OutputBaseDir = options.OutputPath.FullName,
117+
LabelName = labelName
115118
});
116119
codeGenerator.GenerateOutput();
117120
}
118121

119-
internal (Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData)
122+
internal (Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData, string labelName)
120123
{
121124
ITransformer model = null;
122-
string label = "Label";
123-
124-
if (options.LabelColumnName != null)
125-
{
126-
label = Utils.Sanitize(options.LabelColumnName);
127-
}
128125

129126
Pipeline pipeline = null;
130127

@@ -137,7 +134,7 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p
137134
MaxInferenceTimeInSeconds = options.MaxExplorationTime,
138135
ProgressCallback = progressReporter
139136
})
140-
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = label });
137+
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
141138
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
142139
var bestIteration = result.Best();
143140
pipeline = bestIteration.Pipeline;
@@ -152,7 +149,7 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p
152149
{
153150
MaxInferenceTimeInSeconds = options.MaxExplorationTime,
154151
ProgressCallback = progressReporter
155-
}).Execute(trainData, validationData, new ColumnInformation() { LabelColumn = label });
152+
}).Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
156153
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
157154
var bestIteration = result.Best();
158155
pipeline = bestIteration.Pipeline;

src/mlnet/Program.cs

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public static void Main(string[] args)
4747

4848

4949
parser.InvokeAsync(args).Wait();
50+
Console.ReadKey();
5051
}
5152
}
5253
}

src/mlnet/Templates/Console/MLCodeGen.cs

+21-9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace Microsoft.ML.CLI.Templates.Console
1212
using System.Linq;
1313
using System.Text;
1414
using System.Collections.Generic;
15+
using Microsoft.ML.CLI.Utilities;
1516
using System;
1617

1718
/// <summary>
@@ -135,16 +136,20 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
135136
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
136137
this.Write(".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: ");
137138
this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
138-
this.Write(", labelColumn:\"Label\");\r\n ConsoleHelper.PrintBinaryClassificationFolds" +
139-
"AverageMetrics(trainer.ToString(), crossValidationResults);\r\n");
139+
this.Write(", labelColumn:\"");
140+
this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
141+
this.Write("\");\r\n ConsoleHelper.PrintBinaryClassificationFoldsAverageMetrics(train" +
142+
"er.ToString(), crossValidationResults);\r\n");
140143
}
141144
if("Regression".Equals(TaskType)){
142145
this.Write(" var crossValidationResults = mlContext.");
143146
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
144147
this.Write(".CrossValidate(trainingDataView, trainingPipeline, numFolds: ");
145148
this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds));
146-
this.Write(", labelColumn:\"Label\");\r\n ConsoleHelper.PrintRegressionFoldsAverageMet" +
147-
"rics(trainer.ToString(), crossValidationResults);\r\n");
149+
this.Write(", labelColumn:\"");
150+
this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
151+
this.Write("\");\r\n ConsoleHelper.PrintRegressionFoldsAverageMetrics(trainer.ToStrin" +
152+
"g(), crossValidationResults);\r\n");
148153
}
149154
}
150155
this.Write("\r\n // Train the model fitting to the DataSet\r\n Console.Writ" +
@@ -157,14 +162,18 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
157162
if("BinaryClassification".Equals(TaskType)){
158163
this.Write(" var metrics = mlContext.");
159164
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
160-
this.Write(".EvaluateNonCalibrated(predictions, \"Label\", \"Score\");\r\n ConsoleHelper" +
161-
".PrintBinaryClassificationMetrics(trainer.ToString(), metrics);\r\n");
165+
this.Write(".EvaluateNonCalibrated(predictions, \"");
166+
this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
167+
this.Write("\", \"Score\");\r\n ConsoleHelper.PrintBinaryClassificationMetrics(trainer." +
168+
"ToString(), metrics);\r\n");
162169
}
163170
if("Regression".Equals(TaskType)){
164171
this.Write(" var metrics = mlContext.");
165172
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
166-
this.Write(".Evaluate(predictions, \"Label\", \"Score\");\r\n ConsoleHelper.PrintRegress" +
167-
"ionMetrics(trainer.ToString(), metrics);\r\n");
173+
this.Write(".Evaluate(predictions, \"");
174+
this.Write(this.ToStringHelper.ToStringWithCulture(LabelName));
175+
this.Write("\", \"Score\");\r\n ConsoleHelper.PrintRegressionMetrics(trainer.ToString()" +
176+
", metrics);\r\n");
168177
}
169178
}
170179
this.Write(@"
@@ -211,7 +220,9 @@ private static void TestSinglePrediction(MLContext mlContext)
211220
var resultprediction = predEngine.Predict(sample);
212221
213222
Console.WriteLine($""=============== Single Prediction ==============="");
214-
Console.WriteLine($""Actual value: {sample.Label} | Predicted value: {resultprediction.");
223+
Console.WriteLine($""Actual value: {sample.");
224+
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
225+
this.Write("} | Predicted value: {resultprediction.");
215226
if("BinaryClassification".Equals(TaskType)){
216227
this.Write("Prediction");
217228
}else{
@@ -258,6 +269,7 @@ private static void TestSinglePrediction(MLContext mlContext)
258269
public bool TrimWhiteSpace {get;set;}
259270
public int Kfolds {get;set;} = 5;
260271
public string Namespace {get;set;}
272+
public string LabelName {get;set;}
261273

262274
}
263275
#region Base class

src/mlnet/Templates/Console/MLCodeGen.tt

+7-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
<#@ import namespace="System.Linq" #>
44
<#@ import namespace="System.Text" #>
55
<#@ import namespace="System.Collections.Generic" #>
6+
<#@ import namespace="Microsoft.ML.CLI.Utilities" #>
67
// Licensed to the .NET Foundation under one or more agreements.
78
// The .NET Foundation licenses this file to you under the MIT license.
89
// See the LICENSE file in the project root for more information.
@@ -89,10 +90,10 @@ else{#>
8990
// in order to evaluate and get the model's accuracy metrics
9091
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
9192
<#if("BinaryClassification".Equals(TaskType)){ #>
92-
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"Label");
93+
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"<#= LabelName #>");
9394
ConsoleHelper.PrintBinaryClassificationFoldsAverageMetrics(trainer.ToString(), crossValidationResults);
9495
<#}#><#if("Regression".Equals(TaskType)){ #>
95-
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"Label");
96+
var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numFolds: <#= Kfolds #>, labelColumn:"<#= LabelName #>");
9697
ConsoleHelper.PrintRegressionFoldsAverageMetrics(trainer.ToString(), crossValidationResults);
9798
<#}
9899
} #>
@@ -106,10 +107,10 @@ else{#>
106107
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
107108
var predictions = trainedModel.Transform(testDataView);
108109
<#if("BinaryClassification".Equals(TaskType)){ #>
109-
var metrics = mlContext.<#= TaskType #>.EvaluateNonCalibrated(predictions, "Label", "Score");
110+
var metrics = mlContext.<#= TaskType #>.EvaluateNonCalibrated(predictions, "<#= LabelName #>", "Score");
110111
ConsoleHelper.PrintBinaryClassificationMetrics(trainer.ToString(), metrics);
111112
<#}#><#if("Regression".Equals(TaskType)){ #>
112-
var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "Label", "Score");
113+
var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "<#= LabelName #>", "Score");
113114
ConsoleHelper.PrintRegressionMetrics(trainer.ToString(), metrics);
114115
<#}#>
115116
<# } #>
@@ -151,7 +152,7 @@ else{#>
151152
var resultprediction = predEngine.Predict(sample);
152153

153154
Console.WriteLine($"=============== Single Prediction ===============");
154-
Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.<#if("BinaryClassification".Equals(TaskType)){ #>Prediction<#}else{#>Score<#}#>}");
155+
Console.WriteLine($"Actual value: {sample.<#= Utils.Normalize(LabelName) #>} | Predicted value: {resultprediction.<#if("BinaryClassification".Equals(TaskType)){ #>Prediction<#}else{#>Score<#}#>}");
155156
Console.WriteLine($"==================================================");
156157
}
157158

@@ -201,4 +202,5 @@ public bool AllowSparse {get;set;}
201202
public bool TrimWhiteSpace {get;set;}
202203
public int Kfolds {get;set;} = 5;
203204
public string Namespace {get;set;}
205+
public string LabelName {get;set;}
204206
#>

src/mlnet/Utilities/Utils.cs

+16
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,21 @@ internal static TaskKind GetTaskKind(string mlTask)
5959
}
6060
}
6161

62+
internal static string Normalize(string inputColumn)
63+
{
64+
//check if first character is int
65+
if (!string.IsNullOrEmpty(inputColumn) && int.TryParse(inputColumn.Substring(0, 1), out int val))
66+
{
67+
inputColumn = "Col" + inputColumn;
68+
return inputColumn;
69+
}
70+
switch (inputColumn)
71+
{
72+
case null: throw new ArgumentNullException(nameof(inputColumn));
73+
case "": throw new ArgumentException($"{nameof(inputColumn)} cannot be empty", nameof(inputColumn));
74+
default: return inputColumn.First().ToString().ToUpper() + inputColumn.Substring(1);
75+
}
76+
}
77+
6278
}
6379
}

0 commit comments

Comments
 (0)