@@ -168,7 +168,7 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
168
168
169
169
public interface IPredictorTransformer < out TModel > : ITransformer
170
170
{
171
- TModel TrainedModel { get ; }
171
+ TModel InnerModel { get ; }
172
172
}
173
173
174
174
public class ScorerWrapper < TModel > : TransformWrapper , IPredictorTransformer < TModel >
@@ -177,12 +177,10 @@ public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TMo
177
177
public ScorerWrapper ( IHostEnvironment env , IDataView scorer , TModel trainedModel )
178
178
: base ( env , scorer )
179
179
{
180
- Model = trainedModel ;
180
+ InnerModel = trainedModel ;
181
181
}
182
182
183
- public TModel TrainedModel => Model ;
184
-
185
- public TModel Model { get ; }
183
+ public TModel InnerModel { get ; }
186
184
}
187
185
188
186
public class MyTextLoader : IDataReaderEstimator < IMultiStreamSource , LoaderWrapper >
@@ -215,11 +213,13 @@ public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
215
213
private readonly string _featureCol ;
216
214
private readonly string _labelCol ;
217
215
private readonly bool _cache ;
216
+ private readonly bool _normalize ;
218
217
219
- protected TrainerBase ( IHostEnvironment env , bool cache , string featureColumn , string labelColumn )
218
+ protected TrainerBase ( IHostEnvironment env , bool cache , bool normalize , string featureColumn , string labelColumn )
220
219
{
221
220
_env = env ;
222
221
_cache = cache ;
222
+ _normalize = normalize ;
223
223
_featureCol = featureColumn ;
224
224
_labelCol = labelColumn ;
225
225
}
@@ -229,28 +229,39 @@ public ScorerWrapper<TModel> Fit(IDataView input)
229
229
return TrainTransformer ( input ) ;
230
230
}
231
231
232
- protected ScorerWrapper < TModel > TrainTransformer ( IDataView trainSet , IDataView validationSet = null , IPredictor initPredictor = null )
232
+ protected ScorerWrapper < TModel > TrainTransformer ( IDataView trainSet ,
233
+ IDataView validationSet = null , IPredictor initPredictor = null )
233
234
{
234
235
var cachedTrain = _cache ? new CacheDataView ( _env , trainSet , prefetch : null ) : trainSet ;
235
236
236
237
var trainRoles = new RoleMappedData ( cachedTrain , label : _labelCol , feature : _featureCol ) ;
238
+ var emptyData = new EmptyDataView ( _env , trainSet . Schema ) ;
239
+ IDataView normalizer = emptyData ;
240
+
241
+ if ( _normalize && trainRoles . Schema . FeaturesAreNormalized ( ) == false )
242
+ {
243
+ var view = NormalizeTransform . CreateMinMaxNormalizer ( _env , trainRoles . Data , name : trainRoles . Schema . Feature . Name ) ;
244
+ normalizer = ApplyTransformUtils . ApplyAllTransformsToData ( _env , view , emptyData , cachedTrain ) ;
245
+
246
+ trainRoles = new RoleMappedData ( view , trainRoles . Schema . GetColumnRoleNames ( ) ) ;
247
+ }
248
+
237
249
RoleMappedData validRoles ;
238
250
239
251
if ( validationSet == null )
240
252
validRoles = null ;
241
253
else
242
254
{
243
255
var cachedValid = _cache ? new CacheDataView ( _env , validationSet , prefetch : null ) : validationSet ;
256
+ cachedValid = ApplyTransformUtils . ApplyAllTransformsToData ( _env , normalizer , cachedValid ) ;
244
257
validRoles = new RoleMappedData ( cachedValid , label : _labelCol , feature : _featureCol ) ;
245
258
}
246
259
247
260
var pred = TrainCore ( new TrainContext ( trainRoles , validRoles , initPredictor ) ) ;
248
-
249
- var emptyData = new EmptyDataView ( _env , trainSet . Schema ) ;
250
- var scoreRoles = new RoleMappedData ( emptyData , label : _labelCol , feature : _featureCol ) ;
261
+
262
+ var scoreRoles = new RoleMappedData ( normalizer , label : _labelCol , feature : _featureCol ) ;
251
263
IDataScorerTransform scorer = ScoreUtils . GetScorer ( pred , scoreRoles , _env , trainRoles . Schema ) ;
252
264
return new ScorerWrapper < TModel > ( _env , scorer , pred ) ;
253
-
254
265
}
255
266
256
267
public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
@@ -291,7 +302,7 @@ public sealed class MySdca : TrainerBase<IPredictor>
291
302
private readonly LinearClassificationTrainer . Arguments _args ;
292
303
293
304
public MySdca ( IHostEnvironment env , LinearClassificationTrainer . Arguments args , string featureCol , string labelCol )
294
- : base ( env , true , featureCol , labelCol )
305
+ : base ( env , true , true , featureCol , labelCol )
295
306
{
296
307
_args = args ;
297
308
}
@@ -306,7 +317,7 @@ public sealed class MyAveragedPerceptron : TrainerBase<IPredictor>
306
317
private readonly AveragedPerceptronTrainer _trainer ;
307
318
308
319
public MyAveragedPerceptron ( IHostEnvironment env , AveragedPerceptronTrainer . Arguments args , string featureCol , string labelCol )
309
- : base ( env , false , featureCol , labelCol )
320
+ : base ( env , false , true , featureCol , labelCol )
310
321
{
311
322
_trainer = new AveragedPerceptronTrainer ( env , args ) ;
312
323
}
0 commit comments