Skip to content

Commit 91dc0f2

Browse files
author
Pete Luferenko
committed
Added prediction engine to playground
1 parent 3826648 commit 91dc0f2

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

src/Microsoft.ML.Core/Data/IEstimator.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ public interface ITransformer<TIn>
111111
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
112112
/// </summary>
113113
IDataView Transform(TIn input);
114+
115+
/// <summary>
116+
/// The output schema of the transformer.
117+
/// </summary>
118+
ISchema GetOutputSchema();
114119
}
115120

116121
/// <summary>
@@ -139,14 +144,20 @@ public interface IEstimator<TIn>
139144
/// The data transformer, in addition to being a transformer, also exposes the input schema shape. It is handy for
140145
/// evaluating what kind of columns the transformer expects.
141146
/// </summary>
142-
public interface IDataTransformer : ITransformer<IDataView>
147+
public interface IDataTransformer
143148
{
144149
/// <summary>
145150
/// Schema propagation for transformers.
146151
/// Returns the output schema of the data, if the input schema is like the one provided.
147152
/// Returns <c>null</c> iff the schema is invalid (then a call to Transform with this data will fail).
148153
/// </summary>
149154
ISchema GetOutputSchema(ISchema inputSchema);
155+
156+
/// <summary>
157+
/// Take the data in, make transformations, output the data.
158+
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
159+
/// </summary>
160+
IDataView Transform(IDataView input);
150161
}
151162

152163
public interface IDataEstimator

test/Microsoft.ML.Core.Tests/UnitTests/AdHocTest.cs

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ public IDataView Transform(IMultiStreamSource input)
7878
{
7979
return new TextLoader(new TlcEnvironment(), _args, input);
8080
}
81+
82+
ISchema ITransformer<IMultiStreamSource>.GetOutputSchema()
83+
{
84+
var emptyData = new TextLoader(new TlcEnvironment(), _args, new MultiFileSource(null));
85+
return emptyData.Schema;
86+
}
8187
}
8288

8389
public class TransformerPipe<TIn> : ITransformer<TIn>
@@ -98,6 +104,19 @@ public IDataView Transform(TIn input)
98104
idv = xf.Transform(idv);
99105
return idv;
100106
}
107+
108+
public (ITransformer<TIn>, IEnumerable<IDataTransformer>) GetParts()
109+
{
110+
return (_start, _chain);
111+
}
112+
113+
public ISchema GetOutputSchema()
114+
{
115+
var s = _start.GetOutputSchema();
116+
foreach (var xf in _chain)
117+
s = xf.GetOutputSchema(s);
118+
return s;
119+
}
101120
}
102121

103122
public class EstimatorPipe<TIn> : IEstimator<TIn>
@@ -118,7 +137,7 @@ public EstimatorPipe<TIn> Append(IDataEstimator est)
118137
return this;
119138
}
120139

121-
public ITransformer<TIn> Fit(TIn input)
140+
public TransformerPipe<TIn> Fit(TIn input)
122141
{
123142
var start = _start.Fit(input);
124143

@@ -140,7 +159,24 @@ public IEstimator<TIn> GetEstimator()
140159

141160
public SchemaShape GetOutputSchema()
142161
{
143-
throw new System.NotImplementedException();
162+
var shape = _start.GetOutputSchema();
163+
foreach (var xf in _estimatorChain)
164+
{
165+
shape = xf.GetOutputSchema(shape);
166+
if (shape == null)
167+
return null;
168+
}
169+
return shape;
170+
}
171+
172+
public (IEstimator<TIn>, IEnumerable<IDataEstimator>) GetParts()
173+
{
174+
return (_start, _estimatorChain);
175+
}
176+
177+
ITransformer<TIn> IEstimator<TIn>.Fit(TIn input)
178+
{
179+
return Fit(input);
144180
}
145181
}
146182

@@ -300,6 +336,26 @@ public IDataView Transform(IDataView input)
300336
}
301337
}
302338

339+
public class MyPredictionEngine<TSrc, TDst>
340+
where TSrc : class
341+
where TDst : class, new()
342+
{
343+
private readonly PredictionEngine<TSrc, TDst> _engine;
344+
345+
public MyPredictionEngine(IHostEnvironment env, ISchema inputSchema, IEnumerable<IDataTransformer> steps)
346+
{
347+
IDataView dv = new EmptyDataView(env, inputSchema);
348+
foreach (var s in steps)
349+
dv = s.Transform(dv);
350+
_engine = env.CreatePredictionEngine<TSrc, TDst>(dv);
351+
}
352+
353+
public TDst Predict(TSrc example)
354+
{
355+
return _engine.Predict(example);
356+
}
357+
}
358+
303359

304360
public class IrisPrediction
305361
{
@@ -330,6 +386,19 @@ public void TestEstimatorPipe()
330386
var scoredTrainData = model.Transform(new MultiFileSource(@"e:\data\iris.txt"))
331387
.AsEnumerable<IrisPrediction>(env, reuseRowObject: false)
332388
.ToArray();
389+
390+
ITransformer<IMultiStreamSource> loader;
391+
IEnumerable<IDataTransformer> steps;
392+
(loader, steps) = model.GetParts();
393+
394+
var engine = new MyPredictionEngine<IrisData, IrisPrediction>(env, loader.GetOutputSchema(), steps);
395+
IrisPrediction prediction = engine.Predict(new IrisData()
396+
{
397+
SepalLength = 5.1f,
398+
SepalWidth = 3.3f,
399+
PetalLength = 1.6f,
400+
PetalWidth = 0.2f,
401+
});
333402
}
334403
}
335404
}

0 commit comments

Comments
 (0)