Skip to content

Commit 73d141b

Browse files
authored
Rev OVA pipeline node SDK output: wrap binary trainers as children inside parent OVA node (dotnet#317)
1 parent db3850c commit 73d141b

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs

+12-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ internal enum TrainerName
3232
LogisticRegressionMulti,
3333
OnlineGradientDescentRegression,
3434
OrdinaryLeastSquaresRegression,
35+
Ova,
3536
PoissonRegression,
3637
SdcaBinary,
3738
SdcaMulti,
@@ -79,8 +80,17 @@ public static LightGBM.Options CreateLightGbmOptions(IEnumerable<SweepableParam>
7980
public static PipelineNode BuildOvaPipelineNode(ITrainerExtension multiExtension, ITrainerExtension binaryExtension,
8081
IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
8182
{
82-
var ovaNode = binaryExtension.CreatePipelineNode(sweepParams, columnInfo);
83-
ovaNode.Name = TrainerExtensionCatalog.GetTrainerName(multiExtension).ToString();
83+
var ovaNode = new PipelineNode()
84+
{
85+
Name = TrainerName.Ova.ToString(),
86+
NodeType = PipelineNodeType.Trainer,
87+
Properties = new Dictionary<string, object>()
88+
{
89+
{ LabelColumn, columnInfo.LabelColumn }
90+
}
91+
};
92+
var binaryNode = binaryExtension.CreatePipelineNode(sweepParams, columnInfo);
93+
ovaNode.Properties["BinaryTrainer"] = binaryNode;
8494
return ovaNode;
8595
}
8696

src/Test/TrainerExtensionsTests.cs

+19-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ public void TrainerExtensionInstanceTests()
1717
{
1818
var context = new MLContext();
1919
var columnInfo = new ColumnInformation();
20-
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>();
20+
var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast<TrainerName>()
21+
.Except(new[] { TrainerName.Ova });
2122
foreach (var trainerName in trainerNames)
2223
{
2324
var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
@@ -194,16 +195,25 @@ public void BuildOvaPipelineNode()
194195
{
195196
var pipelineNode = new FastForestOvaExtension().CreatePipelineNode(null, new ColumnInformation());
196197
var expectedJson = @"{
197-
""Name"": ""FastForestOva"",
198+
""Name"": ""Ova"",
198199
""NodeType"": ""Trainer"",
199-
""InColumns"": [
200-
""Features""
201-
],
202-
""OutColumns"": [
203-
""Score""
204-
],
200+
""InColumns"": null,
201+
""OutColumns"": null,
205202
""Properties"": {
206-
""LabelColumn"": ""Label""
203+
""LabelColumn"": ""Label"",
204+
""BinaryTrainer"": {
205+
""Name"": ""FastForestBinary"",
206+
""NodeType"": ""Trainer"",
207+
""InColumns"": [
208+
""Features""
209+
],
210+
""OutColumns"": [
211+
""Score""
212+
],
213+
""Properties"": {
214+
""LabelColumn"": ""Label""
215+
}
216+
}
207217
}
208218
}";
209219
Util.AssertObjectMatchesJson(expectedJson, pipelineNode);

0 commit comments

Comments
 (0)