21
21
[ assembly: LoadableClass ( CopyColumnsTransform . Summary , typeof ( CopyColumnsTransform ) , null , typeof ( SignatureLoadDataTransform ) ,
22
22
CopyColumnsTransform . UserName , CopyColumnsTransform . LoaderSignature ) ]
23
23
24
+ [ assembly: LoadableClass ( CopyColumnsTransform . Summary , typeof ( CopyColumnsTransformer ) , null , typeof ( SignatureLoadModel ) ,
25
+ CopyColumnsTransform . UserName , CopyColumnsTransformer . LoaderSignature ) ]
26
+
24
27
namespace Microsoft . ML . Runtime . Data
25
28
{
26
29
public sealed class CopyColumnsTransform : OneToOneTransformBase
@@ -169,35 +172,36 @@ protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, ou
169
172
170
173
public sealed class CopyColumnsEstimator : IEstimator < CopyColumnsTransformer >
171
174
{
172
- private readonly ( string source , string name ) [ ] _columnsMapping ;
175
+ private readonly ( string Source , string Name ) [ ] _columns ;
173
176
private readonly IHost _host ;
177
+
174
178
public CopyColumnsEstimator ( IHostEnvironment env , string input , string output )
175
179
{
176
180
Contracts . CheckNonWhiteSpace ( input , nameof ( input ) ) ;
177
181
Contracts . CheckNonWhiteSpace ( output , nameof ( output ) ) ;
178
- _columnsMapping = new ( string , string ) [ 1 ] { ( input , output ) } ;
182
+ _columns = new ( string , string ) [ 1 ] { ( input , output ) } ;
179
183
_host = env . Register ( "CopyColumnsEstimator" ) ;
180
184
}
181
185
182
- public CopyColumnsEstimator ( IHostEnvironment env , ( string source , string name ) [ ] columns )
186
+ public CopyColumnsEstimator ( IHostEnvironment env , ( string Source , string Name ) [ ] columns )
183
187
{
184
188
Contracts . CheckValue ( columns , nameof ( columns ) ) ;
185
189
var newNames = new HashSet < string > ( ) ;
186
- foreach ( ( string source , string name ) pair in columns )
190
+ foreach ( var column in columns )
187
191
{
188
- if ( newNames . Contains ( pair . name ) )
189
- throw Contracts . ExceptUserArg ( nameof ( columns ) , $ "New column { pair . name } specified multiple times") ;
190
- newNames . Add ( pair . name ) ;
192
+ if ( newNames . Contains ( column . Name ) )
193
+ throw Contracts . ExceptUserArg ( nameof ( columns ) , $ "New column { column . Name } specified multiple times") ;
194
+ newNames . Add ( column . Name ) ;
191
195
}
192
- _columnsMapping = columns ;
196
+ _columns = columns ;
193
197
_host = env . Register ( "CopyColumnsEstimator" ) ;
194
198
}
195
199
196
200
public CopyColumnsTransformer Fit ( IDataView input )
197
201
{
198
202
// invoke schema validation.
199
203
GetOutputSchema ( SchemaShape . Create ( input . Schema ) ) ;
200
- return new CopyColumnsTransformer ( _host , _columnsMapping ) ;
204
+ return new CopyColumnsTransformer ( _host , _columns ) ;
201
205
}
202
206
203
207
public SchemaShape GetOutputSchema ( SchemaShape inputSchema )
@@ -206,7 +210,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
206
210
Contracts . CheckValue ( inputSchema . Columns , nameof ( inputSchema . Columns ) ) ;
207
211
var originDic = inputSchema . Columns . ToDictionary ( x => x . Name ) ;
208
212
var resultDic = inputSchema . Columns . ToDictionary ( x => x . Name ) ;
209
- foreach ( ( string source , string name ) pair in _columnsMapping )
213
+ foreach ( ( string source , string name ) pair in _columns )
210
214
{
211
215
if ( originDic . ContainsKey ( pair . source ) )
212
216
{
@@ -225,7 +229,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
225
229
226
230
public sealed class CopyColumnsTransformer : ITransformer , ICanSaveModel
227
231
{
228
- private readonly ( string source , string name ) [ ] _columns ;
232
+ private readonly ( string Source , string Name ) [ ] _columns ;
229
233
private readonly IHost _host ;
230
234
231
235
private class CopyColumnsRowMapper : IRowMapper
@@ -239,7 +243,7 @@ public CopyColumnsRowMapper(ISchema schema, (string source, string name)[] colum
239
243
_schema = schema ;
240
244
_columns = columns ;
241
245
_originalColumnSources = new HashSet < int > ( ) ;
242
- HashSet < string > sources = new HashSet < string > ( ) ;
246
+ var sources = new HashSet < string > ( ) ;
243
247
foreach ( var source in columns . Select ( x => x . source ) )
244
248
sources . Add ( source ) ;
245
249
for ( int i = 0 ; i < _schema . ColumnCount ; i ++ )
@@ -286,7 +290,7 @@ public void Save(ModelSaveContext ctx)
286
290
throw new NotImplementedException ( ) ;
287
291
}
288
292
}
289
- public CopyColumnsTransformer ( IHostEnvironment env , ( string source , string name ) [ ] columns )
293
+ public CopyColumnsTransformer ( IHostEnvironment env , ( string Source , string Name ) [ ] columns )
290
294
{
291
295
_columns = columns ;
292
296
_host = env . Register ( "CopyColumnsTransformer" ) ;
@@ -298,9 +302,49 @@ public ISchema GetOutputSchema(ISchema inputSchema)
298
302
return Transform ( new EmptyDataView ( _host , inputSchema ) ) . Schema ;
299
303
}
300
304
305
+ private static VersionInfo GetVersionInfo ( )
306
+ {
307
+ return new VersionInfo (
308
+ modelSignature : "COPYCOLT" ,
309
+ verWrittenCur : 0x00010001 , // Initial
310
+ verReadableCur : 0x00010001 ,
311
+ verWeCanReadBack : 0x00010001 ,
312
+ loaderSignature : LoaderSignature ) ;
313
+ }
314
+
315
+ public const string LoaderSignature = "CopyTransform" ;
316
+
301
317
public void Save ( ModelSaveContext ctx )
302
318
{
303
- throw new NotImplementedException ( ) ;
319
+ _host . CheckValue ( ctx , nameof ( ctx ) ) ;
320
+ ctx . CheckAtModel ( ) ;
321
+ ctx . SetVersionInfo ( GetVersionInfo ( ) ) ;
322
+
323
+ // *** Binary format ***
324
+ // int: number of added columns
325
+ // for each added column
326
+ // int: id of output column name
327
+ // int: id of input column name
328
+ ctx . Writer . Write ( _columns . Length ) ;
329
+ foreach ( var column in _columns )
330
+ {
331
+ ctx . SaveNonEmptyString ( column . Name ) ;
332
+ ctx . SaveNonEmptyString ( column . Source ) ;
333
+ }
334
+ }
335
+ public static CopyColumnsTransformer Create ( IHostEnvironment env , ModelLoadContext ctx )
336
+ {
337
+ Contracts . CheckValue ( env , nameof ( env ) ) ;
338
+ env . CheckValue ( ctx , nameof ( ctx ) ) ;
339
+ ctx . CheckAtModel ( GetVersionInfo ( ) ) ;
340
+ var lenght = ctx . Reader . ReadInt32 ( ) ;
341
+ var columns = new ( string Source , string Name ) [ lenght ] ;
342
+ for ( int i = 0 ; i < lenght ; i ++ )
343
+ {
344
+ columns [ i ] . Name = ctx . LoadNonEmptyString ( ) ;
345
+ columns [ i ] . Source = ctx . LoadNonEmptyString ( ) ;
346
+ }
347
+ return new CopyColumnsTransformer ( env , columns ) ;
304
348
}
305
349
306
350
public IDataView Transform ( IDataView input )
0 commit comments