Skip to content

Commit b92039a

Browse files
srsaggamDmitry-A
authored andcommitted
Format the generated code + bunch of misc tasks (dotnet#152)
* added formatting and minor changes for reordering cv * fixing the template * minor changes * formatting changes * fixed approval test * removed unused nuget * added missing value replacing * added test for new transform * fix test * Update src/mlnet/Templates/Console/MLCodeGen.cs Co-Authored-By: srsaggam <[email protected]>
1 parent 0aeb75f commit b92039a

File tree

8 files changed

+123
-69
lines changed

8 files changed

+123
-69
lines changed

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.GeneratedTrainCodeTest.approved.txt

+39-41
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ using Microsoft.Data.DataView;
99
using Microsoft.ML.LightGBM;
1010

1111

12-
1312
namespace MyNamespace
1413
{
1514
class Program
@@ -41,30 +40,29 @@ namespace MyNamespace
4140
// Data loading
4241
IDataView trainingDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
4342
path: TrainDataPath,
44-
hasHeader : true,
45-
separatorChar : ',',
46-
allowQuotedStrings : true,
47-
trimWhitespace : false ,
48-
supportSparse : true);
43+
hasHeader: true,
44+
separatorChar: ',',
45+
allowQuotedStrings: true,
46+
trimWhitespace: false,
47+
supportSparse: true);
4948
IDataView testDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
5049
path: TestDataPath,
51-
hasHeader : true,
52-
separatorChar : ',',
53-
allowQuotedStrings : true,
54-
trimWhitespace : false ,
55-
supportSparse : true);
56-
57-
// Common data process configuration with pipeline data transformations
50+
hasHeader: true,
51+
separatorChar: ',',
52+
allowQuotedStrings: true,
53+
trimWhitespace: false,
54+
supportSparse: true);
5855

59-
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out",new []{"In"});
56+
// Common data process configuration with pipeline data transformations
57+
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" });
6058

6159
// Set the training algorithm, then create and config the modelBuilder
62-
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options(){NumLeaves=2,Booster=new Options.TreeBooster.Arguments(){},LabelColumn="Label",FeatureColumn="Features"});
63-
60+
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Arguments() { }, LabelColumn = "Label", FeatureColumn = "Features" });
61+
var trainingPipeline = dataProcessPipeline.Append(trainer);
6462

6563
// Train the model fitting to the DataSet
66-
var trainingPipeline = dataProcessPipeline.Append(trainer);
6764
var trainedModel = trainingPipeline.Fit(trainingDataView);
65+
6866
// Evaluate the model and show accuracy stats
6967
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
7068
var predictions = trainedModel.Transform(testDataView);
@@ -86,11 +84,11 @@ namespace MyNamespace
8684
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
8785
IDataView trainingDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
8886
path: TrainDataPath,
89-
hasHeader : true,
90-
separatorChar : ',',
91-
allowQuotedStrings : true,
92-
trimWhitespace : false ,
93-
supportSparse : true);
87+
hasHeader: true,
88+
separatorChar: ',',
89+
allowQuotedStrings: true,
90+
trimWhitespace: false,
91+
supportSparse: true);
9492

9593
var sample = mlContext.CreateEnumerable<SampleObservation>(trainingDataView, false).First();
9694

@@ -101,7 +99,7 @@ namespace MyNamespace
10199
}
102100

103101
// Create prediction engine related to the loaded trained model
104-
var predEngine= trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
102+
var predEngine = trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
105103

106104
//Score
107105
var resultprediction = predEngine.Predict(sample);
@@ -115,29 +113,29 @@ namespace MyNamespace
115113

116114
public class SampleObservation
117115
{
118-
[ColumnName("Label"), LoadColumn(0)]
119-
public bool Label{get; set;}
120-
116+
[ColumnName("Label"), LoadColumn(0)]
117+
public bool Label { get; set; }
118+
119+
120+
[ColumnName("col1"), LoadColumn(1)]
121+
public float Col1 { get; set; }
122+
123+
124+
[ColumnName("col2"), LoadColumn(0)]
125+
public float Col2 { get; set; }
126+
127+
128+
[ColumnName("col3"), LoadColumn(0)]
129+
public string Col3 { get; set; }
121130

122-
[ColumnName("col1"), LoadColumn(1)]
123-
public float Col1{get; set;}
124-
125131

126-
[ColumnName("col2"), LoadColumn(0)]
127-
public float Col2{get; set;}
128-
132+
[ColumnName("col4"), LoadColumn(0)]
133+
public int Col4 { get; set; }
129134

130-
[ColumnName("col3"), LoadColumn(0)]
131-
public string Col3{get; set;}
132-
133135

134-
[ColumnName("col4"), LoadColumn(0)]
135-
public int Col4{get; set;}
136-
136+
[ColumnName("col5"), LoadColumn(0)]
137+
public uint Col5 { get; set; }
137138

138-
[ColumnName("col5"), LoadColumn(0)]
139-
public uint Col5{get; set;}
140-
141139

142140
}
143141

