Skip to content

Commit 20e337f

Browse files
srsaggamDmitry-A
authored andcommitted
Revert "Set Nullable Auto params to null values" (dotnet#53)
* Revert "First public api propsal (dotnet#52)" This reverts commit e4a64cf. * Revert "Set Nullable Auto params to null values (dotnet#50)" This reverts commit 41c663c.
1 parent d880ed1 commit 20e337f

File tree

3 files changed

+38
-48
lines changed

3 files changed

+38
-48
lines changed

src/AutoML/TrainerExtensions/SweepableParams.cs

+11-14
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,14 @@ private static IEnumerable<SweepableParam> BuildOnlineLinearArgsParams()
3131

3232
private static IEnumerable<SweepableParam> BuildTreeArgsParams()
3333
{
34-
return new SweepableParam[]
35-
{
34+
return new SweepableParam[]
35+
{
3636
new SweepableLongParam("NumLeaves", 2, 128, isLogScale: true, stepSize: 4),
3737
new SweepableDiscreteParam("MinDocumentsInLeafs", new object[] { 1, 10, 50 }),
3838
new SweepableDiscreteParam("NumTrees", new object[] { 20, 100, 500 }),
3939
new SweepableFloatParam("LearningRates", 0.025f, 0.4f, isLogScale: true),
4040
new SweepableFloatParam("Shrinkage", 0.025f, 4f, isLogScale: true),
41-
};
41+
};
4242
}
4343

4444
private static IEnumerable<SweepableParam> BuildLbfgsArgsParams()
@@ -123,24 +123,22 @@ public static IEnumerable<SweepableParam> BuildPoissonRegressionParams()
123123
public static IEnumerable<SweepableParam> BuildSdcaParams()
124124
{
125125
return new SweepableParam[] {
126-
new SweepableDiscreteParam("L2Const", new object[] { null, 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
127-
new SweepableDiscreteParam("L1Threshold", new object[] { null, 0f, 0.25f, 0.5f, 0.75f, 1f }),
126+
new SweepableDiscreteParam("L2Const", new object[] { "<Auto>", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }),
127+
new SweepableDiscreteParam("L1Threshold", new object[] { "<Auto>", 0f, 0.25f, 0.5f, 0.75f, 1f }),
128128
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f }),
129-
new SweepableDiscreteParam("MaxIterations", new object[] { null, 10, 20, 100 }),
129+
new SweepableDiscreteParam("MaxIterations", new object[] { "<Auto>", 10, 20, 100 }),
130130
new SweepableDiscreteParam("Shuffle", null, isBool: true),
131131
new SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f })
132132
};
133133
}
134134

135-
public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams()
136-
{
135+
public static IEnumerable<SweepableParam> BuildOrdinaryLeastSquaresParams() {
137136
return new SweepableParam[] {
138137
new SweepableDiscreteParam("L2Weight", new object[] { 1e-6f, 0.1f, 1f })
139138
};
140139
}
141140

142-
public static IEnumerable<SweepableParam> BuildSgdParams()
143-
{
141+
public static IEnumerable<SweepableParam> BuildSgdParams() {
144142
return new SweepableParam[] {
145143
new SweepableDiscreteParam("L2Weight", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f }),
146144
new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f }),
@@ -149,13 +147,12 @@ public static IEnumerable<SweepableParam> BuildSgdParams()
149147
};
150148
}
151149

152-
public static IEnumerable<SweepableParam> BuildSymSgdParams()
153-
{
150+
public static IEnumerable<SweepableParam> BuildSymSgdParams() {
154151
return new SweepableParam[] {
155152
new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 }),
156-
new SweepableDiscreteParam("LearningRate", new object[] { null, 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
153+
new SweepableDiscreteParam("LearningRate", new object[] { "<Auto>", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }),
157154
new SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f }),
158-
new SweepableDiscreteParam("UpdateFrequency", new object[] { null, 5, 20 })
155+
new SweepableDiscreteParam("UpdateFrequency", new object[] { "<Auto>", 5, 20 })
159156
};
160157
}
161158
}

src/AutoML/TrainerExtensions/TrainerExtensionUtil.cs

