|
12 | 12 | using Microsoft.CodeAnalysis.Formatting;
|
13 | 13 | using Microsoft.ML.Auto;
|
14 | 14 | using Microsoft.ML.CLI.Templates.Console;
|
| 15 | +using Microsoft.ML.CLI.Utilities; |
15 | 16 | using static Microsoft.ML.Data.TextLoader;
|
16 | 17 |
|
17 | 18 | namespace Microsoft.ML.CLI.CodeGenerator.CSharp
|
@@ -69,14 +70,13 @@ public void GenerateOutput()
|
69 | 70 |
|
70 | 71 | internal void WriteOutputToFiles(string trainScoreCode, string projectSourceCode, string consoleHelperCode)
|
71 | 72 | {
|
72 |
| - var outputFolder = Path.Combine(options.OutputBaseDir, options.OutputName); |
73 |
| - if (!Directory.Exists(outputFolder)) |
| 73 | + if (!Directory.Exists(options.OutputBaseDir)) |
74 | 74 | {
|
75 |
| - Directory.CreateDirectory(outputFolder); |
| 75 | + Directory.CreateDirectory(options.OutputBaseDir); |
76 | 76 | }
|
77 |
| - File.WriteAllText($"{outputFolder}/Train.cs", trainScoreCode); |
78 |
| - File.WriteAllText($"{outputFolder}/{options.OutputName}.csproj", projectSourceCode); |
79 |
| - File.WriteAllText($"{outputFolder}/ConsoleHelper.cs", consoleHelperCode); |
| 77 | + File.WriteAllText($"{options.OutputBaseDir}/Program.cs", trainScoreCode); |
| 78 | + File.WriteAllText($"{options.OutputBaseDir}/{options.OutputName}.csproj", projectSourceCode); |
| 79 | + File.WriteAllText($"{options.OutputBaseDir}/ConsoleHelper.cs", consoleHelperCode); |
80 | 80 | }
|
81 | 81 |
|
82 | 82 | internal static string GenerateConsoleHelper(string namespaceValue)
|
@@ -165,6 +165,7 @@ internal string GenerateTrainCode(string usings, string trainer, List<string> tr
|
165 | 165 | internal IList<string> GenerateClassLabels()
|
166 | 166 | {
|
167 | 167 | IList<string> result = new List<string>();
|
| 168 | + var label_column = Utils.Sanitize(columnInferenceResult.ColumnInformation.LabelColumn); |
168 | 169 | foreach (var column in columnInferenceResult.TextLoaderArgs.Column)
|
169 | 170 | {
|
170 | 171 | StringBuilder sb = new StringBuilder();
|
@@ -213,7 +214,15 @@ internal IList<string> GenerateClassLabels()
|
213 | 214 | result.Add($"[ColumnName(\"{column.Name}\"), LoadColumn({column.Source[0].Min})]");
|
214 | 215 | }
|
215 | 216 | sb.Append(" ");
|
216 |
| - sb.Append(Normalize(column.Name)); |
| 217 | + if (column.Name.Equals(label_column)) |
| 218 | + { |
| 219 | + sb.Append("Label"); |
| 220 | + } |
| 221 | + else |
| 222 | + { |
| 223 | + sb.Append(Normalize(column.Name)); |
| 224 | + } |
| 225 | + |
217 | 226 | sb.Append("{get; set;}");
|
218 | 227 | result.Add(sb.ToString());
|
219 | 228 | result.Add("\r\n");
|
|
0 commit comments