16
16
17
17
namespace Microsoft . ML . Tests . Scenarios . Api
18
18
{
19
+ using LinearModel = LinearPredictor ;
20
+
19
21
public sealed class LoaderWrapper : IDataReader < IMultiStreamSource > , ICanSaveModel
20
22
{
21
23
public const string LoaderSignature = "LoaderWrapper" ;
@@ -161,6 +163,24 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
161
163
public IDataView Transform ( IDataView input ) => ApplyTransformUtils . ApplyAllTransformsToData ( _env , _xf , input ) ;
162
164
}
163
165
166
+ public interface IPredictorTransformer < out TModel > : ITransformer
167
+ {
168
+ TModel TrainedModel { get ; }
169
+ }
170
+
171
+ public class ScorerWrapper < TModel > : TransformWrapper , IPredictorTransformer < TModel >
172
+ where TModel : IPredictor
173
+ {
174
+ public ScorerWrapper ( IHostEnvironment env , IDataView scorer , TModel trainedModel )
175
+ : base ( env , scorer )
176
+ {
177
+ Model = trainedModel ;
178
+ }
179
+
180
+ public TModel TrainedModel => Model ;
181
+
182
+ public TModel Model { get ; }
183
+ }
164
184
165
185
public class MyTextLoader : IDataReaderEstimator < IMultiStreamSource , LoaderWrapper >
166
186
{
@@ -185,7 +205,8 @@ public SchemaShape GetOutputSchema()
185
205
}
186
206
}
187
207
188
- public abstract class TrainerBase : IEstimator < TransformWrapper >
208
+ public abstract class TrainerBase < TModel > : IEstimator < ScorerWrapper < TModel > >
209
+ where TModel : IPredictor
189
210
{
190
211
protected readonly IHostEnvironment _env ;
191
212
private readonly string _featureCol ;
@@ -200,25 +221,41 @@ protected TrainerBase(IHostEnvironment env, bool cache, string featureColumn, st
200
221
_labelCol = labelColumn ;
201
222
}
202
223
203
- public TransformWrapper Fit ( IDataView input )
224
+ public ScorerWrapper < TModel > Fit ( IDataView input )
204
225
{
205
- var cached = _cache ? new CacheDataView ( _env , input , prefetch : null ) : input ;
226
+ return TrainTransformer ( input ) ;
227
+ }
228
+
229
+ protected ScorerWrapper < TModel > TrainTransformer ( IDataView trainSet , IDataView validationSet = null , IPredictor initPredictor = null )
230
+ {
231
+ var cachedTrain = _cache ? new CacheDataView ( _env , trainSet , prefetch : null ) : trainSet ;
232
+
233
+ var trainRoles = new RoleMappedData ( cachedTrain , label : _labelCol , feature : _featureCol ) ;
234
+ RoleMappedData validRoles ;
206
235
207
- var trainRoles = new RoleMappedData ( cached , label : _labelCol , feature : _featureCol ) ;
208
- var pred = Train ( trainRoles ) ;
236
+ if ( validationSet == null )
237
+ validRoles = null ;
238
+ else
239
+ {
240
+ var cachedValid = _cache ? new CacheDataView ( _env , validationSet , prefetch : null ) : validationSet ;
241
+ validRoles = new RoleMappedData ( cachedValid , label : _labelCol , feature : _featureCol ) ;
242
+ }
209
243
210
- var emptyData = new EmptyDataView ( _env , input . Schema ) ;
244
+ var pred = TrainCore ( new TrainContext ( trainRoles , validRoles , initPredictor ) ) ;
245
+
246
+ var emptyData = new EmptyDataView ( _env , trainSet . Schema ) ;
211
247
var scoreRoles = new RoleMappedData ( emptyData , label : _labelCol , feature : _featureCol ) ;
212
248
IDataScorerTransform scorer = ScoreUtils . GetScorer ( pred , scoreRoles , _env , trainRoles . Schema ) ;
213
- return new TransformWrapper ( _env , scorer ) ;
249
+ return new ScorerWrapper < TModel > ( _env , scorer , pred ) ;
250
+
214
251
}
215
252
216
253
public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
217
254
{
218
255
throw new NotImplementedException ( ) ;
219
256
}
220
257
221
- protected abstract IPredictor Train ( RoleMappedData data ) ;
258
+ protected abstract TModel TrainCore ( TrainContext trainContext ) ;
222
259
}
223
260
224
261
public class MyTextTransform : IEstimator < TransformWrapper >
@@ -246,7 +283,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
246
283
}
247
284
}
248
285
249
- public sealed class MySdca : TrainerBase
286
+ public sealed class MySdca : TrainerBase < IPredictor >
250
287
{
251
288
private readonly LinearClassificationTrainer . Arguments _args ;
252
289
@@ -256,7 +293,27 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
256
293
_args = args ;
257
294
}
258
295
259
- protected override IPredictor Train ( RoleMappedData data ) => new LinearClassificationTrainer ( _env , _args ) . Train ( data ) ;
296
+ protected override IPredictor TrainCore ( TrainContext context ) => new LinearClassificationTrainer ( _env , _args ) . Train ( context ) ;
297
+
298
+ public ITransformer Train ( IDataView trainData , IDataView validationData = null ) => TrainTransformer ( trainData , validationData ) ;
299
+ }
300
+
301
+ public sealed class MyAveragedPerceptron : TrainerBase < IPredictor >
302
+ {
303
+ private readonly AveragedPerceptronTrainer _trainer ;
304
+
305
+ public MyAveragedPerceptron ( IHostEnvironment env , AveragedPerceptronTrainer . Arguments args , string featureCol , string labelCol )
306
+ : base ( env , false , featureCol , labelCol )
307
+ {
308
+ _trainer = new AveragedPerceptronTrainer ( env , args ) ;
309
+ }
310
+
311
+ protected override IPredictor TrainCore ( TrainContext trainContext ) => _trainer . Train ( trainContext ) ;
312
+
313
+ public ITransformer Train ( IDataView trainData , IPredictorTransformer < IPredictor > initialPredictor )
314
+ {
315
+ return TrainTransformer ( trainData , initPredictor : initialPredictor . TrainedModel ) ;
316
+ }
260
317
}
261
318
262
319
public sealed class MyPredictionEngine < TSrc , TDst >
0 commit comments