src/mlnet.Test/CodeGenTests.cs

+15
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,21 @@ public void TrainerComplexParameterTest()
170170
}
171171

172172
#region Transform Tests
173+
[TestMethod]
174+
public void MissingValueReplacingTest()
175+
{
176+
var context = new MLContext();
177+
var elementProperties = new Dictionary<string, object>();//categorical
178+
PipelineNode node = new PipelineNode("MissingValueReplacing", PipelineNodeType.Transform, new string[] { "categorical_column_1" }, new string[] { "categorical_column_1" }, elementProperties);
179+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
180+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, (null, null), null);
181+
var actual = codeGenerator.GenerateTransformsAndUsings();
182+
var expectedTransform = "ReplaceMissingValues(new []{new MissingValueReplacingTransformer.ColumnInfo(\"categorical_column_1\",\"categorical_column_1\")})";
183+
string expectedUsings = "using Microsoft.ML.Transforms;\r\n";
184+
Assert.AreEqual(expectedTransform, actual[0].Item1);
185+
Assert.AreEqual(expectedUsings, actual[0].Item2);
186+
}
187+
173188
[TestMethod]
174189
public void OneHotEncodingTest()
175190
{

src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs

+10-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
using System.IO;
88
using System.Linq;
99
using System.Text;
10+
using Microsoft.CodeAnalysis;
11+
using Microsoft.CodeAnalysis.CSharp;
12+
using Microsoft.CodeAnalysis.Formatting;
1013
using Microsoft.ML.Auto;
1114
using Microsoft.ML.CLI.Templates.Console;
1215
using static Microsoft.ML.Data.TextLoader;
@@ -50,15 +53,18 @@ public void GenerateOutput()
5053
var namespaceValue = Normalize(options.OutputName);
5154

5255
// Generate code for training and scoring
53-
var trainScoreCode = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
56+
var trainFileContent = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
57+
var tree = CSharpSyntaxTree.ParseText(trainFileContent);
58+
var syntaxNode = tree.GetRoot();
59+
trainFileContent = Formatter.Format(syntaxNode, new AdhocWorkspace()).ToFullString();
5460

5561
// Generate csproj
56-
var projectSourceCode = GeneratProjectCode();
62+
var projectFileContent = GeneratProjectCode();
5763

5864
// Generate Helper class
59-
var consoleHelperCode = GenerateConsoleHelper(namespaceValue);
65+
var consoleHelperFileContent = GenerateConsoleHelper(namespaceValue);
6066

61-
return (trainScoreCode, projectSourceCode, consoleHelperCode);
67+
return (trainFileContent, projectFileContent, consoleHelperFileContent);
6268
}
6369

6470
internal void WriteOutputToFiles(string trainScoreCode, string projectSourceCode, string consoleHelperCode)

src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ internal static ITransformGenerator GetInstance(PipelineNode node)
3939
case EstimatorName.MissingValueIndicating:
4040
result = new MissingValueIndicator(node);
4141
break;
42-
//todo : add missing value replacing too.
42+
case EstimatorName.MissingValueReplacing:
43+
result = new MissingValueReplacer(node);
44+
break;
4345
case EstimatorName.OneHotHashEncoding:
4446
result = new OneHotHashEncoding(node);
4547
break;

src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs

+36
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,42 @@ public override string GenerateTransformer()
162162
}
163163
}
164164

