@@ -520,26 +520,26 @@ public static ColInfo Create(string name, PrimitiveType itemType, Segment[] segs
520
520
}
521
521
}
522
522
523
- private sealed class Bindings : ISchema
523
+ private sealed class Bindings
524
524
{
525
+ /// <summary>
526
+ /// <see cref="Infos"/>[i] stores the i-th column's name and type. Columns are loaded from the input text file.
527
+ /// </summary>
525
528
public readonly ColInfo [ ] Infos ;
526
- public readonly Dictionary < string , int > NameToInfoIndex ;
529
+ /// <summary>
530
+ /// <see cref="Infos"/>[i] stores the i-th column's metadata, named <see cref="MetadataUtils.Kinds.SlotNames"/>
531
+ /// in <see cref="Schema.Metadata"/>.
532
+ /// </summary>
527
533
private readonly VBuffer < ReadOnlyMemory < char > > [ ] _slotNames ;
528
- // Empty iff either header+ not set in args, or if no header present, or upon load
529
- // there was no header stored in the model.
534
+ /// <summary>
535
+ /// Empty if <see cref="ArgumentsCore.HasHeader"/> is <see langword="false"/>, no header presents, or upon load
536
+ /// there was no header stored in the model.
537
+ /// </summary>
530
538
private readonly ReadOnlyMemory < char > _header ;
531
539
532
- private readonly MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > _getSlotNames ;
533
-
534
- public Schema AsSchema { get ; }
535
-
536
- private Bindings ( )
537
- {
538
- _getSlotNames = GetSlotNames ;
539
- }
540
+ public Schema OutputSchema { get ; }
540
541
541
542
public Bindings ( TextLoader parent , Column [ ] cols , IMultiStreamSource headerFile , IMultiStreamSource dataSample )
542
- : this ( )
543
543
{
544
544
Contracts . AssertNonEmpty ( cols ) ;
545
545
Contracts . AssertValueOrNull ( headerFile ) ;
@@ -590,14 +590,17 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
590
590
int isegOther = - 1 ;
591
591
592
592
Infos = new ColInfo [ cols . Length ] ;
593
- NameToInfoIndex = new Dictionary < string , int > ( Infos . Length ) ;
593
+
594
+ // This dictionary is used only for detecting duplicated column names specified by user.
595
+ var nameToInfoIndex = new Dictionary < string , int > ( Infos . Length ) ;
596
+
594
597
for ( int iinfo = 0 ; iinfo < Infos . Length ; iinfo ++ )
595
598
{
596
599
var col = cols [ iinfo ] ;
597
600
598
601
ch . CheckNonWhiteSpace ( col . Name , nameof ( col . Name ) ) ;
599
602
string name = col . Name . Trim ( ) ;
600
- if ( iinfo == NameToInfoIndex . Count && NameToInfoIndex . ContainsKey ( name ) )
603
+ if ( iinfo == nameToInfoIndex . Count && nameToInfoIndex . ContainsKey ( name ) )
601
604
ch . Info ( "Duplicate name(s) specified - later columns will hide earlier ones" ) ;
602
605
603
606
PrimitiveType itemType ;
@@ -669,7 +672,7 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
669
672
if ( iinfoOther != iinfo )
670
673
Infos [ iinfo ] = ColInfo . Create ( name , itemType , segs , true ) ;
671
674
672
- NameToInfoIndex [ name ] = iinfo ;
675
+ nameToInfoIndex [ name ] = iinfo ;
673
676
}
674
677
675
678
// Note that segsOther[isegOther] is not a real segment to be included.
@@ -734,11 +737,10 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
734
737
if ( ! _header . IsEmpty )
735
738
Parser . ParseSlotNames ( parent , _header , Infos , _slotNames ) ;
736
739
}
737
- AsSchema = Schema . Create ( this ) ;
740
+ OutputSchema = ComputeOutputSchema ( ) ;
738
741
}
739
742
740
743
public Bindings ( ModelLoadContext ctx , TextLoader parent )
741
- : this ( )
742
744
{
743
745
Contracts . AssertValue ( ctx ) ;
744
746
@@ -760,7 +762,9 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
760
762
int cinfo = ctx . Reader . ReadInt32 ( ) ;
761
763
Contracts . CheckDecode ( cinfo > 0 ) ;
762
764
Infos = new ColInfo [ cinfo ] ;
763
- NameToInfoIndex = new Dictionary < string , int > ( Infos . Length ) ;
765
+
766
+ // This dictionary is used only for detecting duplicated column names specified by user.
767
+ var nameToInfoIndex = new Dictionary < string , int > ( Infos . Length ) ;
764
768
765
769
for ( int iinfo = 0 ; iinfo < cinfo ; iinfo ++ )
766
770
{
@@ -808,7 +812,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
808
812
// of multiple variable segments (since those segments will overlap and overlapping
809
813
// segments are illegal).
810
814
Infos [ iinfo ] = ColInfo . Create ( name , itemType , segs , false ) ;
811
- NameToInfoIndex [ name ] = iinfo ;
815
+ nameToInfoIndex [ name ] = iinfo ;
812
816
}
813
817
814
818
_slotNames = new VBuffer < ReadOnlyMemory < char > > [ Infos . Length ] ;
@@ -818,7 +822,7 @@ public Bindings(ModelLoadContext ctx, TextLoader parent)
818
822
if ( ! string . IsNullOrEmpty ( result ) )
819
823
Parser . ParseSlotNames ( parent , _header = result . AsMemory ( ) , Infos , _slotNames ) ;
820
824
821
- AsSchema = Schema . Create ( this ) ;
825
+ OutputSchema = ComputeOutputSchema ( ) ;
822
826
}
823
827
824
828
public void Save ( ModelSaveContext ctx )
@@ -869,86 +873,29 @@ public void Save(ModelSaveContext ctx)
869
873
ctx . SaveTextStream ( "Header.txt" , writer => writer . WriteLine ( _header . ToString ( ) ) ) ;
870
874
}
871
875
872
- public int ColumnCount
873
- {
874
- get { return Infos . Length ; }
875
- }
876
-
877
- public bool TryGetColumnIndex ( string name , out int col )
878
- {
879
- Contracts . CheckValueOrNull ( name ) ;
880
- return NameToInfoIndex . TryGetValue ( name , out col ) ;
881
- }
882
-
883
- public string GetColumnName ( int col )
884
- {
885
- Contracts . CheckParam ( 0 <= col && col < Infos . Length , nameof ( col ) ) ;
886
- return Infos [ col ] . Name ;
887
- }
888
-
889
- public ColumnType GetColumnType ( int col )
890
- {
891
- Contracts . CheckParam ( 0 <= col && col < Infos . Length , nameof ( col ) ) ;
892
- return Infos [ col ] . ColType ;
893
- }
894
-
895
- public IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypes ( int col )
896
- {
897
- Contracts . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
898
-
899
- var names = _slotNames [ col ] ;
900
- if ( names . Length > 0 )
901
- {
902
- Contracts . Assert ( Infos [ col ] . ColType . VectorSize == names . Length ) ;
903
- yield return MetadataUtils . GetSlotNamesPair ( names . Length ) ;
904
- }
905
- }
906
-
907
- public ColumnType GetMetadataTypeOrNull ( string kind , int col )
908
- {
909
- Contracts . CheckNonEmpty ( kind , nameof ( kind ) ) ;
910
- Contracts . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
911
-
912
- switch ( kind )
913
- {
914
- case MetadataUtils . Kinds . SlotNames :
915
- var names = _slotNames [ col ] ;
916
- if ( names . Length == 0 )
917
- return null ;
918
- Contracts . Assert ( Infos [ col ] . ColType . VectorSize == names . Length ) ;
919
- return MetadataUtils . GetNamesType ( names . Length ) ;
920
-
921
- default :
922
- return null ;
923
- }
924
- }
925
-
926
- public void GetMetadata < TValue > ( string kind , int col , ref TValue value )
876
+ private Schema ComputeOutputSchema ( )
927
877
{
928
- Contracts . CheckNonEmpty ( kind , nameof ( kind ) ) ;
929
- Contracts . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
878
+ var schemaBuilder = new SchemaBuilder ( ) ;
930
879
931
- switch ( kind )
880
+ // Iterate through all loaded columns. The index i indicates the i-th column loaded.
881
+ for ( int i = 0 ; i < Infos . Length ; ++ i )
932
882
{
933
- case MetadataUtils . Kinds . SlotNames :
934
- _getSlotNames . Marshal ( col , ref value ) ;
935
- return ;
936
-
937
- default :
938
- throw MetadataUtils . ExceptGetMetadata ( ) ;
883
+ var info = Infos [ i ] ;
884
+ // Retrieve the only possible metadata of this class.
885
+ var names = _slotNames [ i ] ;
886
+ if ( names . Length > 0 )
887
+ {
888
+ // Slot names present! Let's add them.
889
+ var metadataBuilder = new MetadataBuilder ( ) ;
890
+ metadataBuilder . AddSlotNames ( names . Length , ( ref VBuffer < ReadOnlyMemory < char > > value ) => names . CopyTo ( ref value ) ) ;
891
+ schemaBuilder . AddColumn ( info . Name , info . ColType , metadataBuilder . GetMetadata ( ) ) ;
892
+ }
893
+ else
894
+ // Slot names is empty.
895
+ schemaBuilder . AddColumn ( info . Name , info . ColType ) ;
939
896
}
940
- }
941
-
942
- private void GetSlotNames ( int col , ref VBuffer < ReadOnlyMemory < char > > dst )
943
- {
944
- Contracts . Assert ( 0 <= col && col < ColumnCount ) ;
945
-
946
- var names = _slotNames [ col ] ;
947
- if ( names . Length == 0 )
948
- throw MetadataUtils . ExceptGetMetadata ( ) ;
949
897
950
- Contracts . Assert ( Infos [ col ] . ColType . VectorSize == names . Length ) ;
951
- names . CopyTo ( ref dst ) ;
898
+ return schemaBuilder . GetSchema ( ) ;
952
899
}
953
900
}
954
901
@@ -1355,7 +1302,7 @@ public void Save(ModelSaveContext ctx)
1355
1302
_bindings . Save ( ctx ) ;
1356
1303
}
1357
1304
1358
- public Schema GetOutputSchema ( ) => _bindings . AsSchema ;
1305
+ public Schema GetOutputSchema ( ) => _bindings . OutputSchema ;
1359
1306
1360
1307
public IDataView Read ( IMultiStreamSource source ) => new BoundLoader ( this , source ) ;
1361
1308
@@ -1455,21 +1402,21 @@ public BoundLoader(TextLoader reader, IMultiStreamSource files)
1455
1402
// REVIEW: Should we try to support shuffling?
1456
1403
public bool CanShuffle => false ;
1457
1404
1458
- public Schema Schema => _reader . _bindings . AsSchema ;
1405
+ public Schema Schema => _reader . _bindings . OutputSchema ;
1459
1406
1460
1407
public RowCursor GetRowCursor ( Func < int , bool > predicate , Random rand = null )
1461
1408
{
1462
1409
_host . CheckValue ( predicate , nameof ( predicate ) ) ;
1463
1410
_host . CheckValueOrNull ( rand ) ;
1464
- var active = Utils . BuildArray ( _reader . _bindings . ColumnCount , predicate ) ;
1411
+ var active = Utils . BuildArray ( _reader . _bindings . OutputSchema . Count , predicate ) ;
1465
1412
return Cursor . Create ( _reader , _files , active ) ;
1466
1413
}
1467
1414
1468
1415
public RowCursor [ ] GetRowCursorSet ( Func < int , bool > predicate , int n , Random rand = null )
1469
1416
{
1470
1417
_host . CheckValue ( predicate , nameof ( predicate ) ) ;
1471
1418
_host . CheckValueOrNull ( rand ) ;
1472
- var active = Utils . BuildArray ( _reader . _bindings . ColumnCount , predicate ) ;
1419
+ var active = Utils . BuildArray ( _reader . _bindings . OutputSchema . Count , predicate ) ;
1473
1420
return Cursor . CreateSet ( _reader , _files , active , n ) ;
1474
1421
}
1475
1422
0 commit comments