Skip to content

Commit 7a66e19

Browse files
author
Pete Luferenko
committed
Reconfigurable prediction
1 parent 231516a commit 7a66e19

File tree

2 files changed

+97
-13
lines changed

2 files changed

+97
-13
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/ReconfigurablePrediction.cs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.ML.Models;
22
using Microsoft.ML.Runtime.Data;
33
using Microsoft.ML.Runtime.Learners;
4+
using System.Collections;
45
using Xunit;
56

67
namespace Microsoft.ML.Tests.Scenarios.Api
@@ -56,5 +57,44 @@ void ReconfigurablePrediction()
5657
var new_metrics = BinaryClassificationMetrics.FromMetrics(env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"])[0];
5758
}
5859
}
60+
61+
/// <summary>
62+
/// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
63+
/// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
64+
/// and configures the scorer (or more precisely instantiates a new scorer over the same predictor)
65+
/// with some threshold derived from that.
66+
/// </summary>
67+
[Fact]
68+
void New_ReconfigurablePrediction()
69+
{
70+
var dataPath = GetDataPath(SentimentDataPath);
71+
var testDataPath = GetDataPath(SentimentTestPath);
72+
73+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
74+
{
75+
var dataReader = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
76+
.Fit(new MultiFileSource(dataPath));
77+
78+
var data = dataReader.Read(new MultiFileSource(dataPath));
79+
var testData = dataReader.Read(new MultiFileSource(testDataPath));
80+
81+
// Pipeline.
82+
var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs())
83+
.Fit(data);
84+
85+
var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label");
86+
var trainData = pipeline.Transform(data);
87+
var model = trainer.Fit(trainData);
88+
89+
var scoredTest = model.Transform(pipeline.Transform(testData));
90+
var metrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()).Evaluate(scoredTest, "Label", "Probability");
91+
92+
var newModel = model.Clone(new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability });
93+
var newScoredTest = newModel.Transform(pipeline.Transform(testData));
94+
var newMetrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments { Threshold = 0.01f, UseRawScoreThreshold = false }).Evaluate(newScoredTest, "Label", "Probability");
95+
}
96+
97+
}
98+
5999
}
60100
}

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Microsoft.ML.Models;
33
using Microsoft.ML.Runtime;
44
using Microsoft.ML.Runtime.Api;
5+
using Microsoft.ML.Runtime.CommandLine;
56
using Microsoft.ML.Runtime.Data;
67
using Microsoft.ML.Runtime.Data.IO;
78
using Microsoft.ML.Runtime.Learners;
@@ -90,8 +91,8 @@ public class TransformWrapper : ITransformer, ICanSaveModel
9091
public const string LoaderSignature = "TransformWrapper";
9192
private const string TransformDirTemplate = "Step_{0:000}";
9293

93-
private readonly IHostEnvironment _env;
94-
private readonly IDataView _xf;
94+
protected readonly IHostEnvironment _env;
95+
protected readonly IDataView _xf;
9596

9697
public TransformWrapper(IHostEnvironment env, IDataView xf)
9798
{
@@ -174,15 +175,42 @@ public interface IPredictorTransformer<out TModel> : ITransformer
174175
public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TModel>
175176
where TModel : IPredictor
176177
{
177-
public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel)
178+
protected readonly string _featureColumn;
179+
180+
public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn)
178181
: base(env, scorer)
179182
{
183+
_featureColumn = featureColumn;
180184
InnerModel = trainedModel;
181185
}
182186

183187
public TModel InnerModel { get; }
184188
}
185189

190+
public class BinaryScorerWrapper<TModel>: ScorerWrapper<TModel>
191+
where TModel: IPredictor
192+
{
193+
public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args)
194+
:base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn)
195+
{
196+
}
197+
198+
private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args)
199+
{
200+
var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}";
201+
var mapper = ScoreUtils.GetSchemaBindableMapper(env, model, SubComponent.Parse<IDataScorerTransform, SignatureDataScorer>(settings));
202+
var edv = new EmptyDataView(env, schema);
203+
var data = new RoleMappedData(edv, "Label", featureColumn, opt: true);
204+
return new BinaryClassifierScorer(env, args, data.Data, mapper.Bind(env, data.Schema), data.Schema);
205+
}
206+
207+
public BinaryScorerWrapper<TModel> Clone(BinaryClassifierScorer.Arguments scorerArgs)
208+
{
209+
var scorer = _xf as IDataScorerTransform;
210+
return new BinaryScorerWrapper<TModel>(_env, InnerModel, scorer.Source.Schema, _featureColumn, scorerArgs);
211+
}
212+
}
213+
186214
public class MyTextLoader : IDataReaderEstimator<IMultiStreamSource, LoaderWrapper>
187215
{
188216
private readonly TextLoader.Arguments _args;
@@ -206,12 +234,13 @@ public SchemaShape GetOutputSchema()
206234
}
207235
}
208236

