Skip to content

Commit bf097ab

Browse files
author
Pete Luferenko
committed
Added auto-normalization to everything
1 parent ef416f7 commit bf097ab

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public void New_TrainWithInitialPredictor()
6969

7070
// Train the second predictor on the same data.
7171
var secondTrainer = new MyAveragedPerceptron(env, new AveragedPerceptronTrainer.Arguments(), "Features", "Label");
72-
var finalPredictor = secondTrainer.Train(trainData, firstPredictor.Model);
72+
var finalPredictor = secondTrainer.Train(trainData, firstPredictor.InnerModel);
7373
}
7474
}
7575

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
168168

169169
public interface IPredictorTransformer<out TModel> : ITransformer
170170
{
171-
TModel TrainedModel { get; }
171+
TModel InnerModel { get; }
172172
}
173173

174174
public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TModel>
@@ -177,12 +177,10 @@ public class ScorerWrapper<TModel> : TransformWrapper, IPredictorTransformer<TMo
177177
public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel)
178178
: base(env, scorer)
179179
{
180-
Model = trainedModel;
180+
InnerModel = trainedModel;
181181
}
182182

183-
public TModel TrainedModel => Model;
184-
185-
public TModel Model { get; }
183+
public TModel InnerModel { get; }
186184
}
187185

188186
public class MyTextLoader : IDataReaderEstimator<IMultiStreamSource, LoaderWrapper>
@@ -215,11 +213,13 @@ public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
215213
private readonly string _featureCol;
216214
private readonly string _labelCol;
217215
private readonly bool _cache;
216+
private readonly bool _normalize;
218217

219-
protected TrainerBase(IHostEnvironment env, bool cache, string featureColumn, string labelColumn)
218+
protected TrainerBase(IHostEnvironment env, bool cache, bool normalize, string featureColumn, string labelColumn)
220219
{
221220
_env = env;
222221
_cache = cache;
222+
_normalize = normalize;
223223
_featureCol = featureColumn;
224224
_labelCol = labelColumn;
225225
}
@@ -229,28 +229,39 @@ public ScorerWrapper<TModel> Fit(IDataView input)
229229
return TrainTransformer(input);
230230
}
231231

232-
protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet, IDataView validationSet = null, IPredictor initPredictor = null)
232+
protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet,
233+
IDataView validationSet = null, IPredictor initPredictor = null)
233234
{
234235
var cachedTrain = _cache ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet;
235236

236237
var trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol);
238+
var emptyData = new EmptyDataView(_env, trainSet.Schema);
239+
IDataView normalizer = emptyData;
240+
241+
if (_normalize && trainRoles.Schema.FeaturesAreNormalized() == false)
242+
{
243+
var view = NormalizeTransform.CreateMinMaxNormalizer(_env, trainRoles.Data, name: trainRoles.Schema.Feature.Name);
244+
normalizer = ApplyTransformUtils.ApplyAllTransformsToData(_env, view, emptyData, cachedTrain);
245+
246+
trainRoles = new RoleMappedData(view, trainRoles.Schema.GetColumnRoleNames());
247+
}
248+
237249
RoleMappedData validRoles;
238250

239251
if (validationSet == null)
240252
validRoles = null;
241253
else
242254
{
243255
var cachedValid = _cache ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet;
256+
cachedValid = ApplyTransformUtils.ApplyAllTransformsToData(_env, normalizer, cachedValid);
244257
validRoles = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol);
245258
}
246259

247260
var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor));
248-
249-
var emptyData = new EmptyDataView(_env, trainSet.Schema);
250-
var scoreRoles = new RoleMappedData(emptyData, label: _labelCol, feature: _featureCol);
261+
262+
var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol);
251263
IDataScorerTransform scorer = ScoreUtils.GetScorer(pred, scoreRoles, _env, trainRoles.Schema);
252264
return new ScorerWrapper<TModel>(_env, scorer, pred);
253-
254265
}
255266

256267
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
@@ -291,7 +302,7 @@ public sealed class MySdca : TrainerBase<IPredictor>
291302
private readonly LinearClassificationTrainer.Arguments _args;
292303

293304
public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args, string featureCol, string labelCol)
294-
: base(env, true, featureCol, labelCol)
305+
: base(env, true, true, featureCol, labelCol)
295306
{
296307
_args = args;
297308
}
@@ -306,7 +317,7 @@ public sealed class MyAveragedPerceptron : TrainerBase<IPredictor>
306317
private readonly AveragedPerceptronTrainer _trainer;
307318

308319
public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol)
309-
: base(env, false, featureCol, labelCol)
320+
: base(env, false, true, featureCol, labelCol)
310321
{
311322
_trainer = new AveragedPerceptronTrainer(env, args);
312323
}

0 commit comments

Comments
 (0)