Skip to content

Commit bf344a3

Browse files
author
Pete Luferenko
committed
Lowered execution times on some tests.
Implemented new API for initial predictor and validation sets
1 parent d051138 commit bf344a3

File tree

6 files changed

+140
-18
lines changed

6 files changed

+140
-18
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/AutoNormalizationAndCaching.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public void AutoNormalizationAndCaching()
3030
// Train.
3131
var trainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
3232
{
33-
NumThreads = 1
33+
NumThreads = 1,
34+
ConvergenceTolerance = 1f
3435
});
3536

3637
// Auto-caching.

test/Microsoft.ML.Tests/Scenarios/Api/EstimatorPipe.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public CompositeReader<TSource, TNewLastTransformer> Append<TNewLastTransformer>
152152
return new CompositeReader<TSource, TNewLastTransformer>(Reader, Transformer.Append(transformer));
153153
}
154154

155-
public void Save(IHostEnvironment env, Stream outputStream)
155+
public void SavePipeline(IHostEnvironment env, Stream outputStream)
156156
{
157157
using (var ch = env.Start("Saving model"))
158158
{
@@ -171,7 +171,7 @@ public void Save(IHostEnvironment env, Stream outputStream)
171171

172172
public static class CompositeReader
173173
{
174-
public static CompositeReader<IMultiStreamSource, ITransformer> LoadModel(IHostEnvironment env, Stream stream)
174+
public static CompositeReader<IMultiStreamSource, ITransformer> LoadPipeline(IHostEnvironment env, Stream stream)
175175
{
176176
using (var rep = RepositoryReader.Open(stream, env))
177177
{

test/Microsoft.ML.Tests/Scenarios/Api/TrainSaveModelAndPredict.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ public void New_TrainSaveModelAndPredict()
9393
{
9494
// Save model.
9595
using (var fs = file.CreateWriteStream())
96-
model.Save(env, fs);
96+
model.SavePipeline(env, fs);
9797

9898
// Load model.
99-
loadedModel = CompositeReader.LoadModel(env, file.OpenReadStream());
99+
loadedModel = CompositeReader.LoadPipeline(env, file.OpenReadStream());
100100

101101
}
102102

test/Microsoft.ML.Tests/Scenarios/Api/TrainWithInitialPredictor.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,39 @@ public void TrainWithInitialPredictor()
3939
var finalPredictor = secondTrainer.Train(new TrainContext(trainRoles, initialPredictor: predictor));
4040
}
4141
}
42+
43+
/// <summary>
44+
/// Train with initial predictor: Similar to the simple train scenario, but also accept a pre-trained initial model.
45+
/// The scenario might be one of the online linear learners that can take advantage of this, e.g., averaged perceptron.
46+
/// </summary>
47+
[Fact]
48+
public void New_TrainWithInitialPredictor()
49+
{
50+
var dataPath = GetDataPath(SentimentDataPath);
51+
52+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
53+
{
54+
// Pipeline.
55+
var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
56+
.Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
57+
58+
// Train the pipeline, prepare train set.
59+
var reader = pipeline.Fit(new MultiFileSource(dataPath));
60+
var trainData = reader.Read(new MultiFileSource(dataPath));
61+
62+
63+
// Train the first predictor.
64+
var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments
65+
{
66+
NumThreads = 1
67+
}, "Features", "Label");
68+
var firstPredictor = trainer.Fit(trainData);
69+
70+
// Train the second predictor on the same data.
71+
var secondTrainer = new MyAveragedPerceptron(env, new AveragedPerceptronTrainer.Arguments(), "Features", "Label");
72+
var finalPredictor = secondTrainer.Train(trainData, firstPredictor);
73+
}
74+
}
75+
4276
}
4377
}

test/Microsoft.ML.Tests/Scenarios/Api/TrainWithValidationSet.cs

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public void TrainWithValidationSet()
3232
// Apply the same transformations on the validation set.
3333
// Sadly, there is no way to easily apply the same loader to different data, so we either have
3434
// to create another loader, or to save the loader to model file and then reload.
35-
35+
3636
// A new one is not always feasible, but this time it is.
3737
var validLoader = new TextLoader(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(validationDataPath));
3838
var validData = ApplyTransformUtils.ApplyAllTransformsToData(env, trainData, validLoader);
@@ -42,11 +42,41 @@ public void TrainWithValidationSet()
4242
var cachedValid = new CacheDataView(env, validData, prefetch: null);
4343

