Skip to content

Commit 5730685

Browse files
authored
OVA should respect normalization in underlying learner (#310)
* Respect normalization in OVA. * some cleanup * fix copypaste issues
1 parent ab4108d commit 5730685

File tree

2 files changed

+65
-3
lines changed
  • src/Microsoft.ML.StandardLearners/Standard/MultiClass
  • test/Microsoft.ML.Core.Tests/UnitTests

2 files changed

+65
-3
lines changed

src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs

+6-3
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,21 @@ public static ModelOperations.PredictorModelOutput CombineOvaModels(IHostEnviron
200200
host.CheckValue(input, nameof(input));
201201
EntryPointUtils.CheckInputArgs(host, input);
202202
host.CheckNonEmpty(input.ModelArray, nameof(input.ModelArray));
203-
203+
// Something tells me we should put normalization as part of macro expansion, but since i get
204+
// subgraph instead of learner it's a bit tricky to get learner and decide should we add
205+
// normalization node or not, plus everywhere in code we leave that reposnsibility to TransformModel.
206+
var normalizedView = input.ModelArray[0].TransformModel.Apply(host, input.TrainingData);
204207
using (var ch = host.Start("CombineOvaModels"))
205208
{
206-
ISchema schema = input.TrainingData.Schema;
209+
ISchema schema = normalizedView.Schema;
207210
var label = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.LabelColumn),
208211
input.LabelColumn,
209212
DefaultColumnNames.Label);
210213
var feature = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.FeatureColumn),
211214
input.FeatureColumn, DefaultColumnNames.Features);
212215
var weight = TrainUtils.MatchNameOrDefaultOrNull(ch, schema, nameof(input.WeightColumn),
213216
input.WeightColumn, DefaultColumnNames.Weight);
214-
var data = TrainUtils.CreateExamples(input.TrainingData, label, feature, null, weight);
217+
var data = TrainUtils.CreateExamples(normalizedView, label, feature, null, weight);
215218

216219
return new ModelOperations.PredictorModelOutput
217220
{

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

+59
Original file line numberDiff line numberDiff line change
@@ -739,5 +739,64 @@ public void TestCrossValidationMacroWithNonDefaultNames()
739739
}
740740
}
741741
}
742+
743+
[Fact]
744+
public void TestOvaMacro()
745+
{
746+
var dataPath = GetDataPath(@"iris.txt");
747+
using (var env = new TlcEnvironment(42))
748+
{
749+
// Specify subgraph for OVA
750+
var subGraph = env.CreateExperiment();
751+
var learnerInput = new Trainers.StochasticDualCoordinateAscentBinaryClassifier { NumThreads = 1 };
752+
var learnerOutput = subGraph.Add(learnerInput);
753+
// Create pipeline with OVA and multiclass scoring.
754+
var experiment = env.CreateExperiment();
755+
var importInput = new ML.Data.TextLoader(dataPath);
756+
importInput.Arguments.Column = new TextLoaderColumn[]
757+
{
758+
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
759+
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
760+
};
761+
var importOutput = experiment.Add(importInput);
762+
var oneVersusAll = new Models.OneVersusAll
763+
{
764+
TrainingData = importOutput.Data,
765+
Nodes = subGraph,
766+
UseProbabilities = true,
767+
};
768+
var ovaOutput = experiment.Add(oneVersusAll);
769+
var scoreInput = new ML.Transforms.DatasetScorer
770+
{
771+
Data = importOutput.Data,
772+
PredictorModel = ovaOutput.PredictorModel
773+
};
774+
var scoreOutput = experiment.Add(scoreInput);
775+
var evalInput = new ML.Models.ClassificationEvaluator
776+
{
777+
Data = scoreOutput.ScoredData
778+
};
779+
var evalOutput = experiment.Add(evalInput);
780+
experiment.Compile();
781+
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
782+
experiment.Run();
783+
784+
var data = experiment.GetOutput(evalOutput.OverallMetrics);
785+
var schema = data.Schema;
786+
var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
787+
Assert.True(b);
788+
using (var cursor = data.GetRowCursor(col => col == accCol))
789+
{
790+
var getter = cursor.GetGetter<double>(accCol);
791+
b = cursor.MoveNext();
792+
Assert.True(b);
793+
double acc = 0;
794+
getter(ref acc);
795+
Assert.Equal(0.96, acc, 2);
796+
b = cursor.MoveNext();
797+
Assert.False(b);
798+
}
799+
}
800+
}
742801
}
743802
}

0 commit comments

Comments
 (0)