Skip to content

Commit 2023d09

Browse files
yaeldMSeerhardt
authored andcommitted
Create CalibratedPredictor instead of SchemaBindableCalibratedPredictor (dotnet#338)
`CalibratorUtils.TrainCalibrator` and `TrainCalibratorIfNeeded` now creates `CalibratedPredictor` instead of `SchemaBindableCalibratedPredictor` whenever the predictor implements `IValueMapper`.
1 parent 6abc988 commit 2023d09

File tree

3 files changed

+66
-13
lines changed

3 files changed

+66
-13
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+6-12
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat
746746
/// <param name="trainer">The trainer used to train the predictor.</param>
747747
/// <param name="predictor">The predictor that needs calibration.</param>
748748
/// <param name="data">The examples to used for calibrator training.</param>
749-
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
750-
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
751-
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
752749
/// <returns>The original predictor, if no calibration is needed,
753750
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
754751
public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
755-
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
752+
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data)
756753
{
757754
Contracts.CheckValue(env, nameof(env));
758755
env.CheckValue(ch, nameof(ch));
@@ -763,7 +760,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
763760
if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema))
764761
return predictor;
765762

766-
return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper);
763+
return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data);
767764
}
768765

769766
/// <summary>
@@ -775,13 +772,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
775772
/// <param name="maxRows">The maximum rows to use for calibrator training.</param>
776773
/// <param name="predictor">The predictor that needs calibration.</param>
777774
/// <param name="data">The examples to used for calibrator training.</param>
778-
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
779-
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
780-
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
781775
/// <returns>The original predictor, if no calibration is needed,
782776
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
783777
public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
784-
int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
778+
int maxRows, IPredictor predictor, RoleMappedData data)
785779
{
786780
Contracts.CheckValue(env, nameof(env));
787781
env.CheckValue(ch, nameof(ch));
@@ -834,10 +828,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal
834828
}
835829
}
836830
var cali = caliTrainer.FinishTraining(ch);
837-
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali, needValueMapper);
831+
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali);
838832
}
839833

840-
public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali, bool needValueMapper = false)
834+
public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali)
841835
{
842836
Contracts.Assert(predictor != null);
843837
if (cali == null)
@@ -853,7 +847,7 @@ public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironm
853847
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights<Float>;
854848
if (predWithFeatureScores != null && predictor is IParameterMixer<Float> && cali is IParameterMixer)
855849
return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali);
856-
if (needValueMapper)
850+
if (predictor is IValueMapper)
857851
return new CalibratedPredictor(env, predictor, cali);
858852
return new SchemaBindableCalibratedPredictor(env, predictor, cali);
859853
}

src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
9292
else
9393
calibrator = Args.Calibrator.CreateInstance(Host);
9494
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
95-
trainer, predictor, td, true);
95+
trainer, predictor, td);
9696
predictor = res as TScalarPredictor;
9797
Host.Check(predictor != null, "Calibrated predictor does not implement the expected interface");
9898
}

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

+59
Original file line numberDiff line numberDiff line change
@@ -798,5 +798,64 @@ public void TestOvaMacro()
798798
}
799799
}
800800
}
801+
802+
[Fact]
803+
public void TestOvaMacroWithUncalibratedLearner()
804+
{
805+
var dataPath = GetDataPath(@"iris.txt");
806+
using (var env = new TlcEnvironment(42))
807+
{
808+
// Specify subgraph for OVA
809+
var subGraph = env.CreateExperiment();
810+
var learnerInput = new Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
811+
var learnerOutput = subGraph.Add(learnerInput);
812+
// Create pipeline with OVA and multiclass scoring.
813+
var experiment = env.CreateExperiment();
814+
var importInput = new ML.Data.TextLoader(dataPath);
815+
importInput.Arguments.Column = new TextLoaderColumn[]
816+
{
817+
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
818+
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
819+
};
820+
var importOutput = experiment.Add(importInput);
821+
var oneVersusAll = new Models.OneVersusAll
822+
{
823+
TrainingData = importOutput.Data,
824+
Nodes = subGraph,
825+
UseProbabilities = true,
826+
};
827+
var ovaOutput = experiment.Add(oneVersusAll);
828+
var scoreInput = new ML.Transforms.DatasetScorer
829+
{
830+
Data = importOutput.Data,
831+
PredictorModel = ovaOutput.PredictorModel
832+
};
833+
var scoreOutput = experiment.Add(scoreInput);
834+
var evalInput = new ML.Models.ClassificationEvaluator
835+
{
836+
Data = scoreOutput.ScoredData
837+
};
838+
var evalOutput = experiment.Add(evalInput);
839+
experiment.Compile();
840+
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
841+
experiment.Run();
842+
843+
var data = experiment.GetOutput(evalOutput.OverallMetrics);
844+
var schema = data.Schema;
845+
var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
846+
Assert.True(b);
847+
using (var cursor = data.GetRowCursor(col => col == accCol))
848+
{
849+
var getter = cursor.GetGetter<double>(accCol);
850+
b = cursor.MoveNext();
851+
Assert.True(b);
852+
double acc = 0;
853+
getter(ref acc);
854+
Assert.Equal(0.71, acc, 2);
855+
b = cursor.MoveNext();
856+
Assert.False(b);
857+
}
858+
}
859+
}
801860
}
802861
}

0 commit comments

Comments
 (0)