4444
// Train.
45-
var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments());
45+
var trainer = new FastTreeBinaryClassificationTrainer(env, new FastTreeBinaryClassificationTrainer.Arguments
46+
{
47+
NumTrees = 3
48+
});
4649
var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
47-
var validRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
50+
var validRoles = new RoleMappedData(cachedValid, label: "Label", feature: "Features");
4851
trainer.Train(new Runtime.TrainContext(trainRoles, validRoles));
4952
}
5053
}
54+
55+
/// <summary>
56+
/// Train with validation set: Similar to the simple train scenario, but also support a validation set.
57+
/// The learner might be trees with early stopping.
58+
/// </summary>
59+
[Fact]
60+
public void New_TrainWithValidationSet()
61+
{
62+
var dataPath = GetDataPath(SentimentDataPath);
63+
var validationDataPath = GetDataPath(SentimentTestPath);
64+
65+
using (var env = new TlcEnvironment(seed: 1, conc: 1))
66+
{
67+
// Pipeline.
68+
var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
69+
.Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
70+
71+
// Train the pipeline, prepare train and validation set.
72+
var reader = pipeline.Fit(new MultiFileSource(dataPath));
73+
var trainData = reader.Read(new MultiFileSource(dataPath));
74+
var validData = reader.Read(new MultiFileSource(validationDataPath));
75+
76+
// Train model with validation set.
77+
var trainer = new MySdca(env, new Runtime.Learners.LinearClassificationTrainer.Arguments(), "Features", "Label");
78+
var transformer = trainer.Train(trainData, validData);
79+
}
80+
}
5181
}
5282
}

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
namespace Microsoft.ML.Tests.Scenarios.Api
1818
{
19+
using LinearModel = LinearPredictor;
20+
1921
public sealed class LoaderWrapper : IDataReader<IMultiStreamSource>, ICanSaveModel
2022
{
2123
public const string LoaderSignature = "LoaderWrapper";
@@ -161,6 +163,24 @@ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
161163
public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
162164
}
163165

166+
public interface IPredictorTransformer<out TModel>: ITransformer
167+
{
168+
TModel TrainedModel { get; }
169+
}
170+
171+
public class ScorerWrapper<TModel>: TransformWrapper, IPredictorTransformer<TModel>
172+
where TModel: IPredictor
173+
{
174+
public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel)
175+
:base(env, scorer)
176+
{
177+
Model = trainedModel;
178+
}
179+
180+
public TModel TrainedModel => Model;
181+
182+
public TModel Model { get; }
183+
}
164184

165185
public class MyTextLoader : IDataReaderEstimator<IMultiStreamSource, LoaderWrapper>
166186
{
@@ -185,7 +205,8 @@ public SchemaShape GetOutputSchema()
185205
}
186206
}
187207

188-
public abstract class TrainerBase : IEstimator<TransformWrapper>
208+
public abstract class TrainerBase<TModel> : IEstimator<ScorerWrapper<TModel>>
209+
where TModel: IPredictor
189210
{
190211
protected readonly IHostEnvironment _env;
191212
private readonly string _featureCol;
@@ -200,25 +221,41 @@ protected TrainerBase(IHostEnvironment env, bool cache, string featureColumn, st
200221
_labelCol = labelColumn;
201222
}
202223

203-
public TransformWrapper Fit(IDataView input)
224+
public ScorerWrapper<TModel> Fit(IDataView input)
204225
{
205-
var cached = _cache ? new CacheDataView(_env, input, prefetch: null) : input;
226+
return TrainTransformer(input);
227+
}
228+
229+
protected ScorerWrapper<TModel> TrainTransformer(IDataView trainSet, IDataView validationSet = null, IPredictor initPredictor = null)
230+
{
231+
var cachedTrain = _cache ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet;
232+
233+
var trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol);
234+
RoleMappedData validRoles;
206235

207-
var trainRoles = new RoleMappedData(cached, label: _labelCol, feature: _featureCol);
208-
var pred = Train(trainRoles);
236+
if (validationSet == null)
237+
validRoles = null;
238+
else
239+
{
240+
var cachedValid = _cache ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet;
241+
validRoles = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol);
242+
}
209243

210-
var emptyData = new EmptyDataView(_env, input.Schema);
244+
var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor));
245+
246+
var emptyData = new EmptyDataView(_env, trainSet.Schema);
211247
var scoreRoles = new RoleMappedData(emptyData, label: _labelCol, feature: _featureCol);
212248
IDataScorerTransform scorer = ScoreUtils.GetScorer(pred, scoreRoles, _env, trainRoles.Schema);
213-
return new TransformWrapper(_env, scorer);
249+
return new ScorerWrapper<TModel>(_env, scorer, pred);
250+
214251
}
215252

216253
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
217254
{
218255
throw new NotImplementedException();
219256
}
220257

221-
protected abstract IPredictor Train(RoleMappedData data);
258+
protected abstract TModel TrainCore(TrainContext trainContext);
222259
}
223260

224261
public class MyTextTransform : IEstimator<TransformWrapper>
@@ -246,7 +283,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
246283
}
247284
}
248285

249-
public sealed class MySdca : TrainerBase
286+
public sealed class MySdca : TrainerBase<IPredictor>
250287
{
251288
private readonly LinearClassificationTrainer.Arguments _args;
252289

@@ -256,7 +293,27 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
256293
_args = args;
257294
}
258295

259-
protected override IPredictor Train(RoleMappedData data) => new LinearClassificationTrainer(_env, _args).Train(data);
296+
protected override IPredictor TrainCore(TrainContext context) => new LinearClassificationTrainer(_env, _args).Train(context);
297+
298+
public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
299+
}
300+
301+
public sealed class MyAveragedPerceptron: TrainerBase<IPredictor>
302+
{
303+
private readonly AveragedPerceptronTrainer _trainer;
304+
305+
public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol)
306+
:base(env, false, featureCol, labelCol)
307+
{
308+
_trainer = new AveragedPerceptronTrainer(env, args);
309+
}
310+
311+
protected override IPredictor TrainCore(TrainContext trainContext) => _trainer.Train(trainContext);
312+
313+
public ITransformer Train(IDataView trainData, IPredictorTransformer<IPredictor> initialPredictor)
314+
{
315+
return TrainTransformer(trainData, initPredictor: initialPredictor.TrainedModel);
316+
}
260317
}
261318

262319
public sealed class MyPredictionEngine<TSrc, TDst>

0 commit comments

Comments
 (0)