diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index efa52d2ff6..487726572b 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -746,13 +746,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat
/// The trainer used to train the predictor.
/// The predictor that needs calibration.
/// The examples to used for calibrator training.
- /// Indicates whether the predictor returned needs to be an .
- /// This parameter is needed for OVA that uses the predictors as s. If it is false,
- /// The predictor returned is an an .
/// The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.
public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
- int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
+ int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -763,7 +760,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema))
return predictor;
- return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper);
+ return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data);
}
///
@@ -775,13 +772,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
/// The maximum rows to use for calibrator training.
/// The predictor that needs calibration.
/// The examples to used for calibrator training.
- /// Indicates whether the predictor returned needs to be an .
- /// This parameter is needed for OVA that uses the predictors as s. If it is false,
- /// The predictor returned is an an .
/// The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.
public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
- int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
+ int maxRows, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
@@ -834,10 +828,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal
}
}
var cali = caliTrainer.FinishTraining(ch);
- return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali, needValueMapper);
+ return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali);
}
- public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali, bool needValueMapper = false)
+ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali)
{
Contracts.Assert(predictor != null);
if (cali == null)
@@ -853,7 +847,7 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights;
if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer)
return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali);
- if (needValueMapper)
+ if (predictor is IValueMapper)
return new CalibratedPredictor(env, predictor, cali);
return new SchemaBindableCalibratedPredictor(env, predictor, cali);
}
diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
index 3d6e1e67b2..8aa6e4b6e0 100644
--- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
+++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs
@@ -92,7 +92,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
else
calibrator = Args.Calibrator.CreateInstance(Host);
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
- trainer, predictor, td, true);
+ trainer, predictor, td);
predictor = res as TScalarPredictor;
Host.Check(predictor != null, "Calibrated predictor does not implement the expected interface");
}
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
index b42dee2d52..f8ac506f6e 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
@@ -798,5 +798,64 @@ public void TestOvaMacro()
}
}
}
+
+ [Fact]
+ public void TestOvaMacroWithUncalibratedLearner()
+ {
+ var dataPath = GetDataPath(@"iris.txt");
+ using (var env = new TlcEnvironment(42))
+ {
+ // Specify subgraph for OVA
+ var subGraph = env.CreateExperiment();
+ var learnerInput = new Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
+ var learnerOutput = subGraph.Add(learnerInput);
+ // Create pipeline with OVA and multiclass scoring.
+ var experiment = env.CreateExperiment();
+ var importInput = new ML.Data.TextLoader(dataPath);
+ importInput.Arguments.Column = new TextLoaderColumn[]
+ {
+ new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
+ new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
+ };
+ var importOutput = experiment.Add(importInput);
+ var oneVersusAll = new Models.OneVersusAll
+ {
+ TrainingData = importOutput.Data,
+ Nodes = subGraph,
+ UseProbabilities = true,
+ };
+ var ovaOutput = experiment.Add(oneVersusAll);
+ var scoreInput = new ML.Transforms.DatasetScorer
+ {
+ Data = importOutput.Data,
+ PredictorModel = ovaOutput.PredictorModel
+ };
+ var scoreOutput = experiment.Add(scoreInput);
+ var evalInput = new ML.Models.ClassificationEvaluator
+ {
+ Data = scoreOutput.ScoredData
+ };
+ var evalOutput = experiment.Add(evalInput);
+ experiment.Compile();
+ experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
+ experiment.Run();
+
+ var data = experiment.GetOutput(evalOutput.OverallMetrics);
+ var schema = data.Schema;
+ var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
+ Assert.True(b);
+ using (var cursor = data.GetRowCursor(col => col == accCol))
+ {
+ var getter = cursor.GetGetter(accCol);
+ b = cursor.MoveNext();
+ Assert.True(b);
+ double acc = 0;
+ getter(ref acc);
+ Assert.Equal(0.71, acc, 2);
+ b = cursor.MoveNext();
+ Assert.False(b);
+ }
+ }
+ }
}
}