Skip to content

Commit 3826648

Browse files
author
Pete Luferenko
committed
Added an ad hoc test playground
1 parent 3b2edab commit 3826648

File tree

2 files changed

+340
-2
lines changed

2 files changed

+340
-2
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,12 @@ public interface IDataTransformer : ITransformer<IDataView>
149149
ISchema GetOutputSchema(ISchema inputSchema);
150150
}
151151

152-
public interface IDataEstimator : IEstimator<IDataView>
152+
public interface IDataEstimator
153153
{
154-
new IDataTransformer Fit(IDataView input);
154+
/// <summary>
155+
/// Train and return a transformer.
156+
/// </summary>
157+
IDataTransformer Fit(IDataView input);
155158

156159
/// <summary>
157160
/// Schema propagation for estimators.
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
using Microsoft.ML.Core.Data;
2+
using Microsoft.ML.Runtime;
3+
using Microsoft.ML.Runtime.Api;
4+
using Microsoft.ML.Runtime.Data;
5+
using Microsoft.ML.Runtime.Data.IO;
6+
using Microsoft.ML.Runtime.Learners;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
using Xunit;
10+
11+
namespace Microsoft.ML.Core.Tests.UnitTests
12+
{
13+
public class AdHocTest
14+
{
15+
private static TextLoader.Arguments MakeTextLoaderArgs()
16+
{
17+
return new TextLoader.Arguments()
18+
{
19+
HasHeader = false,
20+
Column = new[] {
21+
new TextLoader.Column()
22+
{
23+
Name = "Label",
24+
Source = new [] { new TextLoader.Range() { Min = 0, Max = 0} },
25+
Type = DataKind.R4
26+
},
27+
new TextLoader.Column()
28+
{
29+
Name = "SepalLength",
30+
Source = new [] { new TextLoader.Range() { Min = 1, Max = 1} },
31+
Type = DataKind.R4
32+
},
33+
new TextLoader.Column()
34+
{
35+
Name = "SepalWidth",
36+
Source = new [] { new TextLoader.Range() { Min = 2, Max = 2} },
37+
Type = DataKind.R4
38+
},
39+
new TextLoader.Column()
40+
{
41+
Name = "PetalLength",
42+
Source = new [] { new TextLoader.Range() { Min = 3, Max = 3} },
43+
Type = DataKind.R4
44+
},
45+
new TextLoader.Column()
46+
{
47+
Name = "PetalWidth",
48+
Source = new [] { new TextLoader.Range() { Min = 4, Max = 4} },
49+
Type = DataKind.R4
50+
}
51+
}
52+
};
53+
}
54+
55+
public class MyTextLoader : IEstimator<IMultiStreamSource>, ITransformer<IMultiStreamSource>
56+
{
57+
private readonly TextLoader.Arguments _args;
58+
private readonly IHostEnvironment _env;
59+
60+
public MyTextLoader(IHostEnvironment env, TextLoader.Arguments args)
61+
{
62+
_env = env;
63+
_args = args;
64+
}
65+
66+
public ITransformer<IMultiStreamSource> Fit(IMultiStreamSource input)
67+
{
68+
return this;
69+
}
70+
71+
public SchemaShape GetOutputSchema()
72+
{
73+
var emptyData = new TextLoader(new TlcEnvironment(), _args, new MultiFileSource(null));
74+
return SchemaShape.Create(emptyData.Schema);
75+
}
76+
77+
public IDataView Transform(IMultiStreamSource input)
78+
{
79+
return new TextLoader(new TlcEnvironment(), _args, input);
80+
}
81+
}
82+
83+
public class TransformerPipe<TIn> : ITransformer<TIn>
84+
{
85+
private readonly ITransformer<TIn> _start;
86+
private readonly IDataTransformer[] _chain;
87+
88+
public TransformerPipe(ITransformer<TIn> start, IDataTransformer[] chain)
89+
{
90+
_start = start;
91+
_chain = chain;
92+
}
93+
94+
public IDataView Transform(TIn input)
95+
{
96+
var idv = _start.Transform(input);
97+
foreach (var xf in _chain)
98+
idv = xf.Transform(idv);
99+
return idv;
100+
}
101+
}
102+
103+
public class EstimatorPipe<TIn> : IEstimator<TIn>
104+
{
105+
private readonly IEstimator<TIn> _start;
106+
private readonly List<IDataEstimator> _estimatorChain = new List<IDataEstimator>();
107+
private readonly IHostEnvironment _env = new TlcEnvironment();
108+
109+
110+
public EstimatorPipe(IEstimator<TIn> start)
111+
{
112+
_start = start;
113+
}
114+
115+
public EstimatorPipe<TIn> Append(IDataEstimator est)
116+
{
117+
_estimatorChain.Add(est);
118+
return this;
119+
}
120+
121+
public ITransformer<TIn> Fit(TIn input)
122+
{
123+
var start = _start.Fit(input);
124+
125+
var idv = start.Transform(input);
126+
var xfs = new List<IDataTransformer>();
127+
foreach (var est in _estimatorChain)
128+
{
129+
var xf = est.Fit(idv);
130+
xfs.Add(xf);
131+
idv = xf.Transform(idv);
132+
}
133+
return new TransformerPipe<TIn>(start, xfs.ToArray());
134+
}
135+
136+
public IEstimator<TIn> GetEstimator()
137+
{
138+
return this;
139+
}
140+
141+
public SchemaShape GetOutputSchema()
142+
{
143+
throw new System.NotImplementedException();
144+
}
145+
}
146+
147+
public class MyConcatTransformer : IDataEstimator, IDataTransformer
148+
{
149+
private readonly ConcatTransform _xf;
150+
private readonly IHostEnvironment _env;
151+
private readonly string _name;
152+
private readonly string[] _source;
153+
154+
public MyConcatTransformer(IHostEnvironment env, string name, params string[] source)
155+
{
156+
_env = env;
157+
_name = name;
158+
_source = source;
159+
}
160+
161+
private MyConcatTransformer(IHostEnvironment env, ConcatTransform xf)
162+
{
163+
_env = env;
164+
_xf = xf;
165+
}
166+
167+
public IDataTransformer Fit(IDataView input)
168+
{
169+
var xf = new ConcatTransform(_env, input, _name, _source);
170+
return new MyConcatTransformer(_env, xf);
171+
}
172+
173+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
174+
{
175+
var cols = inputSchema.Columns.ToList();
176+
177+
var selectedCols = cols.Where(x => _source.Contains(x.Name)).Cast<SchemaShape.RelaxedColumn>();
178+
var isFixed = selectedCols.All(x => x.Kind != SchemaShape.RelaxedColumn.VectorKind.VariableVector);
179+
var newCol = new SchemaShape.RelaxedColumn(_name,
180+
isFixed ? SchemaShape.RelaxedColumn.VectorKind.Vector : SchemaShape.RelaxedColumn.VectorKind.VariableVector,
181+
selectedCols.First().ItemKind, selectedCols.First().IsKey);
182+
183+
cols.Add(newCol);
184+
return new SchemaShape(cols.ToArray());
185+
}
186+
187+
public ISchema GetOutputSchema(ISchema inputSchema)
188+
{
189+
var dv = new EmptyDataView(_env, inputSchema);
190+
var output = ApplyTransformUtils.ApplyTransformToData(_env, _xf, dv);
191+
return output.Schema;
192+
}
193+
194+
public IDataView Transform(IDataView input)
195+
{
196+
return ApplyTransformUtils.ApplyTransformToData(_env, _xf, input);
197+
}
198+
}
199+
200+
public class MyNormalizer : IDataEstimator
201+
{
202+
private readonly IHostEnvironment _env;
203+
private readonly string _col;
204+
205+
public MyNormalizer(IHostEnvironment env, string col)
206+
{
207+
_env = env;
208+
_col = col;
209+
}
210+
211+
public IDataTransformer Fit(IDataView input)
212+
{
213+
return new Transformer(_env, input, _col);
214+
}
215+
216+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
217+
{
218+
return inputSchema;
219+
}
220+
221+
private class Transformer : IDataTransformer
222+
{
223+
private IHostEnvironment _env;
224+
private IDataTransform _xf;
225+
226+
public Transformer(IHostEnvironment env, IDataView input, string col)
227+
{
228+
_env = env;
229+
_xf = NormalizeTransform.CreateMinMaxNormalizer(env, input, col);
230+
}
231+
232+
public ISchema GetOutputSchema(ISchema inputSchema)
233+
{
234+
var dv = new EmptyDataView(_env, inputSchema);
235+
var output = ApplyTransformUtils.ApplyTransformToData(_env, _xf, dv);
236+
return output.Schema;
237+
}
238+
239+
public IDataView Transform(IDataView input)
240+
{
241+
return ApplyTransformUtils.ApplyTransformToData(_env, _xf, input);
242+
}
243+
}
244+
}
245+
246+
public class MySdca : IDataEstimator
247+
{
248+
249+
private readonly IHostEnvironment _env;
250+
251+
public MySdca(IHostEnvironment env)
252+
{
253+
_env = env;
254+
}
255+
256+
public IDataTransformer Fit(IDataView input)
257+
{
258+
// Train
259+
var trainer = new SdcaMultiClassTrainer(_env, new SdcaMultiClassTrainer.Arguments() { NumThreads = 1 });
260+
261+
// Explicity adding CacheDataView since caching is not working though trainer has 'Caching' On/Auto
262+
var cached = new CacheDataView(_env, input, prefetch: null);
263+
var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
264+
var pred = trainer.Train(trainRoles);
265+
266+
var scoreRoles = new RoleMappedData(input, label: "Label", feature: "Features");
267+
IDataScorerTransform scorer = ScoreUtils.GetScorer(pred, scoreRoles, _env, trainRoles.Schema);
268+
return new Transformer(_env, pred, scorer);
269+
}
270+
271+
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
272+
{
273+
throw new System.NotImplementedException();
274+
}
275+
276+
private sealed class Transformer : IDataTransformer
277+
{
278+
private IHostEnvironment _env;
279+
private IPredictor _pred;
280+
private IDataScorerTransform _xf;
281+
282+
public Transformer(IHostEnvironment env, IPredictorProducing<VBuffer<float>> pred, IDataScorerTransform scorer)
283+
{
284+
_env = env;
285+
_pred = pred;
286+
_xf = scorer;
287+
}
288+
289+
public ISchema GetOutputSchema(ISchema inputSchema)
290+
{
291+
var dv = new EmptyDataView(_env, inputSchema);
292+
var output = ApplyTransformUtils.ApplyTransformToData(_env, _xf, dv);
293+
return output.Schema;
294+
}
295+
296+
public IDataView Transform(IDataView input)
297+
{
298+
return ApplyTransformUtils.ApplyTransformToData(_env, _xf, input);
299+
}
300+
}
301+
}
302+
303+
304+
public class IrisPrediction
305+
{
306+
[ColumnName("Score")]
307+
public float[] PredictedLabels;
308+
}
309+
310+
public class IrisData
311+
{
312+
public float SepalLength;
313+
public float SepalWidth;
314+
public float PetalLength;
315+
public float PetalWidth;
316+
}
317+
318+
[Fact]
319+
public void TestEstimatorPipe()
320+
{
321+
var env = new TlcEnvironment();
322+
323+
var pipeline = new EstimatorPipe<IMultiStreamSource>(new MyTextLoader(env, MakeTextLoaderArgs()));
324+
pipeline.Append(new MyConcatTransformer(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
325+
.Append(new MyNormalizer(env, "Features"))
326+
.Append(new MySdca(env));
327+
328+
var model = pipeline.Fit(new MultiFileSource(@"e:\data\iris.txt"));
329+
330+
var scoredTrainData = model.Transform(new MultiFileSource(@"e:\data\iris.txt"))
331+
.AsEnumerable<IrisPrediction>(env, reuseRowObject: false)
332+
.ToArray();
333+
}
334+
}
335+
}

0 commit comments

Comments
 (0)