Skip to content

Commit b3f980b

Browse files
authored
Fix bug for regression and sanitize input label from user (dotnet#198)
* removed dummy command * sanitize label and fix template * fix tests
1 parent 9cc2910 commit b3f980b

File tree

6 files changed

+29
-14
lines changed

6 files changed

+29
-14
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedProjectCodeTest.approved.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<EnableDefaultCompileItems>False</EnableDefaultCompileItems>
77
</PropertyGroup>
88
<ItemGroup>
9-
<Compile Include="Train.cs" />
9+
<Compile Include="Program.cs" />
1010
<Compile Include="ConsoleHelper.cs" />
1111
</ItemGroup>
1212
<ItemGroup>

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ public void GeneratedHelperCodeTest()
124124

125125
this.columnInference = new ColumnInferenceResults()
126126
{
127-
TextLoaderArgs = textLoaderArgs
127+
TextLoaderArgs = textLoaderArgs,
128+
ColumnInformation = new ColumnInformation() { LabelColumn = "Label" }
128129
};
129130
}
130131
return (pipeline, columnInference);

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

+16-7
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
using Microsoft.CodeAnalysis.Formatting;
1313
using Microsoft.ML.Auto;
1414
using Microsoft.ML.CLI.Templates.Console;
15+
using Microsoft.ML.CLI.Utilities;
1516
using static Microsoft.ML.Data.TextLoader;
1617

1718
namespace Microsoft.ML.CLI.CodeGenerator.CSharp
@@ -69,14 +70,13 @@ public void GenerateOutput()
6970

7071
internal void WriteOutputToFiles(string trainScoreCode, string projectSourceCode, string consoleHelperCode)
7172
{
72-
var outputFolder = Path.Combine(options.OutputBaseDir, options.OutputName);
73-
if (!Directory.Exists(outputFolder))
73+
if (!Directory.Exists(options.OutputBaseDir))
7474
{
75-
Directory.CreateDirectory(outputFolder);
75+
Directory.CreateDirectory(options.OutputBaseDir);
7676
}
77-
File.WriteAllText($"{outputFolder}/Train.cs", trainScoreCode);
78-
File.WriteAllText($"{outputFolder}/{options.OutputName}.csproj", projectSourceCode);
79-
File.WriteAllText($"{outputFolder}/ConsoleHelper.cs", consoleHelperCode);
77+
File.WriteAllText($"{options.OutputBaseDir}/Program.cs", trainScoreCode);
78+
File.WriteAllText($"{options.OutputBaseDir}/{options.OutputName}.csproj", projectSourceCode);
79+
File.WriteAllText($"{options.OutputBaseDir}/ConsoleHelper.cs", consoleHelperCode);
8080
}
8181

8282
internal static string GenerateConsoleHelper(string namespaceValue)
@@ -165,6 +165,7 @@ internal string GenerateTrainCode(string usings, string trainer, List<string> tr
165165
internal IList<string> GenerateClassLabels()
166166
{
167167
IList<string> result = new List<string>();
168+
var label_column = Utils.Sanitize(columnInferenceResult.ColumnInformation.LabelColumn);
168169
foreach (var column in columnInferenceResult.TextLoaderArgs.Column)
169170
{
170171
StringBuilder sb = new StringBuilder();
@@ -213,7 +214,15 @@ internal IList<string> GenerateClassLabels()
213214
result.Add($"[ColumnName(\"{column.Name}\"), LoadColumn({column.Source[0].Min})]");
214215
}
215216
sb.Append(" ");
216-
sb.Append(Normalize(column.Name));
217+
if (column.Name.Equals(label_column))
218+
{
219+
sb.Append("Label");
220+
}
221+
else
222+
{
223+
sb.Append(Normalize(column.Name));
224+
}
225+
217226
sb.Append("{get; set;}");
218227
result.Add(sb.ToString());
219228
result.Add("\r\n");

src/mlnet/Commands/New/NewCommandHandler.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ public void Execute()
7474

7575
// Save the model
7676
logger.Log(LogLevel.Info, Strings.SavingBestModel);
77-
var modelPath = Path.Combine(@options.OutputPath.FullName, options.Name);
78-
Utils.SaveModel(model, modelPath, $"{options.Name}_model.zip", context);
77+
Utils.SaveModel(model, options.OutputPath.FullName, $"{options.Name}_model.zip", context);
7978

8079
// Generate the Project
8180
GenerateProject(columnInference, pipeline);
@@ -120,7 +119,13 @@ internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline p
120119
internal (Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData)
121120
{
122121
ITransformer model = null;
123-
string label = options.LabelColumnName ?? "Label"; // It is guaranteed training dataview to have Label column
122+
string label = "Label";
123+
124+
if (options.LabelColumnName != null)
125+
{
126+
label = Utils.Sanitize(options.LabelColumnName);
127+
}
128+
124129
Pipeline pipeline = null;
125130

126131
if (taskKind == TaskKind.BinaryClassification)

src/mlnet/Templates/Console/MLProjectGen.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public virtual string TransformText()
3333
<EnableDefaultCompileItems>False</EnableDefaultCompileItems>
3434
</PropertyGroup>
3535
<ItemGroup>
36-
<Compile Include=""Train.cs"" />
36+
<Compile Include=""Program.cs"" />
3737
<Compile Include=""ConsoleHelper.cs"" />
3838
</ItemGroup>
3939
<ItemGroup>

src/mlnet/Templates/Console/MLProjectGen.tt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
<EnableDefaultCompileItems>False</EnableDefaultCompileItems>
1212
</PropertyGroup>
1313
<ItemGroup>
14-
<Compile Include="Train.cs" />
14+
<Compile Include="Program.cs" />
1515
<Compile Include="ConsoleHelper.cs" />
1616
</ItemGroup>
1717
<ItemGroup>

0 commit comments

Comments
 (0)