Skip to content

Commit 660ebeb

Browse files
author
Pete Luferenko
committed
Added transformer scope enum for decomposability
1 parent b82f4c6 commit 660ebeb

File tree

4 files changed

+210
-32
lines changed

4 files changed

+210
-32
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/ApiScenariosTests.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,17 @@ public ApiScenariosTests(ITestOutputHelper output) : base(output)
2121
public const string SentimentDataPath = "wikipedia-detox-250-line-data.tsv";
2222
public const string SentimentTestPath = "wikipedia-detox-250-line-test.tsv";
2323

24-
public class IrisData
24+
public class IrisData: IrisDataNoLabel
25+
{
26+
public string Label;
27+
}
28+
29+
public class IrisDataNoLabel
2530
{
2631
public float SepalLength;
2732
public float SepalWidth;
2833
public float PetalLength;
2934
public float PetalWidth;
30-
public string Label;
3135
}
3236

3337
public class IrisPrediction

test/Microsoft.ML.Tests/Scenarios/Api/DecomposableTrainAndPredict.cs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void DecomposableTrainAndPredict()
3737

3838
var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
3939
IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
40-
40+
4141
// Cut of term transform from pipeline.
4242
var new_scorer = ApplyTransformUtils.ApplyAllTransformsToData(env, scorer, loader, term);
4343
var keyToValue = new KeyToValueTransform(env, new_scorer, "PredictedLabel");
@@ -52,5 +52,40 @@ void DecomposableTrainAndPredict()
5252
}
5353
}
5454
}
55+
56+
/// <summary>
57+
/// Decomposable train and predict: Train on Iris multiclass problem, which will require
58+
/// a transform on labels. Be able to reconstitute the pipeline for a prediction only task,
59+
/// which will essentially "drop" the transform over labels, while retaining the property
60+
/// that the predicted label for this has a key-type, the probability outputs for the classes
61+
/// have the class labels as slot names, etc. This should be do-able without ugly compromises like,
62+
/// say, injecting a dummy label.
63+
/// </summary>
64+
[Fact]
65+
void New_DecomposableTrainAndPredict()
66+
{
67+
var dataPath = GetDataPath(IrisDataPath);
68+
using (var env = new TlcEnvironment())
69+
{
70+
var data = new MyTextLoader(env, MakeIrisTextLoaderArgs())
71+
.FitAndRead(new MultiFileSource(dataPath));
72+
73+
var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
74+
.Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest)
75+
.Append(new MySdcaMulticlass(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"))
76+
.Append(new MyKeyToValueTransform(env, "PredictedLabel"));
77+
78+
var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
79+
var engine = new MyPredictionEngine<IrisDataNoLabel, IrisPrediction>(env, model);
80+
81+
var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
82+
var testData = testLoader.AsEnumerable<IrisData>(env, false);
83+
foreach (var input in testData.Take(20))
84+
{
85+
var prediction = engine.Predict(input);
86+
Assert.True(prediction.PredictedLabel == input.Label);
87+
}
88+
}
89+
}
5590
}
5691
}

test/Microsoft.ML.Tests/Scenarios/Api/EstimatorPipe.cs

