@@ -78,6 +78,12 @@ public IDataView Transform(IMultiStreamSource input)
78
78
{
79
79
return new TextLoader ( new TlcEnvironment ( ) , _args , input ) ;
80
80
}
81
+
82
+ ISchema ITransformer < IMultiStreamSource > . GetOutputSchema ( )
83
+ {
84
+ var emptyData = new TextLoader ( new TlcEnvironment ( ) , _args , new MultiFileSource ( null ) ) ;
85
+ return emptyData . Schema ;
86
+ }
81
87
}
82
88
83
89
public class TransformerPipe < TIn > : ITransformer < TIn >
@@ -98,6 +104,19 @@ public IDataView Transform(TIn input)
98
104
idv = xf . Transform ( idv ) ;
99
105
return idv ;
100
106
}
107
+
108
+ public ( ITransformer < TIn > , IEnumerable < IDataTransformer > ) GetParts ( )
109
+ {
110
+ return ( _start , _chain ) ;
111
+ }
112
+
113
+ public ISchema GetOutputSchema ( )
114
+ {
115
+ var s = _start . GetOutputSchema ( ) ;
116
+ foreach ( var xf in _chain )
117
+ s = xf . GetOutputSchema ( s ) ;
118
+ return s ;
119
+ }
101
120
}
102
121
103
122
public class EstimatorPipe < TIn > : IEstimator < TIn >
@@ -118,7 +137,7 @@ public EstimatorPipe<TIn> Append(IDataEstimator est)
118
137
return this ;
119
138
}
120
139
121
- public ITransformer < TIn > Fit ( TIn input )
140
+ public TransformerPipe < TIn > Fit ( TIn input )
122
141
{
123
142
var start = _start . Fit ( input ) ;
124
143
@@ -140,7 +159,24 @@ public IEstimator<TIn> GetEstimator()
140
159
141
160
public SchemaShape GetOutputSchema ( )
142
161
{
143
- throw new System . NotImplementedException ( ) ;
162
+ var shape = _start . GetOutputSchema ( ) ;
163
+ foreach ( var xf in _estimatorChain )
164
+ {
165
+ shape = xf . GetOutputSchema ( shape ) ;
166
+ if ( shape == null )
167
+ return null ;
168
+ }
169
+ return shape ;
170
+ }
171
+
172
+ public ( IEstimator < TIn > , IEnumerable < IDataEstimator > ) GetParts ( )
173
+ {
174
+ return ( _start , _estimatorChain ) ;
175
+ }
176
+
177
+ ITransformer < TIn > IEstimator < TIn > . Fit ( TIn input )
178
+ {
179
+ return Fit ( input ) ;
144
180
}
145
181
}
146
182
@@ -300,6 +336,26 @@ public IDataView Transform(IDataView input)
300
336
}
301
337
}
302
338
339
+ public class MyPredictionEngine < TSrc , TDst >
340
+ where TSrc : class
341
+ where TDst : class , new ( )
342
+ {
343
+ private readonly PredictionEngine < TSrc , TDst > _engine ;
344
+
345
+ public MyPredictionEngine ( IHostEnvironment env , ISchema inputSchema , IEnumerable < IDataTransformer > steps )
346
+ {
347
+ IDataView dv = new EmptyDataView ( env , inputSchema ) ;
348
+ foreach ( var s in steps )
349
+ dv = s . Transform ( dv ) ;
350
+ _engine = env . CreatePredictionEngine < TSrc , TDst > ( dv ) ;
351
+ }
352
+
353
+ public TDst Predict ( TSrc example )
354
+ {
355
+ return _engine . Predict ( example ) ;
356
+ }
357
+ }
358
+
303
359
304
360
public class IrisPrediction
305
361
{
@@ -330,6 +386,19 @@ public void TestEstimatorPipe()
330
386
var scoredTrainData = model . Transform ( new MultiFileSource ( @"e:\data\iris.txt" ) )
331
387
. AsEnumerable < IrisPrediction > ( env , reuseRowObject : false )
332
388
. ToArray ( ) ;
389
+
390
+ ITransformer < IMultiStreamSource > loader ;
391
+ IEnumerable < IDataTransformer > steps ;
392
+ ( loader , steps ) = model . GetParts ( ) ;
393
+
394
+ var engine = new MyPredictionEngine < IrisData , IrisPrediction > ( env , loader . GetOutputSchema ( ) , steps ) ;
395
+ IrisPrediction prediction = engine . Predict ( new IrisData ( )
396
+ {
397
+ SepalLength = 5.1f ,
398
+ SepalWidth = 3.3f ,
399
+ PetalLength = 1.6f ,
400
+ PetalWidth = 0.2f ,
401
+ } ) ;
333
402
}
334
403
}
335
404
}
0 commit comments