Skip to content

Commit 215b0c9

Browse files
srsaggamDmitry-A
authored andcommitted
Initial version of CLI tool for mlnet (dotnet#61)
* added global tool initial project * removed unneccesary files, renamed files * refactoring and added base abstract classes for trainer generator * removed unused class * Added classes for transforms * added transform generate dummy classes * more refactoring, added first transform * more refactoring and added classes * changed the project structure * restructing added options class * sln changes * refactored options to different class: * added more logic for code generation of class * misc changes * reverted file * added commandline api package * reverted sample * added new command line api parser * added normalization of column names * Added command defaults and error message * implementation of all trainers * changed auto to null * added all transform generators * added error handling when args is empty and minor changes due to change in AutoML api names * changed the name of param * added new command line options and restructuring code * renamed proj file and added solution * Added code to generate usings, Fixed few bugs in the code * added validation to the command line options * changed project name * Bug fixes due to API change in AutoML * changed directory structure * added test framework and basic tests * added more tests * added improvements to template and error handling * renamed the estimator name * fixed test case * added comments * added headers * changed namespace and removed unneccesary properties from project * Revert "changed namespace and removed unneccesary properties from project" This reverts commit 9edae033e9845e910f663f296e168f1182b84f5f. * fixed test cases and renamed namespaces * cleaned up proj file * added folder structure * added symbols/tokens for strings * added more tests * review comments * modified test cases * review comments * change in the exception message * normalized line endings * made method private static * simplified range building /optimization * minor fix * added header * added static methods in command where necessary * nit picks * made few methods static * review comments * nitpick * remove line pragmas * fix test case
1 parent 1c1828a commit 215b0c9

30 files changed

+3727
-18
lines changed

AutoML.sln

+28
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Samples", "src\Samples\Samp
99
EndProject
1010
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Test", "src\Test\Test.csproj", "{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}"
1111
EndProject
12+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet", "src\mlnet\mlnet.csproj", "{ED714FA5-6F89-401B-9E7F-CADF1373C553}"
13+
EndProject
14+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet.Test", "src\mlnet.Test\mlnet.Test.csproj", "{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}"
15+
EndProject
1216
Global
1317
GlobalSection(SolutionConfigurationPlatforms) = preSolution
1418
Debug|Any CPU = Debug|Any CPU
@@ -55,6 +59,30 @@ Global
5559
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
5660
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
5761
{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
62+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
63+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.Build.0 = Debug|Any CPU
64+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
65+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
66+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
67+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
68+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.ActiveCfg = Release|Any CPU
69+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.Build.0 = Release|Any CPU
70+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
71+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
72+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
73+
{ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
74+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
75+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.Build.0 = Debug|Any CPU
76+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
77+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
78+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
79+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
80+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.ActiveCfg = Release|Any CPU
81+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.Build.0 = Release|Any CPU
82+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
83+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
84+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
85+
{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
5886
EndGlobalSection
5987
GlobalSection(SolutionProperties) = preSolution
6088
HideSolutionNode = FALSE

NuGet.Config

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<configuration>
3+
<packageSources>
4+
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
5+
</packageSources>
6+
</configuration>

src/AutoML/Assembly.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44

55
using System.Runtime.CompilerServices;
66

7-
[assembly: InternalsVisibleTo("Test, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
7+
[assembly: InternalsVisibleTo("Test, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
8+
[assembly: InternalsVisibleTo("mlnet, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
9+
[assembly: InternalsVisibleTo("mlnet.Test, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]

src/AutoML/Sweepers/SmacSweeper.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ private FastForestRegressionModelParameters FitModel(IEnumerable<IRunResult> pre
111111

112112
IDataView data = dvBuilder.GetDataView();
113113
AutoMlUtils.Assert(data.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations");
114-
114+
115115
// Set relevant random forest arguments.
116116
// Train random forest.
117117
var trainer = new FastForestRegression(_context, DefaultColumnNames.Label, DefaultColumnNames.Features, advancedSettings: s =>
@@ -195,7 +195,7 @@ private ParameterSet[] GreedyPlusRandomSearch(ParameterSet[] parents, FastForest
195195
var retainedConfigs = new HashSet<ParameterSet>(bestConfigurations.Select(x => x.Item2));
196196

197197
// remove configurations matching previous run
198-
foreach(var previousRun in previousRuns)
198+
foreach (var previousRun in previousRuns)
199199
{
200200
retainedConfigs.Remove(previousRun.ParameterSet);
201201
}

src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public static Action<LightGbmArguments> CreateLightGbmArgsFunc(IEnumerable<Sweep
7777

7878
public static IDictionary<string, object> BuildPipelineNodeProps(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams)
7979
{
80-
if(trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
80+
if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
8181
trainerName == TrainerName.LightGbmRegression)
8282
{
8383
return BuildLightGbmPipelineNodeProps(sweepParams);
@@ -92,11 +92,11 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume
9292
var parentArgParams = sweepParams.Except(treeBoosterParams);
9393

9494
var treeBoosterProps = treeBoosterParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
95-
var treeBoosterCustomProp = new CustomProperty("Microsoft.ML.LightGBM.TreeBooster", treeBoosterProps);
95+
var treeBoosterCustomProp = new CustomProperty("LightGbmArguments.TreeBooster.Arguments", treeBoosterProps);
9696

9797
var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
9898
props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp;
99-
99+
100100
return props;
101101
}
102102

src/Test/TrainerExtensionsTests.cs

+8-8
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public void TrainerExtensionInstanceTests()
1717
{
1818
var context = new MLContext();
1919
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
20-
foreach(var trainerName in trainerNames)
20+
foreach (var trainerName in trainerNames)
2121
{
2222
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
2323
var instance = extension.CreateInstance(context, null);
@@ -33,7 +33,7 @@ public void GetTrainersByMaxIterations()
3333
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
3434
TaskKind.MulticlassClassification, TaskKind.Regression };
3535

36-
foreach(var task in tasks)
36+
foreach (var task in tasks)
3737
{
3838
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
3939
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
@@ -52,7 +52,7 @@ public void GetTrainersByMaxIterations()
5252
public void BuildPipelineNodePropsLightGbm()
5353
{
5454
var sweepParams = SweepableParams.BuildLightGbmParams();
55-
foreach(var sweepParam in sweepParams)
55+
foreach (var sweepParam in sweepParams)
5656
{
5757
sweepParam.RawValue = 1;
5858
}
@@ -74,7 +74,7 @@ public void BuildPipelineNodePropsLightGbm()
7474
""CatSmooth"": 10,
7575
""CatL2"": 0.5,
7676
""TreeBooster"": {
77-
""Name"": ""Microsoft.ML.LightGBM.TreeBooster"",
77+
""Name"": ""LightGbmArguments.TreeBooster.Arguments"",
7878
""Properties"": {
7979
""RegLambda"": 0.5,
8080
""RegAlpha"": 0.5
@@ -90,7 +90,7 @@ public void BuildPipelineNodePropsLightGbm()
9090
public void BuildPipelineNodePropsSdca()
9191
{
9292
var sweepParams = SweepableParams.BuildSdcaParams();
93-
foreach(var sweepParam in sweepParams)
93+
foreach (var sweepParam in sweepParams)
9494
{
9595
sweepParam.RawValue = 1;
9696
}
@@ -106,7 +106,7 @@ public void BuildPipelineNodePropsSdca()
106106
}";
107107
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
108108
}
109-
109+
110110
[TestMethod]
111111
public void BuildParameterSetLightGbm()
112112
{
@@ -127,7 +127,7 @@ public void BuildParameterSetLightGbm()
127127
var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props);
128128
var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props);
129129

130-
foreach(var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
130+
foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
131131
{
132132
Assert.AreEqual(4, paramSet.Count);
133133
Assert.AreEqual("1", paramSet["NumBoostRound"].ValueText);
@@ -146,7 +146,7 @@ public void BuildParameterSetSdca()
146146
};
147147

148148
var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaBinary, props);
149-
149+
150150
Assert.AreEqual(1, sdcaParams.Count);
151151
Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText);
152152
}

src/Test/UserInputValidationTests.cs

+6-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public void ValidateCreateTextReaderArgsNullColumn()
4444
public void ValidateCreateTextReaderArgsColumnWithNullSoure()
4545
{
4646
var input = new ColumnInferenceResult(
47-
new List<(TextLoader.Column, ColumnPurpose)>() { (new TextLoader.Column() { Name = "Column", Type = DataKind.R4 } , ColumnPurpose.CategoricalFeature) },
47+
new List<(TextLoader.Column, ColumnPurpose)>() { (new TextLoader.Column() { Name = "Column", Type = DataKind.R4 }, ColumnPurpose.CategoricalFeature) },
4848
false, false, "\t", false, false);
4949
UserInputValidationUtil.ValidateCreateTextReaderArgs(input);
5050
}
@@ -63,7 +63,7 @@ public void ValidateCreateTextReaderArgsNullSeparator()
6363
[ExpectedException(typeof(ArgumentNullException))]
6464
public void ValidateAutoFitNullTrainData()
6565
{
66-
UserInputValidationUtil.ValidateAutoFitArgs(null, DatasetUtil.UciAdultLabel,
66+
UserInputValidationUtil.ValidateAutoFitArgs(null, DatasetUtil.UciAdultLabel,
6767
DatasetUtil.GetUciAdultDataView(), null, null);
6868
}
6969

@@ -89,8 +89,10 @@ public void ValidateAutoFitArgsZeroMaxIterations()
8989
{
9090
UserInputValidationUtil.ValidateAutoFitArgs(DatasetUtil.GetUciAdultDataView(),
9191
DatasetUtil.UciAdultLabel, DatasetUtil.GetUciAdultDataView(),
92-
new AutoFitSettings() {
93-
StoppingCriteria = new ExperimentStoppingCriteria() {
92+
new AutoFitSettings()
93+
{
94+
StoppingCriteria = new ExperimentStoppingCriteria()
95+
{
9496
MaxIterations = 0,
9597
}
9698
}, null);

src/mlnet.Test/CodeGenTests.cs

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
using System.Collections.Generic;
2+
using Microsoft.ML;
3+
using Microsoft.ML.Auto;
4+
using Microsoft.ML.Data;
5+
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
using Microsoft.ML.CLI;
7+
using System;
8+
9+
namespace mlnet.Test
10+
{
11+
[TestClass]
12+
public class CodeGeneratorTests
13+
{
14+
[TestMethod]
15+
public void TrainerGeneratorBasicNamedParameterTest()
16+
{
17+
var context = new MLContext();
18+
19+
var elementProperties = new Dictionary<string, object>()
20+
{
21+
{"LearningRate", 0.1f },
22+
{"NumLeaves", 1 },
23+
};
24+
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
25+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
26+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
27+
var actual = codeGenerator.GenerateTrainer();
28+
string expected = "LightGbm(learningRate:0.1f,numLeaves:1,labelColumn:\"Label\",featureColumn:\"Features\");";
29+
Assert.AreEqual(expected, actual);
30+
}
31+
32+
[TestMethod]
33+
public void TrainerGeneratorBasicAdvancedParameterTest()
34+
{
35+
var context = new MLContext();
36+
37+
var elementProperties = new Dictionary<string, object>()
38+
{
39+
{"LearningRate", 0.1f },
40+
{"NumLeaves", 1 },
41+
{"UseSoftmax", true }
42+
};
43+
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
44+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
45+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
46+
var actual = codeGenerator.GenerateTrainer();
47+
string expected = "LightGbm(new LightGbm.Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
48+
Assert.AreEqual(expected, actual);
49+
}
50+
51+
[TestMethod]
52+
public void TransformGeneratorBasicTest()
53+
{
54+
var context = new MLContext();
55+
var elementProperties = new Dictionary<string, object>();
56+
PipelineNode node = new PipelineNode("Normalizing", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
57+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
58+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
59+
var actual = codeGenerator.GenerateTransforms();
60+
string expected = "Normalize(\"Label\",\"Label\")";
61+
Assert.AreEqual(expected, actual[0]);
62+
}
63+
64+
[TestMethod]
65+
public void ClassLabelGenerationBasicTest()
66+
{
67+
List<(TextLoader.Column, ColumnPurpose)> list = new List<(TextLoader.Column, ColumnPurpose)>()
68+
{
69+
(new TextLoader.Column(){ Name = "Label", Source = new TextLoader.Range[]{new TextLoader.Range(0) }, Type = DataKind.Bool }, ColumnPurpose.Label),
70+
};
71+
ColumnInferenceResult result = new ColumnInferenceResult(list, false, false, ",", true, true);
72+
73+
CodeGenerator codeGenerator = new CodeGenerator(null, result);
74+
var actual = codeGenerator.GenerateClassLabels();
75+
var expected1 = "[ColumnName(\"Label\")]";
76+
var expected2 = "public bool Label{get; set;}";
77+
78+
Assert.AreEqual(expected1, actual[0]);
79+
Assert.AreEqual(expected2, actual[1]);
80+
}
81+
82+
[TestMethod]
83+
public void GenerateUsingsBasicTest()
84+
{
85+
var context = new MLContext();
86+
var elementProperties = new Dictionary<string, object>();
87+
PipelineNode node = new PipelineNode("TypeConverting", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
88+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
89+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
90+
var actual = codeGenerator.GenerateUsings();
91+
string expected = "using Microsoft.ML.Transforms.Conversions;\r\n";
92+
Assert.AreEqual(expected, actual);
93+
}
94+
95+
[TestMethod]
96+
public void ColumnGenerationTest()
97+
{
98+
List<(TextLoader.Column, ColumnPurpose)> list = new List<(TextLoader.Column, ColumnPurpose)>()
99+
{
100+
(new TextLoader.Column(){ Name = "Label", Source = new TextLoader.Range[]{new TextLoader.Range(0) }, Type = DataKind.Bool }, ColumnPurpose.Label),
101+
(new TextLoader.Column(){ Name = "Features", Source = new TextLoader.Range[]{new TextLoader.Range(1) }, Type = DataKind.R4 }, ColumnPurpose.NumericFeature),
102+
};
103+
ColumnInferenceResult result = new ColumnInferenceResult(list, false, false, ",", true, true);
104+
105+
var context = new MLContext();
106+
var elementProperties = new Dictionary<string, object>();
107+
PipelineNode node = new PipelineNode("Normalizing", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties);
108+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
109+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, result);
110+
var actual = codeGenerator.GenerateColumns();
111+
Assert.AreEqual(actual.Count, 2);
112+
string expectedColumn1 = "new Column(\"Label\",DataKind.BL,0),";
113+
string expectedColumn2 = "new Column(\"Features\",DataKind.R4,1),";
114+
Assert.AreEqual(expectedColumn1, actual[0]);
115+
Assert.AreEqual(expectedColumn2, actual[1]);
116+
}
117+
118+
[TestMethod]
119+
public void TrainerComplexParameterTest()
120+
{
121+
var context = new MLContext();
122+
123+
var elementProperties = new Dictionary<string, object>()
124+
{
125+
{"TreeBooster", new CustomProperty(){Properties= new Dictionary<string, object>(), Name = "TreeBooster"} },
126+
};
127+
PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties);
128+
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
129+
CodeGenerator codeGenerator = new CodeGenerator(pipeline, null);
130+
var actual = codeGenerator.GenerateTrainer();
131+
string expected = "LightGbm(new LightGbm.Options(){TreeBooster=new TreeBooster(){},LabelColumn=\"Label\",FeatureColumn=\"Features\"});";
132+
Assert.AreEqual(expected, actual);
133+
134+
}
135+
136+
}
137+
}

src/mlnet.Test/mlnet.Test.csproj

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netcoreapp2.1</TargetFramework>
5+
<IsPackable>false</IsPackable>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
10+
<PackageReference Include="MSTest.TestAdapter" Version="1.3.2" />
11+
<PackageReference Include="MSTest.TestFramework" Version="1.3.2" />
12+
</ItemGroup>
13+
14+
<ItemGroup>
15+
<ProjectReference Include="..\AutoML\AutoML.csproj" />
16+
<ProjectReference Include="..\mlnet\mlnet.csproj" />
17+
</ItemGroup>
18+
19+
</Project>

0 commit comments

Comments
 (0)