3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System . Collections . Generic ;
6
+ using System . Linq ;
6
7
using Microsoft . ML . Trainers ;
7
8
using Microsoft . ML . Trainers . FastTree ;
8
9
using Microsoft . ML . Trainers . HalLearners ;
@@ -26,7 +27,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
26
27
ColumnInformation columnInfo )
27
28
{
28
29
AveragedPerceptronTrainer . Options options = null ;
29
- if ( sweepParams == null )
30
+ if ( sweepParams == null || ! sweepParams . Any ( ) )
30
31
{
31
32
options = new AveragedPerceptronTrainer . Options ( ) ;
32
33
options . NumberOfIterations = DefaultNumIterations ;
@@ -35,6 +36,10 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
35
36
else
36
37
{
37
38
options = TrainerExtensionUtil . CreateOptions < AveragedPerceptronTrainer . Options > ( sweepParams , columnInfo . LabelColumn ) ;
39
+ if ( ! sweepParams . Any ( p => p . Name == "NumberOfIterations" ) )
40
+ {
41
+ options . NumberOfIterations = DefaultNumIterations ;
42
+ }
38
43
}
39
44
return mlContext . BinaryClassification . Trainers . AveragedPerceptron ( options ) ;
40
45
}
@@ -43,11 +48,11 @@ public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams,
43
48
{
44
49
Dictionary < string , object > additionalProperties = null ;
45
50
46
- if ( sweepParams == null )
51
+ if ( sweepParams == null || ! sweepParams . Any ( p => p . Name != "NumberOfIterations" ) )
47
52
{
48
53
additionalProperties = new Dictionary < string , object > ( )
49
54
{
50
- { "NumIterations " , "10" }
55
+ { "NumberOfIterations " , DefaultNumIterations . ToString ( ) }
51
56
} ;
52
57
}
53
58
@@ -227,4 +232,4 @@ public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams,
227
232
columnInfo . LabelColumn ) ;
228
233
}
229
234
}
230
- }
235
+ }
0 commit comments