Skip to content

Commit a254bf5

Browse files
authored
Typed Calibrated Predictors (#2435)
* Make strongly-typed and weakly-typed accesses possible
1 parent 834e471 commit a254bf5

File tree

30 files changed

+311
-216
lines changed

30 files changed

+311
-216
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/PermutationFeatureImportance/PfiBinaryClassificationExample.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static void RunExample()
3030
var linearPredictor = model.LastTransformer;
3131
// Linear models for binary classification are wrapped by a calibrator as a generic predictor
3232
// To access it directly, we must extract it out and cast it to the proper class
33-
var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubPredictor as LinearBinaryModelParameters);
33+
var weights = PfiHelper.GetLinearModelWeights(linearPredictor.Model.SubModel as LinearBinaryModelParameters);
3434

3535
// Compute the permutation metrics using the properly normalized data.
3636
var transformedData = model.Transform(data);

src/Microsoft.ML.Data/EntryPoints/PredictorModelImpl.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.IO;
88
using System.Linq;
99
using Microsoft.Data.DataView;
10+
using Microsoft.ML.Calibrator;
1011
using Microsoft.ML.Data;
1112
using Microsoft.ML.Internal.Calibration;
1213
using Microsoft.ML.Internal.Internallearn;
@@ -111,11 +112,11 @@ internal override string[] GetLabelInfo(IHostEnvironment env, out ColumnType lab
111112
{
112113
Contracts.CheckValue(env, nameof(env));
113114
var predictor = Predictor;
114-
var calibrated = predictor as CalibratedPredictorBase;
115+
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
115116
while (calibrated != null)
116117
{
117-
predictor = calibrated.SubPredictor;
118-
calibrated = predictor as CalibratedPredictorBase;
118+
predictor = calibrated.WeeklyTypedSubModel;
119+
calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
119120
}
120121
var canGetTrainingLabelNames = predictor as ICanGetTrainingLabelNames;
121122
if (canGetTrainingLabelNames != null)

src/Microsoft.ML.Data/EntryPoints/SummarizePredictor.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.IO;
66
using System.Text;
77
using Microsoft.Data.DataView;
8+
using Microsoft.ML.Calibrator;
89
using Microsoft.ML.CommandLine;
910
using Microsoft.ML.Data;
1011
using Microsoft.ML.EntryPoints;
@@ -48,11 +49,11 @@ public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, Summar
4849
[BestFriend]
4950
internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
5051
{
51-
var calibrated = predictor as CalibratedPredictorBase;
52+
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
5253
while (calibrated != null)
5354
{
54-
predictor = calibrated.SubPredictor;
55-
calibrated = predictor as CalibratedPredictorBase;
55+
predictor = calibrated.WeeklyTypedSubModel;
56+
calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
5657
}
5758

5859
IDataView summary = null;

src/Microsoft.ML.Data/Prediction/Calibrator.cs

+137-80
Large diffs are not rendered by default.

src/Microsoft.ML.Data/Training/TrainerEstimatorBase.cs

-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ protected TTransformer TrainTransformer(IDataView trainSet,
141141
return MakeTransformer(pred, trainSet.Schema);
142142
}
143143

144-
[BestFriend]
145144
private protected abstract TModel TrainModelCore(TrainContext trainContext);
146145

147146
protected abstract TTransformer MakeTransformer(TModel model, Schema trainSchema);

src/Microsoft.ML.Ensemble/EntryPoints/PipelineEnsemble.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ public static SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.I
3333

3434
input.PredictorModel.PrepareData(host,
3535
new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema),
36-
out RoleMappedData rmd, out IPredictor predictor
37-
);
36+
out RoleMappedData rmd, out IPredictor predictor);
3837

39-
var calibrated = predictor as CalibratedPredictorBase;
38+
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
4039
while (calibrated != null)
4140
{
42-
predictor = calibrated.SubPredictor;
43-
calibrated = predictor as CalibratedPredictorBase;
41+
predictor = calibrated.WeeklyTypedSubModel;
42+
calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
4443
}
4544
var ensemble = predictor as SchemaBindablePipelineEnsembleBase;
4645
host.CheckUserArg(ensemble != null, nameof(input.PredictorModel.Predictor), "Predictor is not a pipeline ensemble predictor");

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+10-7
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,17 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
9696
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
9797
if (calibrator == null)
9898
return predictor;
99-
return new SchemaBindableCalibratedPredictor(env, predictor, calibrator);
99+
return new SchemaBindableCalibratedModelParameters<FastTreeBinaryModelParameters, ICalibrator>(env, predictor, calibrator);
100100
}
101101

