Skip to content

Commit cf8c1f4

Browse files
authored
Fix multiclass code gen (dotnet#314)
* compile error in codegen * removes scores printing * fix bugs * fix test
1 parent e152288 commit cf8c1f4

File tree

7 files changed

+78
-26
lines changed

7 files changed

+78
-26
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.PredictProgramCSFileContentTest.approved.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ namespace TestNamespace.Predict
7070
IDataView dataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
7171
path: dataFilePath,
7272
hasHeader: true,
73-
separatorChar: ',');
73+
separatorChar: ',',
74+
allowQuoting: true,
75+
allowSparse: true);
7476

7577
// Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
7678
SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable<SampleObservation>(dataView, false)

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

+12-1
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,18 @@ private static string GeneratPredictProjectFileContent(string namespaceValue, bo
299299

300300
private string GeneratePredictProgramCSFileContent(string namespaceValue)
301301
{
302-
PredictProgram predictProgram = new PredictProgram() { TaskType = settings.MlTask.ToString(), LabelName = settings.LabelName, Namespace = namespaceValue, TestDataPath = settings.TestDataset, TrainDataPath = settings.TrainDataset };
302+
PredictProgram predictProgram = new PredictProgram()
303+
{
304+
TaskType = settings.MlTask.ToString(),
305+
LabelName = settings.LabelName,
306+
Namespace = namespaceValue,
307+
TestDataPath = settings.TestDataset,
308+
TrainDataPath = settings.TrainDataset,
309+
HasHeader = columnInferenceResult.TextLoaderOptions.HasHeader,
310+
Separator = columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(),
311+
AllowQuoting = columnInferenceResult.TextLoaderOptions.AllowQuoting,
312+
AllowSparse = columnInferenceResult.TextLoaderOptions.AllowSparse,
313+
};
303314
return predictProgram.TransformText();
304315
}
305316
#endregion

src/mlnet/Templates/Console/PredictProgram.cs

+50-18
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace Microsoft.ML.CLI.Templates.Console
1111
{
1212
using System.Linq;
1313
using System.Text;
14+
using System.Text.RegularExpressions;
1415
using System.Collections.Generic;
1516
using Microsoft.ML.CLI.Utilities;
1617
using System;
@@ -44,14 +45,14 @@ public virtual string TransformText()
4445
using Microsoft.Data.DataView;
4546
using ");
4647

47-
#line 20 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
48+
#line 21 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
4849
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
4950

5051
#line default
5152
#line hidden
5253
this.Write(".Model.DataModels;\r\n\r\n\r\nnamespace ");
5354

54-
#line 23 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
55+
#line 24 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
5556
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
5657

5758
#line default
@@ -60,35 +61,35 @@ public virtual string TransformText()
6061
"nd use for predictions\r\n private const string MODEL_FILEPATH = @\"MLModel." +
6162
"zip\";\r\n\r\n //Dataset to use for predictions \r\n");
6263

63-
#line 31 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
64+
#line 32 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
6465
if(string.IsNullOrEmpty(TestDataPath)){
6566

6667
#line default
6768
#line hidden
6869
this.Write(" private const string DATA_FILEPATH = @\"");
6970

70-
#line 32 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
71+
#line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
7172
this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath));
7273

7374
#line default
7475
#line hidden
7576
this.Write("\";\r\n");
7677

77-
#line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
78+
#line 34 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
7879
} else{
7980

8081
#line default
8182
#line hidden
8283
this.Write(" private const string DATA_FILEPATH = @\"");
8384

84-
#line 34 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
85+
#line 35 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
8586
this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath));
8687

8788
#line default
8889
#line hidden
8990
this.Write("\";\r\n");
9091

91-
#line 35 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
92+
#line 36 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
9293
}
9394

9495
#line default
@@ -120,50 +121,50 @@ private static void Predict(MLContext mlContext, ITransformer mlModel, SampleObs
120121
var predictionResult = predEngine.Predict(sampleData);
121122
");
122123

123-
#line 61 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
124+
#line 62 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
124125
if("BinaryClassification".Equals(TaskType)){
125126

126127
#line default
127128
#line hidden
128129
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
129130

130-
#line 62 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
131+
#line 63 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
131132
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
132133

133134
#line default
134135
#line hidden
135136
this.Write("} | Predicted value: {predictionResult.Prediction}\");\r\n");
136137

137-
#line 63 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
138+
#line 64 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
138139
}else if("Regression".Equals(TaskType)){
139140

140141
#line default
141142
#line hidden
142143
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
143144

144-
#line 64 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
145+
#line 65 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
145146
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
146147

147148
#line default
148149
#line hidden
149150
this.Write("} | Predicted value: {predictionResult.Score}\");\r\n");
150151

151-
#line 65 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
152+
#line 66 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
152153
} else if("MulticlassClassification".Equals(TaskType)){
153154

154155
#line default
155156
#line hidden
156157
this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData.");
157158

158-
#line 66 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
159+
#line 67 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
159160
this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName)));
160161