+20-5
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ public static Action<LightGbmArguments> CreateLightGbmArgsFunc(IEnumerable<Sweep
7777

7878
public static IDictionary<string, object> BuildPipelineNodeProps(TrainerName trainerName, IEnumerable<SweepableParam> sweepParams)
7979
{
80-
if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
80+
if(trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti ||
8181
trainerName == TrainerName.LightGbmRegression)
8282
{
8383
return BuildLightGbmPipelineNodeProps(sweepParams);
@@ -96,7 +96,7 @@ private static IDictionary<string, object> BuildLightGbmPipelineNodeProps(IEnume
9696

9797
var props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue());
9898
props[LightGbmTreeBoosterPropName] = treeBoosterCustomProp;
99-
99+
100100
return props;
101101
}
102102

@@ -155,9 +155,24 @@ public static void UpdateFields(object obj, IEnumerable<SweepableParam> sweepPar
155155
{
156156
var optIndex = (int)dp.RawValue;
157157
//Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}");
158-
var option = dp.Options[optIndex];
159-
160-
if (option != null)
158+
var option = dp.Options[optIndex].ToString().ToLower();
159+
160+
// Handle <Auto> string values in sweep params
161+
if (option == "auto" || option == "<auto>" || option == "< auto >")
162+
{
163+
//Check if nullable type, in which case 'null' is the auto value.
164+
if (Nullable.GetUnderlyingType(fi.FieldType) != null)
165+
fi.SetValue(obj, null);
166+
else if (fi.FieldType.IsEnum)
167+
{
168+
// Check if there is an enum option named Auto
169+
var enumDict = fi.FieldType.GetEnumValues().Cast<int>()
170+
.ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v);
171+
if (enumDict.ContainsKey("Auto"))
172+
fi.SetValue(obj, enumDict["Auto"]);
173+
}
174+
}
175+
else
161176
SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType);
162177
}
163178
else

src/Test/TrainerExtensionsTests.cs

+7-29
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public void TrainerExtensionInstanceTests()
1717
{
1818
var context = new MLContext();
1919
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
20-
foreach (var trainerName in trainerNames)
20+
foreach(var trainerName in trainerNames)
2121
{
2222
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
2323
var instance = extension.CreateInstance(context, null);
@@ -33,7 +33,7 @@ public void GetTrainersByMaxIterations()
3333
var tasks = new TaskKind[] { TaskKind.BinaryClassification,
3434
TaskKind.MulticlassClassification, TaskKind.Regression };
3535

36-
foreach (var task in tasks)
36+
foreach(var task in tasks)
3737
{
3838
var trainerSet10 = TrainerExtensionCatalog.GetTrainers(task, 10);
3939
var trainerSet50 = TrainerExtensionCatalog.GetTrainers(task, 50);
@@ -52,7 +52,7 @@ public void GetTrainersByMaxIterations()
5252
public void BuildPipelineNodePropsLightGbm()
5353
{
5454
var sweepParams = SweepableParams.BuildLightGbmParams();
55-
foreach (var sweepParam in sweepParams)
55+
foreach(var sweepParam in sweepParams)
5656
{
5757
sweepParam.RawValue = 1;
5858
}
@@ -91,7 +91,7 @@ public void BuildPipelineNodePropsLightGbm()
9191
public void BuildPipelineNodePropsSdca()
9292
{
9393
var sweepParams = SweepableParams.BuildSdcaParams();
94-
foreach (var sweepParam in sweepParams)
94+
foreach(var sweepParam in sweepParams)
9595
{
9696
sweepParam.RawValue = 1;
9797
}
@@ -108,29 +108,7 @@ public void BuildPipelineNodePropsSdca()
108108
}";
109109
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
110110
}
111-
112-
[TestMethod]
113-
public void BuildPipelineNodePropsSdcaWithNullValues()
114-
{
115-
var sweepParams = SweepableParams.BuildSdcaParams();
116-
foreach (var sweepParam in sweepParams)
117-
{
118-
sweepParam.RawValue = 0;
119-
}
120-
121-
var sdcaBinaryProps = TrainerExtensionUtil.BuildPipelineNodeProps(TrainerName.SdcaBinary, sweepParams);
122-
var expectedJson = @"
123-
{
124-
""L2Const"": null,
125-
""L1Threshold"": null,
126-
""ConvergenceTolerance"": 0.001,
127-
""MaxIterations"": null,
128-
""Shuffle"": false,
129-
""BiasLearningRate"": 0.0
130-
}";
131-
Util.AssertObjectMatchesJson(expectedJson, sdcaBinaryProps);
132-
}
133-
111+
134112
[TestMethod]
135113
public void BuildParameterSetLightGbm()
136114
{
@@ -151,7 +129,7 @@ public void BuildParameterSetLightGbm()
151129
var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props);
152130
var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props);
153131

154-
foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
132+
foreach(var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams })
155133
{
156134
Assert.AreEqual(4, paramSet.Count);
157135
Assert.AreEqual("1", paramSet["NumBoostRound"].ValueText);
@@ -170,7 +148,7 @@ public void BuildParameterSetSdca()
170148
};
171149

172150
var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaBinary, props);
173-
151+
174152
Assert.AreEqual(1, sdcaParams.Count);
175153
Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText);
176154
}

0 commit comments

Comments
 (0)