102102
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
103103
}
104104

105105
/// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
106106
public sealed partial class FastTreeBinaryClassificationTrainer :
107-
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options, BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>, IPredictorWithFeatureWeights<float>>
107+
BoostingFastTreeTrainerBase<FastTreeBinaryClassificationTrainer.Options,
108+
BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>,
109+
CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>
108110
{
109111
/// <summary>
110112
/// The LoadName for the assembly containing the trainer.
@@ -156,7 +158,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options optio
156158

157159
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
158160

159-
private protected override IPredictorWithFeatureWeights<float> TrainModelCore(TrainContext context)
161+
private protected override CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
160162
{
161163
Host.CheckValue(context, nameof(context));
162164
var trainData = context.TrainingSet;
@@ -185,7 +187,7 @@ private protected override IPredictorWithFeatureWeights<float> TrainModelCore(Tr
185187
// BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the
186188
// description in section 6 of the paper.
187189
var cali = new PlattCalibrator(Host, -1 * _sigmoidParameter, 0);
188-
return new FeatureWeightsCalibratedPredictor(Host, pred, cali);
190+
return new FeatureWeightsCalibratedModelParameters<FastTreeBinaryModelParameters, PlattCalibrator>(Host, pred, cali);
189191
}
190192

191193
protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
@@ -273,10 +275,11 @@ protected override void InitializeTests()
273275
}
274276
}
275277

276-
protected override BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> MakeTransformer(IPredictorWithFeatureWeights<float> model, Schema trainSchema)
277-
=> new BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>(Host, model, trainSchema, FeatureColumn.Name);
278+
protected override BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> MakeTransformer(
279+
CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator> model, Schema trainSchema)
280+
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
278281

279-
public BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> Train(IDataView trainData, IDataView validationData = null)
282+
public BinaryPredictionTransformer<CalibratedModelParametersBase<FastTreeBinaryModelParameters, PlattCalibrator>> Train(IDataView trainData, IDataView validationData = null)
280283
=> TrainTransformer(trainData, validationData);
281284

282285
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)

src/Microsoft.ML.FastTree/GamClassification.cs