Lines changed: 69 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Microsoft.ML.Runtime.Internal.Utilities;
55
using Microsoft.ML.Runtime.Model;
66
using Microsoft.ML.Tests.Scenarios.Api;
7+
using System;
78
using System.Collections.Generic;
89
using System.IO;
910
using System.Linq;
@@ -17,10 +18,20 @@ public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveM
1718
where TLastTransformer : class, ITransformer
1819
{
1920
private readonly ITransformer[] _transformers;
21+
private readonly TransformerScope[] _scopes;
2022
public readonly TLastTransformer LastTransformer;
2123

2224
private const string TransformDirTemplate = "Transform_{0:000}";
2325

26+
internal TransformerChain(ITransformer[] transformers, TransformerScope[] scopes)
27+
{
28+
_transformers = transformers.ToArray();
29+
_scopes = scopes.ToArray();
30+
LastTransformer = transformers.Last() as TLastTransformer;
31+
Contracts.Check(LastTransformer != null);
32+
Contracts.Check(transformers.Length == scopes.Length);
33+
}
34+
2435
public TransformerChain(params ITransformer[] transformers)
2536
{
2637
if (Utils.Size(transformers) == 0)
@@ -31,6 +42,7 @@ public TransformerChain(params ITransformer[] transformers)
3142
else
3243
{
3344
_transformers = transformers.ToArray();
45+
_scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
3446
LastTransformer = transformers.Last() as TLastTransformer;
3547
Contracts.Check(LastTransformer != null);
3648
}
@@ -63,11 +75,26 @@ public IEnumerable<ITransformer> GetParts()
6375
return _transformers;
6476
}
6577

66-
public TransformerChain<TNewLast> Append<TNewLast>(TNewLast transformer)
78+
public TransformerChain<ITransformer> GetModelFor(TransformerScope scopeFilter)
79+
{
80+
var xfs = new List<ITransformer>();
81+
var scopes = new List<TransformerScope>();
82+
for (int i=0; i<_transformers.Length; i++)
83+
{
84+
if ((_scopes[i] & scopeFilter) != TransformerScope.None)
85+
{
86+
xfs.Add(_transformers[i]);
87+
scopes.Add(_scopes[i]);
88+
}
89+
}
90+
return new TransformerChain<ITransformer>(xfs.ToArray(), scopes.ToArray());
91+
}
92+
93+
public TransformerChain<TNewLast> Append<TNewLast>(TNewLast transformer, TransformerScope scope)
6794
where TNewLast : class, ITransformer
6895
{
6996
Contracts.CheckValue(transformer, nameof(transformer));
70-
return new TransformerChain<TNewLast>(_transformers.Append(transformer).ToArray());
97+
return new TransformerChain<TNewLast>(_transformers.Append(transformer).ToArray(), _scopes.Append(scope).ToArray());
7198
}
7299

73100
public void Save(ModelSaveContext ctx)
@@ -79,6 +106,7 @@ public void Save(ModelSaveContext ctx)
79106

80107
for (int i = 0; i < _transformers.Length; i++)
81108
{
109+
ctx.Writer.Write((int)_scopes[i]);
82110
var dirName = string.Format(TransformDirTemplate, i);
83111
ctx.SaveModel(_transformers[i], dirName);
84112
}
@@ -88,8 +116,10 @@ internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
88116
{
89117
int len = ctx.Reader.ReadInt32();
90118
_transformers = new ITransformer[len];
119+
_scopes = new TransformerScope[len];
91120
for (int i = 0; i < len; i++)
92121
{
122+
_scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
93123
var dirName = string.Format(TransformDirTemplate, i);
94124
ctx.LoadModel<ITransformer, SignatureLoadModel>(env, out _transformers[i], dirName);
95125
}
@@ -146,12 +176,6 @@ public ISchema GetOutputSchema()
146176
return s;
147177
}
148178

149-
public CompositeReader<TSource, TNewLastTransformer> Append<TNewLastTransformer>(TNewLastTransformer transformer)
150-
where TNewLastTransformer : class, ITransformer
151-
{
152-
return new CompositeReader<TSource, TNewLastTransformer>(Reader, Transformer.Append(transformer));
153-
}
154-
155179
public void SavePipeline(IHostEnvironment env, Stream outputStream)
156180
{
157181
using (var ch = env.Start("Saving model"))
@@ -182,26 +206,40 @@ public static CompositeReader<IMultiStreamSource, ITransformer> LoadPipeline(IHo
182206
}
183207
}
184208

209+
[Flags]
210+
public enum TransformerScope
211+
{
212+
None = 0,
213+
Training = 1 << 0,
214+
Testing = 1 << 1,
215+
Scoring = 1 << 2,
216+
TrainTest = Training | Testing,
217+
Everything = Training | Testing | Scoring
218+
}
219+
185220
public sealed class EstimatorChain<TLastTransformer> : IEstimator<TransformerChain<TLastTransformer>>
186221
where TLastTransformer : class, ITransformer
187222
{
223+
private readonly TransformerScope[] _scopes;
224+
188225
private readonly IEstimator<ITransformer>[] _estimators;
189226
public readonly IEstimator<TLastTransformer> LastEstimator;
190227

191-
public EstimatorChain(params IEstimator<ITransformer>[] estimators)
228+
private EstimatorChain(IEstimator<ITransformer>[] estimators, TransformerScope[] scopes)
192229
{
193-
Contracts.CheckValueOrNull(estimators);
194-
if (Utils.Size(estimators) == 0)
195-
{
196-
_estimators = new IEstimator<ITransformer>[0];
197-
LastEstimator = null;
198-
}
199-
else
200-
{
201-
_estimators = estimators;
202-
LastEstimator = estimators.Last() as IEstimator<TLastTransformer>;
203-
Contracts.Check(LastEstimator != null);
204-
}
230+
_estimators = estimators;
231+
_scopes = scopes;
232+
LastEstimator = estimators.Last() as IEstimator<TLastTransformer>;
233+
234+
Contracts.Check(LastEstimator != null);
235+
Contracts.Check(Utils.Size(estimators) == Utils.Size(scopes));
236+
}
237+
238+
public EstimatorChain()
239+
{
240+
_estimators = new IEstimator<ITransformer>[0];
241+
LastEstimator = null;
242+
_scopes = new TransformerScope[0];
205243
}
206244

207245
public TransformerChain<TLastTransformer> Fit(IDataView input)
@@ -215,7 +253,7 @@ public TransformerChain<TLastTransformer> Fit(IDataView input)
215253
dv = xfs[i].Transform(dv);
216254
}
217255

218-
return new TransformerChain<TLastTransformer>(xfs);
256+
return new TransformerChain<TLastTransformer>(xfs, _scopes);
219257
}
220258