209-
public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
237+
public abstract class TrainerBase<TTransformer, TModel> : IEstimator<TTransformer>
238+
where TTransformer: ScorerWrapper<TModel>
210239
where TModel : IPredictor
211240
{
212241
protected readonly IHostEnvironment _env;
213-
private readonly string _featureCol;
214-
private readonly string _labelCol;
242+
protected readonly string _featureCol;
243+
protected readonly string _labelCol;
215244
private readonly bool _cache;
216245
private readonly bool _normalize;
217246

@@ -224,12 +253,12 @@ protected TrainerBase(IHostEnvironment env, bool cache, bool normalize, string f
224253
_labelCol = labelColumn;
225254
}
226255

227-
public ScorerWrapper<TModel> Fit(IDataView input)
256+
public TTransformer Fit(IDataView input)
228257
{
229258
return TrainTransformer(input);
230259
}
231260

232-
protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet,
261+
protected TTransformer TrainTransformer(IDataView trainSet,
233262
IDataView validationSet = null, IPredictor initPredictor = null)
234263
{
235264
var cachedTrain = _cache ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet;
@@ -260,8 +289,7 @@ protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet,
260289
var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor));
261290

262291
var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol);
263-
IDataScorerTransform scorer = ScoreUtils.GetScorer(pred, scoreRoles, _env, trainRoles.Schema);
264-
return new ScorerWrapper<TModel>(_env, scorer, pred);
292+
return MakeScorer(pred, scoreRoles);
265293
}
266294

267295
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
@@ -270,6 +298,14 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
270298
}
271299

272300
protected abstract TModel TrainCore(TrainContext trainContext);
301+
302+
protected abstract TTransformer MakeScorer(TModel predictor, RoleMappedData data);
303+
304+
protected ScorerWrapper<TModel> MakeScorerBasic(TModel predictor, RoleMappedData data)
305+
{
306+
var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema);
307+
return (TTransformer)(new ScorerWrapper<TModel>(_env, scorer, predictor, data.Schema.Feature.Name));
308+
}
273309
}
274310

275311
public class MyTextTransform : IEstimator<TransformWrapper>
@@ -378,7 +414,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
378414
}
379415
}
380416

381-
public sealed class MySdca : TrainerBase<IPredictor>
417+
public sealed class MySdca : TrainerBase<BinaryScorerWrapper<IPredictor>,IPredictor>
382418
{
383419
private readonly LinearClassificationTrainer.Arguments _args;
384420

@@ -391,9 +427,12 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
391427
protected override IPredictor TrainCore(TrainContext context) => new LinearClassificationTrainer(_env, _args).Train(context);
392428

393429
public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
430+
431+
protected override BinaryScorerWrapper<IPredictor> MakeScorer(IPredictor predictor, RoleMappedData data)
432+
=> new BinaryScorerWrapper<IPredictor>(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments());
394433
}
395434

396-
public sealed class MySdcaMulticlass : TrainerBase<IPredictor>
435+
public sealed class MySdcaMulticlass : TrainerBase<ScorerWrapper<IPredictor>, IPredictor>
397436
{
398437
private readonly SdcaMultiClassTrainer.Arguments _args;
399438

@@ -403,10 +442,12 @@ public MySdcaMulticlass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments ar
403442
_args = args;
404443
}
405444

445+
protected override ScorerWrapper<IPredictor> MakeScorer(IPredictor predictor, RoleMappedData data) => MakeScorerBasic(predictor, data);
446+
406447
protected override IPredictor TrainCore(TrainContext context) => new SdcaMultiClassTrainer(_env, _args).Train(context);
407448
}
408449

409-
public sealed class MyAveragedPerceptron : TrainerBase<IPredictor>
450+
public sealed class MyAveragedPerceptron : TrainerBase<BinaryScorerWrapper<IPredictor>, IPredictor>
410451
{
411452
private readonly AveragedPerceptronTrainer _trainer;
412453

@@ -422,6 +463,9 @@ public ITransformer Train(IDataView trainData, IPredictor initialPredictor)
422463
{
423464
return TrainTransformer(trainData, initPredictor: initialPredictor);
424465
}
466+
467+
protected override BinaryScorerWrapper<IPredictor> MakeScorer(IPredictor predictor, RoleMappedData data)
468+
=> new BinaryScorerWrapper<IPredictor>(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments());
425469
}
426470

427471
public sealed class MyPredictionEngine<TSrc, TDst>

0 commit comments

Comments
 (0)