Skip to content

Commit 974b2d7

Browse files
authored
Upgrading CLI to produce ML.NET V.10 APIs and bunch of Refactoring tasks (dotnet#65)
* Added sequential grouping of columns * reverted the file * upgrade to v .10 and refactoring * added null check * fixed unit tests * review comments * removed the settings change * added regions * fixed unit tests
1 parent 4748f03 commit 974b2d7

15 files changed

+248
-168
lines changed

src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public static Action<T> CreateArgsFunc<T>(IEnumerable<SweepableParam> sweepParam
5757
}
5858

5959
private static string[] _lightGbmTreeBoosterParamNames = new[] { "RegLambda", "RegAlpha" };
60-
private const string LightGbmTreeBoosterPropName = "TreeBooster";
60+
private const string LightGbmTreeBoosterPropName = "Booster";
6161

6262
public static Action<LightGbmArguments> CreateLightGbmArgsFunc(IEnumerable<SweepableParam> sweepParams)
6363
{
@@ -92,7 +92,7 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume
9292
var parentArgParams = sweepParams.Except(treeBoosterParams);
9393

9494
var treeBoosterProps = treeBoosterParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
95-
var treeBoosterCustomProp = new CustomProperty("LightGbmArguments.TreeBooster.Arguments", treeBoosterProps);
95+
var treeBoosterCustomProp = new CustomProperty("Options.TreeBooster.Arguments", treeBoosterProps);
9696

9797
var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
9898
props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp;

src/Test/TrainerExtensionsTests.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ public void BuildPipelineNodePropsLightGbm()
7373
""MaxCatThreshold"": 16,
7474
""CatSmooth"": 10,
7575
""CatL2"": 0.5,
76-
""TreeBooster"": {
77-
""Name"": ""LightGbmArguments.TreeBooster.Arguments"",
76+
""Booster"": {
77+
""Name"": ""Options.TreeBooster.Arguments"",
7878
""Properties"": {
7979
""RegLambda"": 0.5,
8080
""RegAlpha"": 0.5
@@ -114,8 +114,8 @@ public void BuildParameterSetLightGbm()
114114
{
115115
{"NumBoostRound", 1 },
116116
{"LearningRate", 1 },
117-
{"TreeBooster", new CustomProperty() {
118-
Name = "Microsoft.ML.LightGBM.TreeBooster",
117+
{"Booster", new CustomProperty() {
118+
Name = "Options.TreeBooster.Arguments",
119119
Properties = new Dictionary<string, object>()
120120
{
121121
{"RegLambda", 1 },

src/mlnet.Test/CodeGenTests.cs

+32-24
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ public void TrainerGeneratorBasicNamedParameterTest()
2424
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
2525
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
2626
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
27-
var actual = codeGenerator.GenerateTrainer();
27+
var actual = codeGenerator.GenerateTrainerAndUsings();
2828
string expected = "LightGbm(learningRate:0.1f,numLeaves:1,labelColumn:\"Label\",featureColumn:\"Features\");";
29-
Assert.AreEqual(expected, actual);
29+
Assert.AreEqual(expected, actual.Item1);
30+
Assert.IsNull(actual.Item2);
3031
}
3132

3233
[TestMethod]
@@ -43,9 +44,11 @@ public void TrainerGeneratorBasicAdvancedParameterTest()
4344
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
4445
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
4546
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
46-
var actual = codeGenerator.GenerateTrainer();
47-
string expected = "LightGbm(new LightGbm.Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
48-
Assert.AreEqual(expected, actual);
47+
var actual = codeGenerator.GenerateTrainerAndUsings();
48+
string expectedTrainer = "LightGbm(new Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
49+
string expectedUsing = "using Microsoft.ML.LightGBM;\r\n";
50+
Assert.AreEqual(expectedTrainer, actual.Item1);
51+
Assert.AreEqual(expectedUsing, actual.Item2);
4952
}
5053

5154
[TestMethod]
@@ -56,9 +59,25 @@ public void TransformGeneratorBasicTest()
5659
PipelineNode node = new PipelineNode("Normalizing", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
5760
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
5861
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
59-
var actual = codeGenerator.GenerateTransforms();
62+
var actual = codeGenerator.GenerateTransformsAndUsings();
6063
string expected = "Normalize(\"Label\",\"Label\")";
61-
Assert.AreEqual(expected, actual[0]);
64+
Assert.AreEqual(expected, actual[0].Item1);
65+
Assert.IsNull(actual[0].Item2);
66+
}
67+
68+
[TestMethod]
69+
public void TransformGeneratorUsingTest()
70+
{
71+
var context = new MLContext();
72+
var elementProperties = new Dictionary<string, object>();
73+
PipelineNode node = new PipelineNode("OneHotEncoding", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
74+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
75+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
76+
var actual = codeGenerator.GenerateTransformsAndUsings();
77+
string expectedTransform = "Categorical.OneHotEncoding(new []{new OneHotEncodingEstimator.ColumnInfo(\"Label\",\"Label\")})";
78+
var expectedUsings = "using Microsoft.ML.Transforms.Categorical;\r\n";
79+
Assert.AreEqual(expectedTransform, actual[0].Item1);
80+
Assert.AreEqual(expectedUsings, actual[0].Item2);
6281
}
6382

6483
[TestMethod]
@@ -79,19 +98,6 @@ public void ClassLabelGenerationBasicTest()
7998
Assert.AreEqual(expected2, actual[1]);
8099
}
81100

82-
[TestMethod]
83-
public void GenerateUsingsBasicTest()
84-
{
85-
var context = new MLContext();
86-
var elementProperties = new Dictionary<string, object>();
87-
PipelineNode node = new PipelineNode("TypeConverting", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
88-
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
89-
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
90-
var actual = codeGenerator.GenerateUsings();
91-
string expected = "using Microsoft.ML.Transforms.Conversions;\r\n";
92-
Assert.AreEqual(expected, actual);
93-
}
94-
95101
[TestMethod]
96102
public void ColumnGenerationTest()
97103
{
@@ -122,14 +128,16 @@ public void TrainerComplexParameterTest()
122128

123129
var elementProperties = new Dictionary<string, object>()
124130
{
125-
{"TreeBooster", new CustomProperty(){Properties= new Dictionary<string, object>(), Name = "TreeBooster"} },
131+
{"Booster", new CustomProperty(){Properties= new Dictionary<string, object>(), Name = "TreeBooster"} },
126132
};
127133
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
128134
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
129135
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
130-
var actual = codeGenerator.GenerateTrainer();
131-
string expected = "LightGbm(new LightGbm.Options(){TreeBooster=new TreeBooster(){},LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
132-
Assert.AreEqual(expected, actual);
136+
var actual = codeGenerator.GenerateTrainerAndUsings();
137+
string expectedTrainer = "LightGbm(new Options(){Booster=new TreeBooster(){},LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
138+
var expectedUsings = "using Microsoft.ML.LightGBM;\r\n";
139+
Assert.AreEqual(expectedTrainer, actual.Item1);
140+
Assert.AreEqual(expectedUsings, actual.Item2);
133141

134142
}
135143

src/mlnet/CodeGenerator/CodeGenerator.cs

+6-43
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,25 @@ public CodeGenerator(Pipeline pipelineToDeconstruct, ColumnInferenceResult colum
2121
this.pipeline = pipelineToDeconstruct;
2222
this.columnInferenceResult = columnInferenceResult;
2323
}
24-
internal IList<string> GenerateTransforms()
24+
internal IList<(string, string)> GenerateTransformsAndUsings()
2525
{
2626
var nodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Transform);
27-
var results = new List<string>();
27+
var results = new List<(string, string)>();
2828
foreach (var node in nodes)
2929
{
3030
ITransformGenerator generator = TransformGeneratorFactory.GetInstance(node);
31-
results.Add(generator.GenerateTransformer());
31+
results.Add((generator.GenerateTransformer(), generator.GenerateUsings()));
3232
}
3333

3434
return results;
3535
}
3636

37-
internal string GenerateTrainer()
37+
internal (string, string) GenerateTrainerAndUsings()
3838
{
3939
ITrainerGenerator generator = TrainerGeneratorFactory.GetInstance(pipeline);
4040
var trainerString = generator.GenerateTrainer();
41-
return trainerString;
41+
var trainerUsings = generator.GenerateUsings();
42+
return (trainerString, trainerUsings);
4243
}
4344

4445
internal IList<string> GenerateClassLabels()
@@ -149,44 +150,6 @@ private static string ConstructColumnDefinition(Column column)
149150
return def;
150151
}
151152

152-
internal string GenerateUsings()
153-
{
154-
var trainerNodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer);
155-
var transformNodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Transform);
156-
157-
StringBuilder sb = new StringBuilder();
158-
159-
foreach (var node in trainerNodes)
160-
{
161-
if (Enum.TryParse<TrainerName>(node.Name, out TrainerName nodeName))
162-
{
163-
if (nodeName == TrainerName.LightGbmBinary || nodeName == TrainerName.LightGbmMulti || nodeName == TrainerName.LightGbmRegression)
164-
{
165-
sb.Append("using Microsoft.ML.LightGBM;");
166-
sb.Append("\r\n");
167-
}
168-
}
169-
}
170-
171-
foreach (var node in transformNodes)
172-
{
173-
if (Enum.TryParse<EstimatorName>(node.Name, out EstimatorName nodeName))
174-
{
175-
if (nodeName == EstimatorName.OneHotEncoding || nodeName == EstimatorName.OneHotHashEncoding)
176-
{
177-
sb.Append("using Microsoft.ML.Transforms.Categorical;");
178-
sb.Append("\r\n");
179-
}
180-
if (nodeName == EstimatorName.TypeConverting)
181-
{
182-
sb.Append("using Microsoft.ML.Transforms.Conversions;");
183-
sb.Append("\r\n");
184-
}
185-
}
186-
}
187-
return sb.ToString();
188-
}
189-
190153
private static string Normalize(string inputColumn)
191154
{
192155
//check if first character is int

src/mlnet/CodeGenerator/TrainerGeneratorBase.cs

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ internal abstract class TrainerGeneratorBase : ITrainerGenerator
2525
internal abstract string OptionsName { get; }
2626
internal abstract string MethodName { get; }
2727
internal abstract IDictionary<string, string> NamedParameters { get; }
28+
internal abstract string Usings { get; }
2829

2930
/// <summary>
3031
/// Generates an instance of TrainerGenerator
@@ -136,5 +137,12 @@ public string GenerateTrainer()
136137
return sb.ToString();
137138
}
138139

140+
public string GenerateUsings()
141+
{
142+
if (hasAdvancedSettings)
143+
return Usings;
144+
145+
return null;
146+
}
139147
}
140148
}

src/mlnet/CodeGenerator/TrainerGeneratorFactory.cs

+10-8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace Microsoft.ML.CLI
1212
internal interface ITrainerGenerator
1313
{
1414
string GenerateTrainer();
15+
string GenerateUsings();
1516
}
1617
internal static class TrainerGeneratorFactory
1718
{
@@ -33,30 +34,31 @@ internal static ITrainerGenerator GetInstance(Pipeline pipeline)
3334
case TrainerName.AveragedPerceptronBinary:
3435
return new AveragedPerceptron(node);
3536
case TrainerName.FastForestBinary:
37+
return new FastForestClassification(node);
3638
case TrainerName.FastForestRegression:
37-
return new FastForest(node);
39+
return new FastForestRegression(node);
3840
case TrainerName.FastTreeBinary:
41+
return new FastTreeClassification(node);
3942
case TrainerName.FastTreeRegression:
40-
return new FastTree(node);
43+
return new FastTreeRegression(node);
4144
case TrainerName.FastTreeTweedieRegression:
4245
return new FastTreeTweedie(node);
43-
4446
case TrainerName.LinearSvmBinary:
4547
return new LinearSvm(node);
4648
case TrainerName.LogisticRegressionBinary:
47-
case TrainerName.LogisticRegressionMulti:
48-
return new LogisticRegression(node);
49+
return new LogisticRegressionBinary(node);
4950
case TrainerName.OnlineGradientDescentRegression:
5051
return new OnlineGradientDescentRegression(node);
5152
case TrainerName.OrdinaryLeastSquaresRegression:
5253
return new OrdinaryLeastSquaresRegression(node);
5354
case TrainerName.PoissonRegression:
5455
return new PoissonRegression(node);
5556
case TrainerName.SdcaBinary:
56-
case TrainerName.SdcaMulti:
57-
return new StochasticDualCoordinateAscent(node);
57+
return new StochasticDualCoordinateAscentBinary(node);
58+
case TrainerName.SdcaRegression:
59+
return new StochasticDualCoordinateAscentRegression(node);
5860
case TrainerName.StochasticGradientDescentBinary:
59-
return new StochasticGradientDescent(node);
61+
return new StochasticGradientDescentClassification(node);
6062
case TrainerName.SymSgdBinary:
6163
return new SymbolicStochasticGradientDescent(node);
6264
default:

0 commit comments

Comments
 (0)