2
2
using Microsoft . ML . Models ;
3
3
using Microsoft . ML . Runtime ;
4
4
using Microsoft . ML . Runtime . Api ;
5
+ using Microsoft . ML . Runtime . CommandLine ;
5
6
using Microsoft . ML . Runtime . Data ;
6
7
using Microsoft . ML . Runtime . Data . IO ;
7
8
using Microsoft . ML . Runtime . Learners ;
@@ -90,8 +91,8 @@ public class TransformWrapper : ITransformer, ICanSaveModel
90
91
public const string LoaderSignature = "TransformWrapper" ;
91
92
private const string TransformDirTemplate = "Step_{0:000}" ;
92
93
93
- private readonly IHostEnvironment _env ;
94
- private readonly IDataView _xf ;
94
+ protected readonly IHostEnvironment _env ;
95
+ protected readonly IDataView _xf ;
95
96
96
97
public TransformWrapper ( IHostEnvironment env , IDataView xf )
97
98
{
@@ -174,15 +175,42 @@ public interface IPredictorTransformer<out TModel> : ITransformer
174
175
public class ScorerWrapper < TModel > : TransformWrapper , IPredictorTransformer < TModel >
175
176
where TModel : IPredictor
176
177
{
177
- public ScorerWrapper ( IHostEnvironment env , IDataView scorer , TModel trainedModel )
178
+ protected readonly string _featureColumn ;
179
+
180
+ public ScorerWrapper ( IHostEnvironment env , IDataView scorer , TModel trainedModel , string featureColumn )
178
181
: base ( env , scorer )
179
182
{
183
+ _featureColumn = featureColumn ;
180
184
InnerModel = trainedModel ;
181
185
}
182
186
183
187
public TModel InnerModel { get ; }
184
188
}
185
189
190
+ public class BinaryScorerWrapper < TModel > : ScorerWrapper < TModel >
191
+ where TModel : IPredictor
192
+ {
193
+ public BinaryScorerWrapper ( IHostEnvironment env , TModel model , ISchema inputSchema , string featureColumn , BinaryClassifierScorer . Arguments args )
194
+ : base ( env , MakeScorer ( env , inputSchema , featureColumn , model , args ) , model , featureColumn )
195
+ {
196
+ }
197
+
198
+ private static IDataView MakeScorer ( IHostEnvironment env , ISchema schema , string featureColumn , TModel model , BinaryClassifierScorer . Arguments args )
199
+ {
200
+ var settings = $ "Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}";
201
+ var mapper = ScoreUtils . GetSchemaBindableMapper ( env , model , SubComponent . Parse < IDataScorerTransform , SignatureDataScorer > ( settings ) ) ;
202
+ var edv = new EmptyDataView ( env , schema ) ;
203
+ var data = new RoleMappedData ( edv , "Label" , featureColumn , opt : true ) ;
204
+ return new BinaryClassifierScorer ( env , args , data . Data , mapper . Bind ( env , data . Schema ) , data . Schema ) ;
205
+ }
206
+
207
+ public BinaryScorerWrapper < TModel > Clone ( BinaryClassifierScorer . Arguments scorerArgs )
208
+ {
209
+ var scorer = _xf as IDataScorerTransform ;
210
+ return new BinaryScorerWrapper < TModel > ( _env , InnerModel , scorer . Source . Schema , _featureColumn , scorerArgs ) ;
211
+ }
212
+ }
213
+
186
214
public class MyTextLoader : IDataReaderEstimator < IMultiStreamSource , LoaderWrapper >
187
215
{
188
216
private readonly TextLoader . Arguments _args ;
@@ -206,12 +234,13 @@ public SchemaShape GetOutputSchema()
206
234
}
207
235
}
208
236
209
- public abstract class TrainerBase < TModel > : IEstimator < ScorerWrapper < TModel > >
237
+ public abstract class TrainerBase < TTransformer , TModel > : IEstimator < TTransformer >
238
+ where TTransformer : ScorerWrapper < TModel >
210
239
where TModel : IPredictor
211
240
{
212
241
protected readonly IHostEnvironment _env ;
213
- private readonly string _featureCol ;
214
- private readonly string _labelCol ;
242
+ protected readonly string _featureCol ;
243
+ protected readonly string _labelCol ;
215
244
private readonly bool _cache ;
216
245
private readonly bool _normalize ;
217
246
@@ -224,12 +253,12 @@ protected TrainerBase(IHostEnvironment env, bool cache, bool normalize, string f
224
253
_labelCol = labelColumn ;
225
254
}
226
255
227
- public ScorerWrapper < TModel > Fit ( IDataView input )
256
+ public TTransformer Fit ( IDataView input )
228
257
{
229
258
return TrainTransformer ( input ) ;
230
259
}
231
260
232
- protected ScorerWrapper < TModel > TrainTransformer ( IDataView trainSet ,
261
+ protected TTransformer TrainTransformer ( IDataView trainSet ,
233
262
IDataView validationSet = null , IPredictor initPredictor = null )
234
263
{
235
264
var cachedTrain = _cache ? new CacheDataView ( _env , trainSet , prefetch : null ) : trainSet ;
@@ -260,8 +289,7 @@ protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet,
260
289
var pred = TrainCore ( new TrainContext ( trainRoles , validRoles , initPredictor ) ) ;
261
290
262
291
var scoreRoles = new RoleMappedData ( normalizer , label : _labelCol , feature : _featureCol ) ;
263
- IDataScorerTransform scorer = ScoreUtils . GetScorer ( pred , scoreRoles , _env , trainRoles . Schema ) ;
264
- return new ScorerWrapper < TModel > ( _env , scorer , pred ) ;
292
+ return MakeScorer ( pred , scoreRoles ) ;
265
293
}
266
294
267
295
public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
@@ -270,6 +298,14 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
270
298
}
271
299
272
300
protected abstract TModel TrainCore ( TrainContext trainContext ) ;
301
+
302
+ protected abstract TTransformer MakeScorer ( TModel predictor , RoleMappedData data ) ;
303
+
304
+ protected ScorerWrapper < TModel > MakeScorerBasic ( TModel predictor , RoleMappedData data )
305
+ {
306
+ var scorer = ScoreUtils . GetScorer ( predictor , data , _env , data . Schema ) ;
307
+ return ( TTransformer ) ( new ScorerWrapper < TModel > ( _env , scorer , predictor , data . Schema . Feature . Name ) ) ;
308
+ }
273
309
}
274
310
275
311
public class MyTextTransform : IEstimator < TransformWrapper >
@@ -378,7 +414,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
378
414
}
379
415
}
380
416
381
- public sealed class MySdca : TrainerBase < IPredictor >
417
+ public sealed class MySdca : TrainerBase < BinaryScorerWrapper < IPredictor > , IPredictor >
382
418
{
383
419
private readonly LinearClassificationTrainer . Arguments _args ;
384
420
@@ -391,9 +427,12 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
391
427
protected override IPredictor TrainCore ( TrainContext context ) => new LinearClassificationTrainer ( _env , _args ) . Train ( context ) ;
392
428
393
429
public ITransformer Train ( IDataView trainData , IDataView validationData = null ) => TrainTransformer ( trainData , validationData ) ;
430
+
431
+ protected override BinaryScorerWrapper < IPredictor > MakeScorer ( IPredictor predictor , RoleMappedData data )
432
+ => new BinaryScorerWrapper < IPredictor > ( _env , predictor , data . Data . Schema , _featureCol , new BinaryClassifierScorer . Arguments ( ) ) ;
394
433
}
395
434
396
- public sealed class MySdcaMulticlass : TrainerBase < IPredictor >
435
+ public sealed class MySdcaMulticlass : TrainerBase < ScorerWrapper < IPredictor > , IPredictor >
397
436
{
398
437
private readonly SdcaMultiClassTrainer . Arguments _args ;
399
438
@@ -403,10 +442,12 @@ public MySdcaMulticlass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments ar
403
442
_args = args ;
404
443
}
405
444
445
+ protected override ScorerWrapper < IPredictor > MakeScorer ( IPredictor predictor , RoleMappedData data ) => MakeScorerBasic ( predictor , data ) ;
446
+
406
447
protected override IPredictor TrainCore ( TrainContext context ) => new SdcaMultiClassTrainer ( _env , _args ) . Train ( context ) ;
407
448
}
408
449
409
- public sealed class MyAveragedPerceptron : TrainerBase < IPredictor >
450
+ public sealed class MyAveragedPerceptron : TrainerBase < BinaryScorerWrapper < IPredictor > , IPredictor >
410
451
{
411
452
private readonly AveragedPerceptronTrainer _trainer ;
412
453
@@ -422,6 +463,9 @@ public ITransformer Train(IDataView trainData, IPredictor initialPredictor)
422
463
{
423
464
return TrainTransformer ( trainData , initPredictor : initialPredictor ) ;
424
465
}
466
+
467
+ protected override BinaryScorerWrapper < IPredictor > MakeScorer ( IPredictor predictor , RoleMappedData data )
468
+ => new BinaryScorerWrapper < IPredictor > ( _env , predictor , data . Data . Schema , _featureCol , new BinaryClassifierScorer . Arguments ( ) ) ;
425
469
}
426
470
427
471
public sealed class MyPredictionEngine < TSrc , TDst >
0 commit comments