4
4
using Microsoft . ML . Runtime . Internal . Utilities ;
5
5
using Microsoft . ML . Runtime . Model ;
6
6
using Microsoft . ML . Tests . Scenarios . Api ;
7
+ using System ;
7
8
using System . Collections . Generic ;
8
9
using System . IO ;
9
10
using System . Linq ;
@@ -17,10 +18,20 @@ public sealed class TransformerChain<TLastTransformer> : ITransformer, ICanSaveM
17
18
where TLastTransformer : class , ITransformer
18
19
{
19
20
private readonly ITransformer [ ] _transformers ;
21
+ private readonly TransformerScope [ ] _scopes ;
20
22
public readonly TLastTransformer LastTransformer ;
21
23
22
24
private const string TransformDirTemplate = "Transform_{0:000}" ;
23
25
26
+ internal TransformerChain ( ITransformer [ ] transformers , TransformerScope [ ] scopes )
27
+ {
28
+ _transformers = transformers . ToArray ( ) ;
29
+ _scopes = scopes . ToArray ( ) ;
30
+ LastTransformer = transformers . Last ( ) as TLastTransformer ;
31
+ Contracts . Check ( LastTransformer != null ) ;
32
+ Contracts . Check ( transformers . Length == scopes . Length ) ;
33
+ }
34
+
24
35
public TransformerChain ( params ITransformer [ ] transformers )
25
36
{
26
37
if ( Utils . Size ( transformers ) == 0 )
@@ -31,6 +42,7 @@ public TransformerChain(params ITransformer[] transformers)
31
42
else
32
43
{
33
44
_transformers = transformers . ToArray ( ) ;
45
+ _scopes = transformers . Select ( x => TransformerScope . Everything ) . ToArray ( ) ;
34
46
LastTransformer = transformers . Last ( ) as TLastTransformer ;
35
47
Contracts . Check ( LastTransformer != null ) ;
36
48
}
@@ -63,11 +75,26 @@ public IEnumerable<ITransformer> GetParts()
63
75
return _transformers ;
64
76
}
65
77
66
- public TransformerChain < TNewLast > Append < TNewLast > ( TNewLast transformer )
78
+ public TransformerChain < ITransformer > GetModelFor ( TransformerScope scopeFilter )
79
+ {
80
+ var xfs = new List < ITransformer > ( ) ;
81
+ var scopes = new List < TransformerScope > ( ) ;
82
+ for ( int i = 0 ; i < _transformers . Length ; i ++ )
83
+ {
84
+ if ( ( _scopes [ i ] & scopeFilter ) != TransformerScope . None )
85
+ {
86
+ xfs . Add ( _transformers [ i ] ) ;
87
+ scopes . Add ( _scopes [ i ] ) ;
88
+ }
89
+ }
90
+ return new TransformerChain < ITransformer > ( xfs . ToArray ( ) , scopes . ToArray ( ) ) ;
91
+ }
92
+
93
+ public TransformerChain < TNewLast > Append < TNewLast > ( TNewLast transformer , TransformerScope scope )
67
94
where TNewLast : class , ITransformer
68
95
{
69
96
Contracts . CheckValue ( transformer , nameof ( transformer ) ) ;
70
- return new TransformerChain < TNewLast > ( _transformers . Append ( transformer ) . ToArray ( ) ) ;
97
+ return new TransformerChain < TNewLast > ( _transformers . Append ( transformer ) . ToArray ( ) , _scopes . Append ( scope ) . ToArray ( ) ) ;
71
98
}
72
99
73
100
public void Save ( ModelSaveContext ctx )
@@ -79,6 +106,7 @@ public void Save(ModelSaveContext ctx)
79
106
80
107
for ( int i = 0 ; i < _transformers . Length ; i ++ )
81
108
{
109
+ ctx . Writer . Write ( ( int ) _scopes [ i ] ) ;
82
110
var dirName = string . Format ( TransformDirTemplate , i ) ;
83
111
ctx . SaveModel ( _transformers [ i ] , dirName ) ;
84
112
}
@@ -88,8 +116,10 @@ internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
88
116
{
89
117
int len = ctx . Reader . ReadInt32 ( ) ;
90
118
_transformers = new ITransformer [ len ] ;
119
+ _scopes = new TransformerScope [ len ] ;
91
120
for ( int i = 0 ; i < len ; i ++ )
92
121
{
122
+ _scopes [ i ] = ( TransformerScope ) ( ctx . Reader . ReadInt32 ( ) ) ;
93
123
var dirName = string . Format ( TransformDirTemplate , i ) ;
94
124
ctx . LoadModel < ITransformer , SignatureLoadModel > ( env , out _transformers [ i ] , dirName ) ;
95
125
}
@@ -146,12 +176,6 @@ public ISchema GetOutputSchema()
146
176
return s ;
147
177
}
148
178
149
- public CompositeReader < TSource , TNewLastTransformer > Append < TNewLastTransformer > ( TNewLastTransformer transformer )
150
- where TNewLastTransformer : class , ITransformer
151
- {
152
- return new CompositeReader < TSource , TNewLastTransformer > ( Reader , Transformer . Append ( transformer ) ) ;
153
- }
154
-
155
179
public void SavePipeline ( IHostEnvironment env , Stream outputStream )
156
180
{
157
181
using ( var ch = env . Start ( "Saving model" ) )
@@ -182,26 +206,40 @@ public static CompositeReader<IMultiStreamSource, ITransformer> LoadPipeline(IHo
182
206
}
183
207
}
184
208
209
+ [ Flags ]
210
+ public enum TransformerScope
211
+ {
212
+ None = 0 ,
213
+ Training = 1 << 0 ,
214
+ Testing = 1 << 1 ,
215
+ Scoring = 1 << 2 ,
216
+ TrainTest = Training | Testing ,
217
+ Everything = Training | Testing | Scoring
218
+ }
219
+
185
220
public sealed class EstimatorChain < TLastTransformer > : IEstimator < TransformerChain < TLastTransformer > >
186
221
where TLastTransformer : class , ITransformer
187
222
{
223
+ private readonly TransformerScope [ ] _scopes ;
224
+
188
225
private readonly IEstimator < ITransformer > [ ] _estimators ;
189
226
public readonly IEstimator < TLastTransformer > LastEstimator ;
190
227
191
- public EstimatorChain ( params IEstimator < ITransformer > [ ] estimators )
228
+ private EstimatorChain ( IEstimator < ITransformer > [ ] estimators , TransformerScope [ ] scopes )
192
229
{
193
- Contracts . CheckValueOrNull ( estimators ) ;
194
- if ( Utils . Size ( estimators ) == 0 )
195
- {
196
- _estimators = new IEstimator < ITransformer > [ 0 ] ;
197
- LastEstimator = null ;
198
- }
199
- else
200
- {
201
- _estimators = estimators ;
202
- LastEstimator = estimators . Last ( ) as IEstimator < TLastTransformer > ;
203
- Contracts . Check ( LastEstimator != null ) ;
204
- }
230
+ _estimators = estimators ;
231
+ _scopes = scopes ;
232
+ LastEstimator = estimators . Last ( ) as IEstimator < TLastTransformer > ;
233
+
234
+ Contracts . Check ( LastEstimator != null ) ;
235
+ Contracts . Check ( Utils . Size ( estimators ) == Utils . Size ( scopes ) ) ;
236
+ }
237
+
238
+ public EstimatorChain ( )
239
+ {
240
+ _estimators = new IEstimator < ITransformer > [ 0 ] ;
241
+ LastEstimator = null ;
242
+ _scopes = new TransformerScope [ 0 ] ;
205
243
}
206
244
207
245
public TransformerChain < TLastTransformer > Fit ( IDataView input )
@@ -215,7 +253,7 @@ public TransformerChain<TLastTransformer> Fit(IDataView input)
215
253
dv = xfs [ i ] . Transform ( dv ) ;
216
254
}
217
255
218
- return new TransformerChain < TLastTransformer > ( xfs ) ;
256
+ return new TransformerChain < TLastTransformer > ( xfs , _scopes ) ;
219
257
}
220
258
221
259
public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
@@ -230,11 +268,11 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
230
268
return s ;
231
269
}
232
270
233
- public EstimatorChain < TNewTrans > Append < TNewTrans > ( IEstimator < TNewTrans > estimator )
271
+ public EstimatorChain < TNewTrans > Append < TNewTrans > ( IEstimator < TNewTrans > estimator , TransformerScope scope = TransformerScope . Everything )
234
272
where TNewTrans : class , ITransformer
235
273
{
236
274
Contracts . CheckValue ( estimator , nameof ( estimator ) ) ;
237
- return new EstimatorChain < TNewTrans > ( _estimators . Append ( estimator ) . ToArray ( ) ) ;
275
+ return new EstimatorChain < TNewTrans > ( _estimators . Append ( estimator ) . ToArray ( ) , _scopes . Append ( scope ) . ToArray ( ) ) ;
238
276
}
239
277
}
240
278
@@ -282,16 +320,18 @@ public CompositeReaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator
282
320
283
321
public static class LearningPipelineExtensions
284
322
{
285
- public static CompositeReaderEstimator < TSource , ITransformer > StartPipe < TSource > ( this IDataReaderEstimator < TSource , IDataReader < TSource > > start )
286
- {
287
- return new CompositeReaderEstimator < TSource , ITransformer > ( start ) ;
288
- }
289
-
290
323
public static CompositeReaderEstimator < TSource , TTrans > Append < TSource , TTrans > (
291
324
this IDataReaderEstimator < TSource , IDataReader < TSource > > start , IEstimator < TTrans > estimator )
292
325
where TTrans : class , ITransformer
293
326
{
294
327
return new CompositeReaderEstimator < TSource , ITransformer > ( start ) . Append ( estimator ) ;
295
328
}
329
+
330
+ public static EstimatorChain < TTrans > Append < TTrans > (
331
+ this IEstimator < ITransformer > start , IEstimator < TTrans > estimator , TransformerScope scope = TransformerScope . Everything )
332
+ where TTrans : class , ITransformer
333
+ {
334
+ return new EstimatorChain < ITransformer > ( ) . Append ( start ) . Append ( estimator , scope ) ;
335
+ }
296
336
}
297
337
}
0 commit comments