+10-8
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
namespace Microsoft.ML.Trainers.FastTree
3232
{
3333
public sealed class BinaryClassificationGamTrainer :
34-
GamTrainerBase<BinaryClassificationGamTrainer.Options, BinaryPredictionTransformer<CalibratedPredictorBase>, CalibratedPredictorBase>
34+
GamTrainerBase<BinaryClassificationGamTrainer.Options,
35+
BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>,
36+
CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
3537
{
3638
public sealed class Options : ArgumentsBase
3739
{
@@ -103,14 +105,13 @@ private static bool[] ConvertTargetsToBool(double[] targets)
103105
Parallel.Invoke(new ParallelOptions { MaxDegreeOfParallelism = BlockingThreadPool.NumThreads }, actions);
104106
return boolArray;
105107
}
106-
107-
private protected override CalibratedPredictorBase TrainModelCore(TrainContext context)
108+
private protected override CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> TrainModelCore(TrainContext context)
108109
{
109110
TrainBase(context);
110111
var predictor = new BinaryClassificationGamModelParameters(Host,
111112
BinUpperBounds, BinEffects, MeanEffect, InputLength, FeatureMap);
112113
var calibrator = new PlattCalibrator(Host, -1.0 * _sigmoidParameter, 0);
113-
return new CalibratedPredictor(Host, predictor, calibrator);
114+
return new ValueMapperCalibratedModelParameters<BinaryClassificationGamModelParameters, PlattCalibrator>(Host, predictor, calibrator);
114115
}
115116

116117
protected override ObjectiveFunctionBase CreateObjectiveFunction()
@@ -139,10 +140,11 @@ protected override void DefinePruningTest()
139140
PruningTest = new TestHistory(validTest, PruningLossIndex);
140141
}
141142

142-
protected override BinaryPredictionTransformer<CalibratedPredictorBase> MakeTransformer(CalibratedPredictorBase model, Schema trainSchema)
143-
=> new BinaryPredictionTransformer<CalibratedPredictorBase>(Host, model, trainSchema, FeatureColumn.Name);
143+
protected override BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>
144+
MakeTransformer(CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator> model, Schema trainSchema)
145+
=> new BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>>(Host, model, trainSchema, FeatureColumn.Name);
144146

145-
public BinaryPredictionTransformer<CalibratedPredictorBase> Train(IDataView trainData, IDataView validationData = null)
147+
public BinaryPredictionTransformer<CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>> Train(IDataView trainData, IDataView validationData = null)
146148
=> TrainTransformer(trainData, validationData);
147149

148150
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
@@ -207,7 +209,7 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
207209
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
208210
if (calibrator == null)
209211
return predictor;
210-
return new SchemaBindableCalibratedPredictor(env, predictor, calibrator);
212+
return new SchemaBindableCalibratedModelParameters<BinaryClassificationGamModelParameters, ICalibrator>(env, predictor, calibrator);
211213
}
212214

213215
private protected override void SaveCore(ModelSaveContext ctx)

src/Microsoft.ML.FastTree/GamModelParameters.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -876,12 +876,17 @@ private Context Init(IChannel ch)
876876
LoadModelObjects(ch, true, out rawPred, true, out schema, out loader);
877877
bool hadCalibrator = false;
878878

879-
var calibrated = rawPred as CalibratedPredictorBase;
879+
// The rawPred has two possible types:
880+
// 1. CalibratedPredictorBase<BinaryClassificationGamModelParameters, PlattCalibrator>
881+
// 2. RegressionGamModelParameters
882+
// For (1), the trained model, GamModelParametersBase, is a field we need to extract. For (2),
883+
// we don't need to do anything because RegressionGamModelParameters is derived from GamModelParametersBase.
884+
var calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
880885
while (calibrated != null)
881886
{
882887
hadCalibrator = true;
883-
rawPred = calibrated.SubPredictor;
884-
calibrated = rawPred as CalibratedPredictorBase;
888+
rawPred = calibrated.SubModel;
889+
calibrated = rawPred as CalibratedModelParametersBase<BinaryClassificationGamModelParameters, PlattCalibrator>;
885890
}
886891
var pred = rawPred as GamModelParametersBase;
887892
ch.CheckUserArg(pred != null, nameof(Args.InputModelFile), "Predictor was not a " + nameof(GamModelParametersBase));

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
105105
ctx.LoadModelOrNull<ICalibrator, SignatureLoadModel>(env, out calibrator, @"Calibrator");
106106
if (calibrator == null)
107107
return predictor;
108-
return new SchemaBindableCalibratedPredictor(env, predictor, calibrator);
108+
return new SchemaBindableCalibratedModelParameters<FastForestClassificationModelParameters, ICalibrator>(env, predictor, calibrator);
109109
}
110110
}
111111