161162
#line default
162163
#line hidden
163164
this.Write("} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.J" +
164-
"oin(\\\", \\\", resultprediction.Scores)}]\");\r\n");
165+
"oin(\",\", predictionResult.Score)}]\");\r\n");
165166

166-
#line 67 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
167+
#line 68 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
167168
}
168169

169170
#line default
@@ -188,8 +189,35 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
188189
// Read dataset to get a single row for trying a prediction
189190
IDataView dataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
190191
path: dataFilePath,
191-
hasHeader: true,
192-
separatorChar: ',');
192+
hasHeader : ");
193+
194+
#line 89 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
195+
this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant()));
196+
197+
#line default
198+
#line hidden
199+
this.Write(",\r\n separatorChar : \'");
200+
201+
#line 90 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
202+
this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString())));
203+
204+
#line default
205+
#line hidden
206+
this.Write("\',\r\n allowQuoting : ");
207+
208+
#line 91 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
209+
this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant()));
210+
211+
#line default
212+
#line hidden
213+
this.Write(",\r\n allowSparse: ");
214+
215+
#line 92 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
216+
this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant()));
217+
218+
#line default
219+
#line hidden
220+
this.Write(@");
193221
194222
// Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
195223
SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable<SampleObservation>(dataView, false)
@@ -202,13 +230,17 @@ private static SampleObservation CreateSingleDataSample(MLContext mlContext, str
202230
return this.GenerationEnvironment.ToString();
203231
}
204232

205-
#line 98 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
233+
#line 101 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt"
206234

207235
public string TaskType {get;set;}
208236
public string Namespace {get;set;}
209237
public string LabelName {get;set;}
210238
public string TestDataPath {get;set;}
211239
public string TrainDataPath {get;set;}
240+
public char Separator {get;set;}
241+
public bool AllowQuoting {get;set;}
242+
public bool AllowSparse {get;set;}
243+
public bool HasHeader {get;set;}
212244

213245

214246
#line default

src/mlnet/Templates/Console/PredictProgram.tt

+10-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
<#@ assembly name="System.Core" #>
33
<#@ import namespace="System.Linq" #>
44
<#@ import namespace="System.Text" #>
5+
<#@ import namespace="System.Text.RegularExpressions" #>
56
<#@ import namespace="System.Collections.Generic" #>
67
<#@ import namespace="Microsoft.ML.CLI.Utilities" #>
78
//*****************************************************************************************
@@ -63,7 +64,7 @@ namespace <#= Namespace #>.Predict
6364
<#}else if("Regression".Equals(TaskType)){#>
6465
Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Score}");
6566
<#} else if("MulticlassClassification".Equals(TaskType)){#>
66-
Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.Join(\", \", resultprediction.Scores)}]");
67+
Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.Join(",", predictionResult.Score)}]");
6768
<#}#>
6869
}
6970

@@ -85,8 +86,10 @@ namespace <#= Namespace #>.Predict
8586
// Read dataset to get a single row for trying a prediction
8687
IDataView dataView = mlContext.Data.LoadFromTextFile<SampleObservation>(
8788
path: dataFilePath,
88-
hasHeader: true,
89-
separatorChar: ',');
89+
hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>,
90+
separatorChar : '<#= Regex.Escape(Separator.ToString()) #>',
91+
allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>,
92+
allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>);
9093

9194
// Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file.
9295
SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable<SampleObservation>(dataView, false)
@@ -101,4 +104,8 @@ public string Namespace {get;set;}
101104
public string LabelName {get;set;}
102105
public string TestDataPath {get;set;}
103106
public string TrainDataPath {get;set;}
107+
public char Separator {get;set;}
108+
public bool AllowQuoting {get;set;}
109+
public bool AllowSparse {get;set;}
110+
public bool HasHeader {get;set;}
104111
#>

src/mlnet/Templates/Console/PredictionClass.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ namespace ");
8282

8383
#line default
8484
#line hidden
85-
this.Write(" public float[] Scores { get; set; }\r\n");
85+
this.Write(" public float[] Score { get; set; }\r\n");
8686

8787
#line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt"
8888
}else{

src/mlnet/Templates/Console/PredictionClass.tt

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace <#= Namespace #>.Model.DataModels
2929
public <#= PredictionLabelType#> Prediction { get; set; }
3030
<# }#>
3131
<#if("MulticlassClassification".Equals(TaskType)){ #>
32-
public float[] Scores { get; set; }
32+
public float[] Score { get; set; }
3333
<#}else{ #>
3434
public float Score { get; set; }
3535
<#}#>

src/mlnet/Utilities/Utils.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ internal static void CreateSolutionFile(string solutionFile, string outputPath)
200200
var proc = new System.Diagnostics.Process();
201201
proc.StartInfo.FileName = @"dotnet";
202202

203-
proc.StartInfo.Arguments = $"new sln --name {solutionFile} --output {outputPath}";
203+
proc.StartInfo.Arguments = $"new sln --name {solutionFile} --output {outputPath} --force";
204204
proc.StartInfo.UseShellExecute = false;
205205
proc.StartInfo.RedirectStandardOutput = true;
206206
proc.Start();

0 commit comments

Comments
 (0)