221259
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
@@ -230,11 +268,11 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
230268
return s;
231269
}
232270

233-
public EstimatorChain<TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator)
271+
public EstimatorChain<TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator, TransformerScope scope = TransformerScope.Everything)
234272
where TNewTrans : class, ITransformer
235273
{
236274
Contracts.CheckValue(estimator, nameof(estimator));
237-
return new EstimatorChain<TNewTrans>(_estimators.Append(estimator).ToArray());
275+
return new EstimatorChain<TNewTrans>(_estimators.Append(estimator).ToArray(), _scopes.Append(scope).ToArray());
238276
}
239277
}
240278

@@ -282,16 +320,18 @@ public CompositeReaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator
282320

283321
public static class LearningPipelineExtensions
284322
{
285-
public static CompositeReaderEstimator<TSource, ITransformer> StartPipe<TSource>(this IDataReaderEstimator<TSource, IDataReader<TSource>> start)
286-
{
287-
return new CompositeReaderEstimator<TSource, ITransformer>(start);
288-
}
289-
290323
public static CompositeReaderEstimator<TSource, TTrans> Append<TSource, TTrans>(
291324
this IDataReaderEstimator<TSource, IDataReader<TSource>> start, IEstimator<TTrans> estimator)
292325
where TTrans : class, ITransformer
293326
{
294327
return new CompositeReaderEstimator<TSource, ITransformer>(start).Append(estimator);
295328
}
329+
330+
public static EstimatorChain<TTrans> Append<TTrans>(
331+
this IEstimator<ITransformer> start, IEstimator<TTrans> estimator, TransformerScope scope = TransformerScope.Everything)
332+
where TTrans : class, ITransformer
333+
{
334+
return new EstimatorChain<ITransformer>().Append(start).Append(estimator, scope);
335+
}
296336
}
297337
}

test/Microsoft.ML.Tests/Scenarios/Api/Wrappers.cs

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,87 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
297297
}
298298
}
299299

