Skip to content

Commit 3b7b923

Browse files
daholsteDmitry-A
authored andcommitted
fix for defaulting Averaged Perceptron # of iterations to 10 (dotnet#237)
1 parent a69f688 commit 3b7b923

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/Microsoft.ML.Auto/TrainerExtensions/BinaryTrainerExtensions.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.Collections.Generic;
6+
using System.Linq;
67
using Microsoft.ML.Trainers;
78
using Microsoft.ML.Trainers.FastTree;
89
using Microsoft.ML.Trainers.HalLearners;
@@ -26,7 +27,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
2627
ColumnInformation columnInfo)
2728
{
2829
AveragedPerceptronTrainer.Options options = null;
29-
if (sweepParams == null)
30+
if (sweepParams == null || !sweepParams.Any())
3031
{
3132
options = new AveragedPerceptronTrainer.Options();
3233
options.NumberOfIterations = DefaultNumIterations;
@@ -35,6 +36,10 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
3536
else
3637
{
3738
options = TrainerExtensionUtil.CreateOptions<AveragedPerceptronTrainer.Options>(sweepParams, columnInfo.LabelColumn);
39+
if (!sweepParams.Any(p => p.Name == "NumberOfIterations"))
40+
{
41+
options.NumberOfIterations = DefaultNumIterations;
42+
}
3843
}
3944
return mlContext.BinaryClassification.Trainers.AveragedPerceptron(options);
4045
}
@@ -43,11 +48,11 @@ public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams,
4348
{
4449
Dictionary<string, object> additionalProperties = null;
4550

46-
if(sweepParams == null)
51+
if (sweepParams == null || !sweepParams.Any(p => p.Name != "NumberOfIterations"))
4752
{
4853
additionalProperties = new Dictionary<string, object>()
4954
{
50-
{ "NumIterations", "10" }
55+
{ "NumberOfIterations", DefaultNumIterations.ToString() }
5156
};
5257
}
5358

@@ -227,4 +232,4 @@ public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams,
227232
columnInfo.LabelColumn);
228233
}
229234
}
230-
}
235+
}

src/Test/TrainerExtensionsTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public void BuildDefaultAveragedPerceptronPipelineNode()
161161
],
162162
""Properties"": {
163163
""LabelColumn"": ""L"",
164-
""NumIterations"": ""10""
164+
""NumberOfIterations"": ""10""
165165
}
166166
}";
167167
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);

0 commit comments

Comments
 (0)