165+
internal class MissingValueReplacer : TransformGeneratorBase
166+
{
167+
public MissingValueReplacer(PipelineNode node) : base(node)
168+
{
169+
}
170+
171+
internal override string MethodName => "ReplaceMissingValues";
172+
173+
private string ArgumentsName = "MissingValueReplacingTransformer.ColumnInfo";
174+
internal override string Usings => "using Microsoft.ML.Transforms;\r\n";
175+
176+
public override string GenerateTransformer()
177+
{
178+
StringBuilder sb = new StringBuilder();
179+
sb.Append(MethodName);
180+
sb.Append("(");
181+
sb.Append("new []{");
182+
for (int i = 0; i < inputColumns.Length; i++)
183+
{
184+
sb.Append("new ");
185+
sb.Append(ArgumentsName);
186+
sb.Append("(");
187+
sb.Append(outputColumns[i]);
188+
sb.Append(",");
189+
sb.Append(inputColumns[i]);
190+
sb.Append(")");
191+
sb.Append(",");
192+
}
193+
sb.Remove(sb.Length - 1, 1); // remove extra ,
194+
195+
sb.Append("}");
196+
sb.Append(")");
197+
return sb.ToString();
198+
}
199+
}
200+
165201
internal class OneHotHashEncoding : TransformGeneratorBase
166202
{
167203
public OneHotHashEncoding(PipelineNode node) : base(node)

src/mlnet/Templates/Console/MLCodeGen.cs

+11-12
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public virtual string TransformText()
2929
" Microsoft.ML.Core.Data;\r\nusing Microsoft.ML.Data;\r\nusing Microsoft.Data.DataVie" +
3030
"w;\r\n");
3131
this.Write(this.ToStringHelper.ToStringWithCulture(GeneratedUsings));
32-
this.Write("\r\n\r\n\r\nnamespace ");
32+
this.Write("\r\n\r\nnamespace ");
3333
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
3434
this.Write("\r\n{\r\n class Program\r\n {\r\n private static string TrainDataPath = @\"");
3535
this.Write(this.ToStringHelper.ToStringWithCulture(Path));
@@ -93,7 +93,7 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
9393
this.Write("\r\n");
9494
if(Transforms.Count >0 ) {
9595
this.Write(" // Common data process configuration with pipeline data transformatio" +
96-
"ns \r\n\r\n var dataProcessPipeline = ");
96+
"ns\r\n var dataProcessPipeline = ");
9797
for(int i=0;i<Transforms.Count;i++)
9898
{
9999
if(i>0)
@@ -111,7 +111,13 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
111111
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
112112
this.Write(".Trainers.");
113113
this.Write(this.ToStringHelper.ToStringWithCulture(Trainer));
114-
this.Write(";\r\n\r\n");
114+
this.Write(";\r\n");
115+
if (Transforms.Count > 0) {
116+
this.Write(" var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n");
117+
}
118+
else{
119+
this.Write(" var trainingPipeline = trainer;\r\n");
120+
}
115121
if(string.IsNullOrEmpty(TestPath)){
116122
this.Write(@"
117123
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
@@ -135,15 +141,8 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
135141
"rics(trainer.ToString(), crossValidationResults);\r\n");
136142
}
137143
}
138-
this.Write("\r\n // Train the model fitting to the DataSet\r\n");
139-
if(Transforms.Count >0 ) {
140-
this.Write(" var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n " +
141-
" var trainedModel = trainingPipeline.Fit(trainingDataView);\r\n");
142-
}
143-
else{
144-
this.Write(" var trainingPipeline = trainer;\r\n var trainedModel = train" +
145-
"ingPipeline.Fit(trainingDataView);\r\n");
146-
}
144+
this.Write("\r\n // Train the model fitting to the DataSet\r\n var trainedM" +
145+
"odel = trainingPipeline.Fit(trainingDataView);\r\n\r\n");
147146
if(!string.IsNullOrEmpty(TestPath)){
148147
this.Write(" // Evaluate the model and show accuracy stats\r\n Console.Wr" +
149148
"iteLine(\"===== Evaluating Model\'s accuracy with Test data =====\");\r\n " +

src/mlnet/Templates/Console/MLCodeGen.tt

+8-11
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using Microsoft.ML.Data;
1313
using Microsoft.Data.DataView;
1414
<#= GeneratedUsings #>
1515

16-
1716
namespace <#= Namespace #>
1817
{
1918
class Program
@@ -63,8 +62,7 @@ namespace <#= Namespace #>
6362
<# } #>
6463

6564
<# if(Transforms.Count >0 ) {#>
66-
// Common data process configuration with pipeline data transformations
67-
65+
// Common data process configuration with pipeline data transformations
6866
var dataProcessPipeline = <# for(int i=0;i<Transforms.Count;i++)
6967
{
7068
if(i>0)
@@ -79,7 +77,12 @@ namespace <#= Namespace #>
7977

8078
// Set the training algorithm, then create and config the modelBuilder
8179
var trainer = mlContext.<#= TaskType #>.Trainers.<#= Trainer #>;
82-
80+
<# if(Transforms.Count >0 ) {#>
81+
var trainingPipeline = dataProcessPipeline.Append(trainer);
82+
<# }
83+
else{#>
84+
var trainingPipeline = trainer;
85+
<#}#>
8386
<# if(string.IsNullOrEmpty(TestPath)){ #>
8487

8588
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
@@ -95,14 +98,8 @@ namespace <#= Namespace #>
9598
} #>
9699

97100
// Train the model fitting to the DataSet
98-
<# if(Transforms.Count >0 ) {#>
99-
var trainingPipeline = dataProcessPipeline.Append(trainer);
100-
var trainedModel = trainingPipeline.Fit(trainingDataView);
101-
<# }
102-
else{#>
103-
var trainingPipeline = trainer;
104101
var trainedModel = trainingPipeline.Fit(trainingDataView);
105-
<#}#>
102+
106103
<# if(!string.IsNullOrEmpty(TestPath)){ #>
107104
// Evaluate the model and show accuracy stats
108105
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");

src/mlnet/mlnet.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
</PropertyGroup>
1212

1313
<ItemGroup>
14+
<PackageReference Include="Microsoft.CodeAnalysis" Version="2.10.0" />
1415
<PackageReference Include="NLog" Version="4.5.11" />
1516
<PackageReference Include="NLog.Config" Version="4.5.11" />
1617
<PackageReference Include="System.CommandLine.Experimental" Version="0.1.0-alpha-63728-01" />

0 commit comments

Comments
 (0)