@@ -96,15 +96,17 @@ private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoad
96
96
ctx . LoadModelOrNull < ICalibrator , SignatureLoadModel > ( env , out calibrator , @"Calibrator" ) ;
97
97
if ( calibrator == null )
98
98
return predictor ;
99
- return new SchemaBindableCalibratedPredictor ( env , predictor , calibrator ) ;
99
+ return new SchemaBindableCalibratedModelParameters < FastTreeBinaryModelParameters , ICalibrator > ( env , predictor , calibrator ) ;
100
100
}
101
101
102
102
public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
103
103
}
104
104
105
105
/// <include file = 'doc.xml' path='doc/members/member[@name="FastTree"]/*' />
106
106
public sealed partial class FastTreeBinaryClassificationTrainer :
107
- BoostingFastTreeTrainerBase < FastTreeBinaryClassificationTrainer . Options , BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > , IPredictorWithFeatureWeights < float > >
107
+ BoostingFastTreeTrainerBase < FastTreeBinaryClassificationTrainer . Options ,
108
+ BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > ,
109
+ CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > >
108
110
{
109
111
/// <summary>
110
112
/// The LoadName for the assembly containing the trainer.
@@ -156,7 +158,7 @@ internal FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options optio
156
158
157
159
public override PredictionKind PredictionKind => PredictionKind . BinaryClassification ;
158
160
159
- private protected override IPredictorWithFeatureWeights < float > TrainModelCore ( TrainContext context )
161
+ private protected override CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > TrainModelCore ( TrainContext context )
160
162
{
161
163
Host . CheckValue ( context , nameof ( context ) ) ;
162
164
var trainData = context . TrainingSet ;
@@ -185,7 +187,7 @@ private protected override IPredictorWithFeatureWeights<float> TrainModelCore(Tr
185
187
// BinaryClassificationObjectiveFunction.GetGradientInOneQuery being consistent with the
186
188
// description in section 6 of the paper.
187
189
var cali = new PlattCalibrator ( Host , - 1 * _sigmoidParameter , 0 ) ;
188
- return new FeatureWeightsCalibratedPredictor ( Host , pred , cali ) ;
190
+ return new FeatureWeightsCalibratedModelParameters < FastTreeBinaryModelParameters , PlattCalibrator > ( Host , pred , cali ) ;
189
191
}
190
192
191
193
protected override ObjectiveFunctionBase ConstructObjFunc ( IChannel ch )
@@ -273,10 +275,11 @@ protected override void InitializeTests()
273
275
}
274
276
}
275
277
276
- protected override BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > MakeTransformer ( IPredictorWithFeatureWeights < float > model , Schema trainSchema )
277
- => new BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > ( Host , model , trainSchema , FeatureColumn . Name ) ;
278
+ protected override BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > MakeTransformer (
279
+ CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > model , Schema trainSchema )
280
+ => new BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > ( Host , model , trainSchema , FeatureColumn . Name ) ;
278
281
279
- public BinaryPredictionTransformer < IPredictorWithFeatureWeights < float > > Train ( IDataView trainData , IDataView validationData = null )
282
+ public BinaryPredictionTransformer < CalibratedModelParametersBase < FastTreeBinaryModelParameters , PlattCalibrator > > Train ( IDataView trainData , IDataView validationData = null )
280
283
=> TrainTransformer ( trainData , validationData ) ;
281
284
282
285
protected override SchemaShape . Column [ ] GetOutputColumnsCore ( SchemaShape inputSchema )
0 commit comments