112112
/// <include file='doc.xml' path='doc/members/member[@name="FastForest"]/*' />
113113
public sealed partial class FastForestClassification :
114-
RandomForestTrainerBase<FastForestClassification.Options, BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>, IPredictorWithFeatureWeights<float>>
114+
RandomForestTrainerBase<FastForestClassification.Options, BinaryPredictionTransformer<FastForestClassificationModelParameters>, FastForestClassificationModelParameters>
115115
{
116116
public sealed class Options : FastForestArgumentsBase
117117
{
@@ -170,7 +170,7 @@ internal FastForestClassification(IHostEnvironment env, Options options)
170170
{
171171
}
172172

173-
private protected override IPredictorWithFeatureWeights<float> TrainModelCore(TrainContext context)
173+
private protected override FastForestClassificationModelParameters TrainModelCore(TrainContext context)
174174
{
175175
Host.CheckValue(context, nameof(context));
176176
var trainData = context.TrainingSet;
@@ -213,10 +213,10 @@ protected override Test ConstructTestForTrainingData()
213213
return new BinaryClassificationTest(ConstructScoreTracker(TrainSet), _trainSetLabels, 1);
214214
}
215215

216-
protected override BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> MakeTransformer(IPredictorWithFeatureWeights<float> model, Schema trainSchema)
217-
=> new BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>>(Host, model, trainSchema, FeatureColumn.Name);
216+
protected override BinaryPredictionTransformer<FastForestClassificationModelParameters> MakeTransformer(FastForestClassificationModelParameters model, Schema trainSchema)
217+
=> new BinaryPredictionTransformer<FastForestClassificationModelParameters>(Host, model, trainSchema, FeatureColumn.Name);
218218

219-
public BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> Train(IDataView trainData, IDataView validationData = null)
219+
public BinaryPredictionTransformer<FastForestClassificationModelParameters> Train(IDataView trainData, IDataView validationData = null)
220220
=> TrainTransformer(trainData, validationData);
221221

222222
protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)

src/Microsoft.ML.FastTree/TreeEnsemble/TreeEnsembleCombiner.cs

+11-7
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
using System.Collections.Generic;
66
using Microsoft.ML;
7+
using Microsoft.ML.Calibrator;
78
using Microsoft.ML.Data;
89
using Microsoft.ML.Ensemble;
910
using Microsoft.ML.Internal.Calibration;
11+
using Microsoft.ML.Internal.Internallearn;
1012
using Microsoft.ML.Trainers.FastTree.Internal;
1113

1214
[assembly: LoadableClass(typeof(TreeEnsembleCombiner), null, typeof(SignatureModelCombiner), "Fast Tree Model Combiner", "FastTreeCombiner")]
@@ -49,16 +51,17 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
4951
var predictor = model;
5052
_host.CheckValue(predictor, nameof(models), "One of the models is null");
5153

52-
var calibrated = predictor as CalibratedPredictorBase;
54+
var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
5355
double paramA = 1;
5456
if (calibrated != null)
55-
{
56-
_host.Check(calibrated.Calibrator is PlattCalibrator,
57+
_host.Check(calibrated.WeeklyTypedCalibrator is PlattCalibrator,
5758
"Combining FastTree models can only be done when the models are calibrated with Platt calibrator");
58-
predictor = calibrated.SubPredictor;
59-
paramA = -(calibrated.Calibrator as PlattCalibrator).Slope;
60-
}
59+
60+
predictor = calibrated.WeeklyTypedSubModel;
61+
paramA = -((PlattCalibrator)calibrated.WeeklyTypedCalibrator).Slope;
62+
6163
var tree = predictor as TreeEnsembleModelParameters;
64+
6265
if (tree == null)
6366
throw _host.Except("Model is not a tree ensemble");
6467
foreach (var t in tree.TrainedEnsemble.Trees)
@@ -103,7 +106,8 @@ public IPredictor CombineModels(IEnumerable<IPredictor> models)
103106
return new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
104107

105108
var cali = new PlattCalibrator(_host, -1, 0);
106-
return new FeatureWeightsCalibratedPredictor(_host, new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null), cali);
109+
var fastTreeModel = new FastTreeBinaryModelParameters(_host, ensemble, featureCount, null);
110+
return new FeatureWeightsCalibratedModelParameters<FastTreeBinaryModelParameters,PlattCalibrator>(_host, fastTreeModel, cali);
107111
case PredictionKind.Regression:
108112
return new FastTreeRegressionModelParameters(_host, ensemble, featureCount, null);
109113
case PredictionKind.Ranking:

0 commit comments

Comments
 (0)