@@ -195,14 +195,18 @@ private sealed class GroupSchema : ISchema
195
195
private readonly IExceptionContext _ectx ;
196
196
private readonly Schema _input ;
197
197
198
+ // Column names in source schema used to group rows.
198
199
private readonly string [ ] _groupColumns ;
200
+ // Column names in source schema grouped into rows.
199
201
private readonly string [ ] _keepColumns ;
200
202
203
+ // GroupIds[i] is the i-th group-key column's column index in the source schema.
201
204
public readonly int [ ] GroupIds ;
205
+ // GroupIds[i] is the i-th grouped column's column index in the source schema.
202
206
public readonly int [ ] KeepIds ;
203
207
204
208
private readonly int _groupCount ;
205
- private readonly ColumnType [ ] _columnTypes ;
209
+ private readonly ColumnType [ ] _keepColumnTypes ;
206
210
207
211
private readonly Dictionary < string , int > _columnNameMap ;
208
212
@@ -224,10 +228,20 @@ public GroupSchema(IExceptionContext ectx, Schema inputSchema, string[] groupCol
224
228
_keepColumns = keepColumns ;
225
229
KeepIds = GetColumnIds ( inputSchema , keepColumns , x => _ectx . ExceptUserArg ( nameof ( Arguments . Column ) , x ) ) ;
226
230
227
- _columnTypes = BuildColumnTypes ( _input , KeepIds ) ;
231
+ _keepColumnTypes = BuildColumnTypes ( _input , KeepIds ) ;
228
232
_columnNameMap = BuildColumnNameMap ( ) ;
229
233
230
- AsSchema = Schema . Create ( this ) ;
234
+ var schemaBuilder = new SchemaBuilder ( ) ;
235
+ foreach ( var groupKeyColumnName in groupColumns )
236
+ schemaBuilder . AddColumn ( groupKeyColumnName , inputSchema [ groupKeyColumnName ] . Type , inputSchema [ groupKeyColumnName ] . Metadata ) ;
237
+ foreach ( var groupValueColumnName in keepColumns )
238
+ {
239
+ var metadataBuilder = new MetadataBuilder ( ) ;
240
+ metadataBuilder . Add ( inputSchema [ groupValueColumnName ] . Metadata ,
241
+ s => s == MetadataUtils . Kinds . IsNormalized || s == MetadataUtils . Kinds . KeyValues ) ;
242
+ schemaBuilder . AddColumn ( groupValueColumnName , inputSchema [ groupValueColumnName ] . Type , metadataBuilder . GetMetadata ( ) ) ;
243
+ }
244
+ AsSchema = schemaBuilder . GetSchema ( ) ;
231
245
}
232
246
233
247
public GroupSchema ( Schema inputSchema , IHostEnvironment env , ModelLoadContext ctx )
@@ -261,7 +275,7 @@ public GroupSchema(Schema inputSchema, IHostEnvironment env, ModelLoadContext ct
261
275
262
276
KeepIds = GetColumnIds ( inputSchema , _keepColumns , _ectx . Except ) ;
263
277
264
- _columnTypes = BuildColumnTypes ( _input , KeepIds ) ;
278
+ _keepColumnTypes = BuildColumnTypes ( _input , KeepIds ) ;
265
279
_columnNameMap = BuildColumnNameMap ( ) ;
266
280
267
281
AsSchema = Schema . Create ( this ) ;
@@ -319,23 +333,34 @@ public void Save(ModelSaveContext ctx)
319
333
}
320
334
}
321
335
336
+ /// <summary>
337
+ /// Given column names, extract and return column indexes from source schema.
338
+ /// </summary>
339
+ /// <param name="schema">Source schema</param>
340
+ /// <param name="names">Column names</param>
341
+ /// <param name="except">Marked exception function</param>
342
+ /// <returns>column indexes</returns>
322
343
private int [ ] GetColumnIds ( Schema schema , string [ ] names , Func < string , Exception > except )
323
344
{
324
345
Contracts . AssertValue ( schema ) ;
325
346
Contracts . AssertValue ( names ) ;
326
347
327
348
var ids = new int [ names . Length ] ;
349
+
328
350
for ( int i = 0 ; i < names . Length ; i ++ )
329
351
{
330
- int col ;
331
- if ( ! schema . TryGetColumnIndex ( names [ i ] , out col ) )
352
+ // Find column called names[i] from input schema.
353
+ var retrievedColumn = schema . GetColumnOrNull ( names [ i ] ) ;
354
+
355
+ // Throw if no such a schema.
356
+ if ( ! retrievedColumn . HasValue )
332
357
throw except ( string . Format ( "Could not find column '{0}'" , names [ i ] ) ) ;
333
358
334
- var colType = schema [ col ] . Type ;
359
+ var colType = retrievedColumn . Value . Type ;
335
360
if ( ! colType . IsPrimitive )
336
361
throw except ( string . Format ( "Column '{0}' has type '{1}', but must have a primitive type" , names [ i ] , colType ) ) ;
337
362
338
- ids [ i ] = col ;
363
+ ids [ i ] = retrievedColumn . Value . Index ;
339
364
}
340
365
341
366
return ids ;
@@ -366,7 +391,7 @@ public ColumnType GetColumnType(int col)
366
391
CheckColumnInRange ( col ) ;
367
392
if ( col < _groupCount )
368
393
return _input [ GroupIds [ col ] ] . Type ;
369
- return _columnTypes [ col - _groupCount ] ;
394
+ return _keepColumnTypes [ col - _groupCount ] ;
370
395
}
371
396
372
397
public IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypes ( int col )
0 commit comments