Skip to content

Commit 41c663c

Browse files
authored
Set Nullable Auto params to null values (#50)
* Added sequential grouping of columns * reverted the file * added auto params as null * change to the update fields method
1 parent d254f4e commit 41c663c

File tree

3 files changed

+48
-38
lines changed

3 files changed

+48
-38
lines changed

src/AutoML/TrainerExtensions/SweepableParams.cs

+14-11
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ private static IEnumerable<SweepableParam> BuildOnlineLinearArgsParams()
3131

3232
private static IEnumerable<SweepableParam> BuildTreeArgsParams()
3333
{
34-
return new SweepableParam[]
35-
{
34+
return new SweepableParam[]
35+
{
3636
new SweepableLongParam("NumLeaves", 2, 128, isLogScale: true, stepSize: 4),
3737
new SweepableDiscreteParam("MinDocumentsInLeafs", new object[] { 1, 10, 50 }),
3838
new SweepableDiscreteParam("NumTrees", new object[] { 20, 100, 500 }),
3939
new SweepableFloatParam("LearningRates", 0.025f, 0.4f, isLogScale: true),
4040
new SweepableFloatParam("Shrinkage", 0.025f, 4f, isLogScale: true),
41-
};
41+
};
4242
}
4343

4444
private static IEnumerable<SweepableParam> BuildLbfgsArgsParams()
@@ -123,22 +123,24 @@ public static IEnumerable<SweepableParam> BuildPoissonRegressionParams()
123123
public static IEnumerable<SweepableParam> BuildSdcaParams()
124124
{
125125
return new SweepableParam[] {
126-
new SweepableDiscreteParam("L2Const", new object[] { "<Auto>", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
127-
new SweepableDiscreteParam("L1Threshold", new object[] { "<Auto>", 0f, 0.25f, 0.5f, 0.75f, 1f }),
126+
new SweepableDiscreteParam("L2Const", new object[] { null, 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
127+
new SweepableDiscreteParam("L1Threshold", new object[] { null, 0f, 0.25f, 0.5f, 0.75f, 1f }),
128128
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f }),
129-
new SweepableDiscreteParam("MaxIterations", new object[] { "<Auto>", 10, 20, 100 }),
129+
new SweepableDiscreteParam("MaxIterations", new object[] { null, 10, 20, 100 }),
130130
new SweepableDiscreteParam("Shuffle", null, isBool: true),
131131
new SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })
132132
};
133133
}
134134

135-
public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams() {
135+
public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams()
136+
{
136137
return new SweepableParam[] {
137138
new SweepableDiscreteParam("L2Weight", new object[] { 1e-6f, 0.1f, 1f })
138139
};
139140
}
140141

141-
public static IEnumerable<SweepableParam> BuildSgdParams() {
142+
public static IEnumerable<SweepableParam> BuildSgdParams()
143+
{
142144
return new SweepableParam[] {
143145
new SweepableDiscreteParam("L2Weight", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f }),
144146
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f }),
@@ -147,12 +149,13 @@ public static IEnumerable<SweepableParam> BuildSgdParams() {
147149
};
148150
}
149151

150-
public static IEnumerable<SweepableParam> BuildSymSgdParams() {
152+
public static IEnumerable<SweepableParam> BuildSymSgdParams()
153+
{
151154
return new SweepableParam[] {
152155
new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 }),
153-
new SweepableDiscreteParam("LearningRate", new object[] { "<Auto>", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
156+
new SweepableDiscreteParam("LearningRate", new object[] { null, 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
154157
new SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f }),
155-
new SweepableDiscreteParam("UpdateFrequency", new object[] { "<Auto>", 5, 20 })
158+
new SweepableDiscreteParam("UpdateFrequency", new object[] { null, 5, 20 })
156159
};
157160
}
158161
}

src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs

+5-20
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public static Action<LightGbmArguments> CreateLightGbmArgsFunc(IEnumerable<Sweep
7777

7878
public static IDictionary<string, object> BuildPipelineNodeProps(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams)
7979
{
80-
if(trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
80+
if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
8181
trainerName == TrainerName.LightGbmRegression)
8282
{
8383
return BuildLightGbmPipelineNodeProps(sweepParams);
@@ -96,7 +96,7 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume
9696

9797
var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
9898
props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp;
99-
99+
100100
return props;
101101
}
102102

@@ -155,24 +155,9 @@ public static void UpdateFields(object obj, IEnumerable<SweepableParam> sweepPar
155155
{
156156
var optIndex = (int)dp.RawValue;
157157
//Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}");
158-
var option = dp.Options[optIndex].ToString().ToLower();
159-
160-
// Handle <Auto> string values in sweep params
161-
if (option == "auto" || option == "<auto>" || option == "< auto >")
162-
{
163-
//Check if nullable type, in which case 'null' is the auto value.
164-
if (Nullable.GetUnderlyingType(fi.FieldType) != null)
165-
fi.SetValue(obj, null);
166-
else if (fi.FieldType.IsEnum)
167-
{
168-
// Check if there is an enum option named Auto
169-
var enumDict = fi.FieldType.GetEnumValues().Cast<int>()
170-
.ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v);
171-
if (enumDict.ContainsKey("Auto"))
172-
fi.SetValue(obj, enumDict["Auto"]);
173-
}
174-
}
175-
else
158+
var option = dp.Options[optIndex];
159+
160+
if (option != null)
176161
SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType);
177162
}
178163
else

src/Test/TrainerExtensionsTests.cs

+29-7
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public void TrainerExtensionInstanceTests()
1717
{
1818
var context = new MLContext();
1919
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
20-
foreach(var trainerName in trainerNames)
20+
foreach (var trainerName in trainerNames)
2121
{
2222
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
2323
var instance = extension.CreateInstance(context, null);
@@ -33,7 +33,7 @@ public void GetTrainersByMaxIterations()
3333
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
3434
TaskKind.MulticlassClassification, TaskKind.Regression };
3535

36-
foreach(var task in tasks)
36+
foreach (var task in tasks)
3737
{
3838
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
3939
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
@@ -52,7 +52,7 @@ public void GetTrainersByMaxIterations()
5252
public void BuildPipelineNodePropsLightGbm()
5353
{
5454
var sweepParams = SweepableParams.BuildLightGbmParams();
55-
foreach(var sweepParam in sweepParams)
55+
foreach (var sweepParam in sweepParams)
5656
{
5757
sweepParam.RawValue = 1;
5858
}
@@ -91,7 +91,7 @@ public void BuildPipelineNodePropsLightGbm()
9191
public void BuildPipelineNodePropsSdca()
9292
{
9393
var sweepParams = SweepableParams.BuildSdcaParams();
94-
foreach(var sweepParam in sweepParams)
94+
foreach (var sweepParam in sweepParams)
9595
{
9696
sweepParam.RawValue = 1;
9797
}
@@ -108,7 +108,29 @@ public void BuildPipelineNodePropsSdca()
108108
}";
109109
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
110110
}
111-
111+
112+
[TestMethod]
113+
public void BuildPipelineNodePropsSdcaWithNullValues()
114+
{
115+
var sweepParams = SweepableParams.BuildSdcaParams();
116+
foreach (var sweepParam in sweepParams)
117+
{
118+
sweepParam.RawValue = 0;
119+
}
120+
121+
var sdcaBinaryProps = TrainerExtensionUtil.BuildPipelineNodeProps(TrainerName.SdcaBinary, sweepParams);
122+
var expectedJson = @"
123+
{
124+
""L2Const"": null,
125+
""L1Threshold"": null,
126+
""ConvergenceTolerance"": 0.001,
127+
""MaxIterations"": null,
128+
""Shuffle"": false,
129+
""BiasLearningRate"": 0.0
130+
}";
131+
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
132+
}
133+
112134
[TestMethod]
113135
public void BuildParameterSetLightGbm()
114136
{
@@ -129,7 +151,7 @@ public void BuildParameterSetLightGbm()
129151
var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props);
130152
var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props);
131153

132-
foreach(var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
154+
foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
133155
{
134156
Assert.AreEqual(4, paramSet.Count);
135157
Assert.AreEqual("1", paramSet["NumBoostRound"].ValueText);
@@ -148,7 +170,7 @@ public void BuildParameterSetSdca()
148170
};
149171

150172
var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaBinary, props);
151-
173+
152174
Assert.AreEqual(1, sdcaParams.Count);
153175
Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText);
154176
}

0 commit comments

Comments
 (0)