File tree 2 files changed +31
-11
lines changed
Microsoft.ML.Auto/TrainerExtensions
2 files changed +31
-11
lines changed Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ internal enum TrainerName
32
32
LogisticRegressionMulti ,
33
33
OnlineGradientDescentRegression ,
34
34
OrdinaryLeastSquaresRegression ,
35
+ Ova ,
35
36
PoissonRegression ,
36
37
SdcaBinary ,
37
38
SdcaMulti ,
@@ -79,8 +80,17 @@ public static LightGBM.Options CreateLightGbmOptions(IEnumerable<SweepableParam>
79
80
public static PipelineNode BuildOvaPipelineNode ( ITrainerExtension multiExtension , ITrainerExtension binaryExtension ,
80
81
IEnumerable < SweepableParam > sweepParams , ColumnInformation columnInfo )
81
82
{
82
- var ovaNode = binaryExtension . CreatePipelineNode ( sweepParams , columnInfo ) ;
83
- ovaNode . Name = TrainerExtensionCatalog . GetTrainerName ( multiExtension ) . ToString ( ) ;
83
+ var ovaNode = new PipelineNode ( )
84
+ {
85
+ Name = TrainerName . Ova . ToString ( ) ,
86
+ NodeType = PipelineNodeType . Trainer ,
87
+ Properties = new Dictionary < string , object > ( )
88
+ {
89
+ { LabelColumn , columnInfo . LabelColumn }
90
+ }
91
+ } ;
92
+ var binaryNode = binaryExtension . CreatePipelineNode ( sweepParams , columnInfo ) ;
93
+ ovaNode . Properties [ "BinaryTrainer" ] = binaryNode ;
84
94
return ovaNode ;
85
95
}
86
96
Original file line number Diff line number Diff line change @@ -17,7 +17,8 @@ public void TrainerExtensionInstanceTests()
17
17
{
18
18
var context = new MLContext ( ) ;
19
19
var columnInfo = new ColumnInformation ( ) ;
20
- var trainerNames = Enum . GetValues ( typeof ( TrainerName ) ) . Cast < TrainerName > ( ) ;
20
+ var trainerNames = Enum . GetValues ( typeof ( TrainerName ) ) . Cast < TrainerName > ( )
21
+ . Except ( new [ ] { TrainerName . Ova } ) ;
21
22
foreach ( var trainerName in trainerNames )
22
23
{
23
24
var extension = TrainerExtensionCatalog . GetTrainerExtension ( trainerName ) ;
@@ -194,16 +195,25 @@ public void BuildOvaPipelineNode()
194
195
{
195
196
var pipelineNode = new FastForestOvaExtension ( ) . CreatePipelineNode ( null , new ColumnInformation ( ) ) ;
196
197
var expectedJson = @"{
197
- ""Name"": ""FastForestOva "",
198
+ ""Name"": ""Ova "",
198
199
""NodeType"": ""Trainer"",
199
- ""InColumns"": [
200
- ""Features""
201
- ],
202
- ""OutColumns"": [
203
- ""Score""
204
- ],
200
+ ""InColumns"": null,
201
+ ""OutColumns"": null,
205
202
""Properties"": {
206
- ""LabelColumn"": ""Label""
203
+ ""LabelColumn"": ""Label"",
204
+ ""BinaryTrainer"": {
205
+ ""Name"": ""FastForestBinary"",
206
+ ""NodeType"": ""Trainer"",
207
+ ""InColumns"": [
208
+ ""Features""
209
+ ],
210
+ ""OutColumns"": [
211
+ ""Score""
212
+ ],
213
+ ""Properties"": {
214
+ ""LabelColumn"": ""Label""
215
+ }
216
+ }
207
217
}
208
218
}" ;
209
219
Util . AssertObjectMatchesJson ( expectedJson , pipelineNode ) ;
You can’t perform that action at this time.
0 commit comments