@@ -125,12 +125,21 @@ internal static DatabaseLoader CreateDatabaseLoader<TInput>(IHostEnvironment hos
125
125
var column = new Column ( ) ;
126
126
column . Name = mappingAttrName ? . Name ?? memberInfo . Name ;
127
127
128
- var mappingAttr = memberInfo . GetCustomAttribute < LoadColumnAttribute > ( ) ;
128
+ var indexMappingAttr = memberInfo . GetCustomAttribute < LoadColumnAttribute > ( ) ;
129
+ var nameMappingAttr = memberInfo . GetCustomAttribute < LoadColumnNameAttribute > ( ) ;
129
130
130
- if ( mappingAttr is object )
131
+ if ( indexMappingAttr is object )
131
132
{
132
- var sources = mappingAttr . Sources . Select ( ( source ) => Range . FromTextLoaderRange ( source ) ) . ToArray ( ) ;
133
- column . Source = sources ;
133
+ if ( nameMappingAttr is object )
134
+ {
135
+ throw Contracts . Except ( $ "Cannot specify both { nameof ( LoadColumnAttribute ) } and { nameof ( LoadColumnNameAttribute ) } ") ;
136
+ }
137
+
138
+ column . Source = indexMappingAttr . Sources . Select ( ( source ) => Range . FromTextLoaderRange ( source ) ) . ToArray ( ) ;
139
+ }
140
+ else if ( nameMappingAttr is object )
141
+ {
142
+ column . Source = nameMappingAttr . Sources . Select ( ( source ) => new Range ( source ) ) . ToArray ( ) ;
134
143
}
135
144
136
145
InternalDataKind dk ;
@@ -228,7 +237,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
228
237
public DbType Type = DbType . Single ;
229
238
230
239
/// <summary>
231
- /// Source index range(s) of the column.
240
+ /// Source index or name range(s) of the column.
232
241
/// </summary>
233
242
[ Argument ( ArgumentType . Multiple , HelpText = "Source index range(s) of the column" , ShortName = "src" ) ]
234
243
public Range [ ] Source ;
@@ -241,7 +250,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
241
250
}
242
251
243
252
/// <summary>
244
- /// Specifies the range of indices of input columns that should be mapped to an output column.
253
+ /// Specifies the range of indices or names of input columns that should be mapped to an output column.
245
254
/// </summary>
246
255
public sealed class Range
247
256
{
@@ -256,6 +265,19 @@ public Range(int index)
256
265
Contracts . CheckParam ( index >= 0 , nameof ( index ) , "Must be non-negative" ) ;
257
266
Min = index ;
258
267
Max = index ;
268
+ Name = null ;
269
+ }
270
+
271
+ /// <summary>
272
+ /// A range representing a single value. Will result in a scalar column.
273
+ /// </summary>
274
+ /// <param name="name">The name of the field of the table to read.</param>
275
+ public Range ( string name )
276
+ {
277
+ Contracts . CheckValue ( name , nameof ( name ) ) ;
278
+ Min = - 1 ;
279
+ Max = - 1 ;
280
+ Name = name ;
259
281
}
260
282
261
283
/// <summary>
@@ -278,15 +300,30 @@ public Range(int min, int max)
278
300
/// <summary>
279
301
/// The minimum index of the column, inclusive.
280
302
/// </summary>
303
+ /// <remarks>
304
+ /// This value is ignored if <see cref="Name" /> is not <c>null</c>.
305
+ /// </remarks>
281
306
[ Argument ( ArgumentType . Required , HelpText = "First index in the range" ) ]
282
307
public int Min ;
283
308
284
309
/// <summary>
285
310
/// The maximum index of the column, inclusive.
286
311
/// </summary>
312
+ /// <remarks>
313
+ /// This value is ignored if <see cref="Name" /> is not <c>null</c>.
314
+ /// </remarks>
287
315
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Last index in the range" ) ]
288
316
public int Max ;
289
317
318
+ /// <summary>
319
+ /// The name of the input column.
320
+ /// </summary>
321
+ /// <remarks>
322
+ /// This value, if non-<c>null</c>, overrides <see cref="Min" /> and <see cref="Max" />.
323
+ /// </remarks>
324
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Name of the column" ) ]
325
+ public string Name ;
326
+
290
327
/// <summary>
291
328
/// Force scalar columns to be treated as vectors of length one.
292
329
/// </summary>
@@ -318,17 +355,28 @@ public sealed class Options
318
355
/// </summary>
319
356
internal readonly struct Segment
320
357
{
358
+ public readonly string Name ;
321
359
public readonly int Min ;
322
360
public readonly int Lim ;
323
361
public readonly bool ForceVector ;
324
362
325
363
public Segment ( int min , int lim , bool forceVector )
326
364
{
327
365
Contracts . Assert ( 0 <= min & min < lim ) ;
366
+ Name = null ;
328
367
Min = min ;
329
368
Lim = lim ;
330
369
ForceVector = forceVector ;
331
370
}
371
+
372
+ public Segment ( string name , bool forceVector )
373
+ {
374
+ Contracts . Assert ( name != null ) ;
375
+ Name = name ;
376
+ Min = - 1 ;
377
+ Lim = - 1 ;
378
+ ForceVector = forceVector ;
379
+ }
332
380
}
333
381
334
382
/// <summary>
@@ -368,19 +416,23 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
368
416
if ( segs != null )
369
417
{
370
418
var order = Utils . GetIdentityPermutation ( segs . Length ) ;
371
- Array . Sort ( order , ( x , y ) => segs [ x ] . Min . CompareTo ( segs [ y ] . Min ) ) ;
372
419
373
- // Check that the segments are disjoint.
374
- for ( int i = 1 ; i < order . Length ; i ++ )
420
+ if ( ( segs . Length != 0 ) && ( segs [ 0 ] . Name is null ) )
375
421
{
376
- int a = order [ i - 1 ] ;
377
- int b = order [ i ] ;
378
- Contracts . Assert ( segs [ a ] . Min <= segs [ b ] . Min ) ;
379
- if ( segs [ a ] . Lim > segs [ b ] . Min )
422
+ Array . Sort ( order , ( x , y ) => segs [ x ] . Min . CompareTo ( segs [ y ] . Min ) ) ;
423
+
424
+ // Check that the segments are disjoint.
425
+ for ( int i = 1 ; i < order . Length ; i ++ )
380
426
{
381
- throw user ?
382
- Contracts . ExceptUserArg ( nameof ( Column . Source ) , "Intervals specified for column '{0}' overlap" , name ) :
383
- Contracts . ExceptDecode ( "Intervals specified for column '{0}' overlap" , name ) ;
427
+ int a = order [ i - 1 ] ;
428
+ int b = order [ i ] ;
429
+ Contracts . Assert ( segs [ a ] . Min <= segs [ b ] . Min ) ;
430
+ if ( segs [ a ] . Lim > segs [ b ] . Min )
431
+ {
432
+ throw user ?
433
+ Contracts . ExceptUserArg ( nameof ( Column . Source ) , "Intervals specified for column '{0}' overlap" , name ) :
434
+ Contracts . ExceptDecode ( "Intervals specified for column '{0}' overlap" , name ) ;
435
+ }
384
436
}
385
437
}
386
438
@@ -389,7 +441,7 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
389
441
for ( int i = 0 ; i < segs . Length ; i ++ )
390
442
{
391
443
var seg = segs [ i ] ;
392
- size += seg . Lim - seg . Min ;
444
+ size += ( seg . Name is null ) ? seg . Lim - seg . Min : 1 ;
393
445
}
394
446
Contracts . Assert ( size >= segs . Length ) ;
395
447
@@ -454,15 +506,23 @@ public Bindings(DatabaseLoader parent, Column[] cols)
454
506
for ( int i = 0 ; i < segs . Length ; i ++ )
455
507
{
456
508
var range = col . Source [ i ] ;
457
-
458
- int min = range . Min ;
459
- ch . CheckUserArg ( 0 <= min , nameof ( range . Min ) ) ;
460
-
461
509
Segment seg ;
462
510
463
- int max = range . Max ;
464
- ch . CheckUserArg ( min <= max , nameof ( range . Max ) ) ;
465
- seg = new Segment ( min , max + 1 , range . ForceVector ) ;
511
+ if ( range . Name is null )
512
+ {
513
+ int min = range . Min ;
514
+ ch . CheckUserArg ( 0 <= min , nameof ( range . Min ) ) ;
515
+
516
+ int max = range . Max ;
517
+ ch . CheckUserArg ( min <= max , nameof ( range . Max ) ) ;
518
+ seg = new Segment ( min , max + 1 , range . ForceVector ) ;
519
+ }
520
+ else
521
+ {
522
+ string columnName = range . Name ;
523
+ ch . CheckUserArg ( columnName != null , nameof ( range . Name ) ) ;
524
+ seg = new Segment ( columnName , range . ForceVector ) ;
525
+ }
466
526
467
527
segs [ i ] = seg ;
468
528
}
@@ -490,6 +550,7 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
490
550
// ulong: count for key range
491
551
// int: number of segments
492
552
// foreach segment:
553
+ // string id: name
493
554
// int: min
494
555
// int: lim
495
556
// byte: force vector (verWrittenCur: verIsVectorSupported)
@@ -532,11 +593,12 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
532
593
segs = new Segment [ cseg ] ;
533
594
for ( int iseg = 0 ; iseg < cseg ; iseg ++ )
534
595
{
596
+ string columnName = ctx . LoadStringOrNull ( ) ;
535
597
int min = ctx . Reader . ReadInt32 ( ) ;
536
598
int lim = ctx . Reader . ReadInt32 ( ) ;
537
599
Contracts . CheckDecode ( 0 <= min && min < lim ) ;
538
600
bool forceVector = ctx . Reader . ReadBoolByte ( ) ;
539
- segs [ iseg ] = new Segment ( min , lim , forceVector ) ;
601
+ segs [ iseg ] = ( columnName is null ) ? new Segment ( min , lim , forceVector ) : new Segment ( columnName , forceVector ) ;
540
602
}
541
603
}
542
604
@@ -563,6 +625,7 @@ internal void Save(ModelSaveContext ctx)
563
625
// ulong: count for key range
564
626
// int: number of segments
565
627
// foreach segment:
628
+ // string id: name
566
629
// int: min
567
630
// int: lim
568
631
// byte: force vector (verWrittenCur: verIsVectorSupported)
@@ -588,6 +651,7 @@ internal void Save(ModelSaveContext ctx)
588
651
ctx . Writer . Write ( info . Segments . Length ) ;
589
652
foreach ( var seg in info . Segments )
590
653
{
654
+ ctx . SaveStringOrNull ( seg . Name ) ;
591
655
ctx . Writer . Write ( seg . Min ) ;
592
656
ctx . Writer . Write ( seg . Lim ) ;
593
657
ctx . Writer . WriteBoolByte ( seg . ForceVector ) ;
0 commit comments