300+
public class MyTermTransform: IEstimator<TransformWrapper>
301+
{
302+
private readonly IHostEnvironment _env;
303+
private readonly string _column;
304+
private readonly string _srcColumn;
305+
306+
public MyTermTransform(IHostEnvironment env, string column, string srcColumn = null)
307+
{
308+
_env = env;
309+
_column = column;
310+
_srcColumn = srcColumn;
311+
}
312+
313+
public TransformWrapper Fit(IDataView input)
314+
{
315+
var xf = new TermTransform(_env, input, _column, _srcColumn);
316+
var empty = new EmptyDataView(_env, input.Schema);
317+
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
318+
return new TransformWrapper(_env, chunk);
319+
}
320+
321+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
322+
{
323+
throw new NotImplementedException();
324+
}
325+
}
326+
327+
public class MyConcatTransform: IEstimator<TransformWrapper>
328+
{
329+
private readonly IHostEnvironment _env;
330+
private readonly string _name;
331+
private readonly string[] _source;
332+
333+
public MyConcatTransform(IHostEnvironment env, string name, params string[] source)
334+
{
335+
_env = env;
336+
_name = name;
337+
_source = source;
338+
}
339+
340+
public TransformWrapper Fit(IDataView input)
341+
{
342+
var xf = new ConcatTransform(_env, input, _name, _source);
343+
var empty = new EmptyDataView(_env, input.Schema);
344+
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
345+
return new TransformWrapper(_env, chunk);
346+
}
347+
348+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
349+
{
350+
throw new NotImplementedException();
351+
}
352+
}
353+
354+
public class MyKeyToValueTransform: IEstimator<TransformWrapper>
355+
{
356+
private readonly IHostEnvironment _env;
357+
private readonly string _name;
358+
private readonly string _source;
359+
360+
public MyKeyToValueTransform(IHostEnvironment env, string name, string source = null)
361+
{
362+
_env = env;
363+
_name = name;
364+
_source = source;
365+
}
366+
367+
public TransformWrapper Fit(IDataView input)
368+
{
369+
var xf = new KeyToValueTransform(_env, input, _name, _source);
370+
var empty = new EmptyDataView(_env, input.Schema);
371+
var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
372+
return new TransformWrapper(_env, chunk);
373+
}
374+
375+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
376+
{
377+
throw new NotImplementedException();
378+
}
379+
}
380+
300381
public sealed class MySdca : TrainerBase<IPredictor>
301382
{
302383
private readonly LinearClassificationTrainer.Arguments _args;
@@ -312,6 +393,19 @@ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args,
312393
public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
313394
}
314395

396+
public sealed class MySdcaMulticlass: TrainerBase<IPredictor>
397+
{
398+
private readonly SdcaMultiClassTrainer.Arguments _args;
399+
400+
public MySdcaMulticlass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments args, string featureCol, string labelCol)
401+
: base(env, true, true, featureCol, labelCol)
402+
{
403+
_args = args;
404+
}
405+
406+
protected override IPredictor TrainCore(TrainContext context) => new SdcaMultiClassTrainer(_env, _args).Train(context);
407+
}
408+
315409
public sealed class MyAveragedPerceptron : TrainerBase<IPredictor>
316410
{
317411
private readonly AveragedPerceptronTrainer _trainer;
@@ -382,5 +476,10 @@ public static void SaveAsBinary(this IDataView data, IHostEnvironment env, Strea
382476
using (var ch = env.Start("SaveData"))
383477
DataSaverUtils.SaveDataView(ch, saver, data, stream);
384478
}
479+
480+
public static IDataView FitAndTransform(this IEstimator<ITransformer> est, IDataView data) => est.Fit(data).Transform(data);
481+
482+
public static IDataView FitAndRead<TSource>(this IDataReaderEstimator<TSource, IDataReader<TSource>> est, TSource source)
483+
=> est.Fit(source).Read(source);
385484
}
386485
}

0 commit comments

Comments
 (0)