@@ -552,25 +552,27 @@ public static IDataView AddFoldIndex(IHostEnvironment env, IDataView input, int
552
552
}
553
553
}
554
554
555
- private static int [ ] [ ] MapKeys ( ISchema [ ] schemas , string columnName , bool isVec ,
556
- out int [ ] indices , out Dictionary < DvText , int > reconciledKeyNames )
555
+ private static int [ ] [ ] MapKeys < T > ( ISchema [ ] schemas , string columnName , bool isVec ,
556
+ int [ ] indices , Dictionary < DvText , int > reconciledKeyNames )
557
557
{
558
+ Contracts . AssertValue ( indices ) ;
559
+ Contracts . AssertValue ( reconciledKeyNames ) ;
560
+
558
561
var dvCount = schemas . Length ;
559
562
var keyValueMappers = new int [ dvCount ] [ ] ;
560
- var keyNamesCur = default ( VBuffer < DvText > ) ;
561
- indices = new int [ dvCount ] ;
562
- reconciledKeyNames = new Dictionary < DvText , int > ( ) ;
563
+ var keyNamesCur = default ( VBuffer < T > ) ;
563
564
for ( int i = 0 ; i < dvCount ; i ++ )
564
565
{
565
566
var schema = schemas [ i ] ;
566
567
if ( ! schema . TryGetColumnIndex ( columnName , out indices [ i ] ) )
567
568
throw Contracts . Except ( $ "Schema number { i } does not contain column '{ columnName } '") ;
568
569
569
570
var type = schema . GetColumnType ( indices [ i ] ) ;
571
+ var keyValueType = schema . GetMetadataTypeOrNull ( MetadataUtils . Kinds . KeyValues , indices [ i ] ) ;
570
572
if ( type . IsVector != isVec )
571
573
throw Contracts . Except ( $ "Column '{ columnName } ' in schema number { i } does not have the correct type") ;
572
- if ( ! schema . HasKeyNames ( indices [ i ] , type . ItemType . KeyCount ) )
573
- throw Contracts . Except ( $ "Column '{ columnName } ' in schema number { i } does not have text key values") ;
574
+ if ( keyValueType == null || keyValueType . ItemType . RawType != typeof ( T ) )
575
+ throw Contracts . Except ( $ "Column '{ columnName } ' in schema number { i } does not have the correct type of key values") ;
574
576
if ( ! type . ItemType . IsKey || type . ItemType . RawKind != DataKind . U4 )
575
577
throw Contracts . Except ( $ "Column '{ columnName } ' must be a U4 key type, but is '{ type . ItemType } '") ;
576
578
@@ -580,7 +582,7 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
580
582
foreach ( var kvp in keyNamesCur . Items ( true ) )
581
583
{
582
584
var key = kvp . Key ;
583
- var name = kvp . Value ;
585
+ var name = new DvText ( kvp . Value . ToString ( ) ) ;
584
586
if ( ! reconciledKeyNames . ContainsKey ( name ) )
585
587
reconciledKeyNames [ name ] = reconciledKeyNames . Count ;
586
588
keyValueMappers [ i ] [ key ] = reconciledKeyNames [ name ] ;
@@ -595,17 +597,18 @@ private static int[][] MapKeys(ISchema[] schemas, string columnName, bool isVec,
595
597
/// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
596
598
/// corresponding to the key value in the original column.
597
599
/// </summary>
598
- public static void ReconcileKeyValues ( IHostEnvironment env , IDataView [ ] views , string columnName )
600
+ public static void ReconcileKeyValues ( IHostEnvironment env , IDataView [ ] views , string columnName , ColumnType keyValueType )
599
601
{
600
602
Contracts . CheckNonEmpty ( views , nameof ( views ) ) ;
601
603
Contracts . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
602
604
603
605
var dvCount = views . Length ;
604
606
605
- Dictionary < DvText , int > keyNames ;
606
- int [ ] indices ;
607
607
// Create mappings from the original key types to the reconciled key type.
608
- var keyValueMappers = MapKeys ( views . Select ( view => view . Schema ) . ToArray ( ) , columnName , false , out indices , out keyNames ) ;
608
+ var indices = new int [ dvCount ] ;
609
+ var keyNames = new Dictionary < DvText , int > ( ) ;
610
+ // We use MarshalInvoke so that we can call MapKeys with the correct generic: keyValueType.RawType.
611
+ var keyValueMappers = Utils . MarshalInvoke ( MapKeys < int > , keyValueType . RawType , views . Select ( view => view . Schema ) . ToArray ( ) , columnName , false , indices , keyNames ) ;
609
612
var keyType = new KeyType ( DataKind . U4 , 0 , keyNames . Count ) ;
610
613
var keyNamesVBuffer = new VBuffer < DvText > ( keyNames . Count , keyNames . Keys . ToArray ( ) ) ;
611
614
ValueGetter < VBuffer < DvText > > keyValueGetter =
@@ -629,20 +632,51 @@ public static void ReconcileKeyValues(IHostEnvironment env, IDataView[] views, s
629
632
}
630
633
}
631
634
635
+ /// <summary>
636
+ /// This method takes an array of data views and a specified input key column, and adds a new output column to each of the data views.
637
+ /// First, we find the union set of the key values in the different data views. Next we define a new key column for each
638
+ /// data view, with the union of the key values as the new key values. For each data view, the value in the output column is the value
639
+ /// corresponding to the key value in the original column.
640
+ /// </summary>
641
+ public static void ReconcileKeyValuesWithNoNames ( IHostEnvironment env , IDataView [ ] views , string columnName , int keyCount )
642
+ {
643
+ Contracts . CheckNonEmpty ( views , nameof ( views ) ) ;
644
+ Contracts . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
645
+
646
+ var keyType = new KeyType ( DataKind . U4 , 0 , keyCount ) ;
647
+
648
+ // For each input data view, create the reconciled key column by wrapping it in a LambdaColumnMapper.
649
+ for ( int i = 0 ; i < views . Length ; i ++ )
650
+ {
651
+ if ( ! views [ i ] . Schema . TryGetColumnIndex ( columnName , out var index ) )
652
+ throw env . Except ( $ "Data view { i } doesn't contain a column '{ columnName } '") ;
653
+ ValueMapper < uint , uint > mapper =
654
+ ( ref uint src , ref uint dst ) =>
655
+ {
656
+ if ( src == 0 || src > keyCount )
657
+ dst = 0 ;
658
+ else
659
+ dst = src + 1 ;
660
+ } ;
661
+ views [ i ] = LambdaColumnMapper . Create ( env , "ReconcileKeyValues" , views [ i ] , columnName , columnName ,
662
+ views [ i ] . Schema . GetColumnType ( index ) , keyType , mapper ) ;
663
+ }
664
+ }
665
+
632
666
/// <summary>
633
667
/// This method is similar to <see cref="ReconcileKeyValues"/>, but it reconciles the key values over vector
634
668
/// input columns.
635
669
/// </summary>
636
- public static void ReconcileVectorKeyValues ( IHostEnvironment env , IDataView [ ] views , string columnName )
670
+ public static void ReconcileVectorKeyValues ( IHostEnvironment env , IDataView [ ] views , string columnName , ColumnType keyValueType )
637
671
{
638
672
Contracts . CheckNonEmpty ( views , nameof ( views ) ) ;
639
673
Contracts . CheckNonEmpty ( columnName , nameof ( columnName ) ) ;
640
674
641
675
var dvCount = views . Length ;
642
676
643
- Dictionary < DvText , int > keyNames ;
644
- int [ ] columnIndices ;
645
- var keyValueMappers = MapKeys ( views . Select ( view => view . Schema ) . ToArray ( ) , columnName , true , out columnIndices , out keyNames ) ;
677
+ var keyNames = new Dictionary < DvText , int > ( ) ;
678
+ var columnIndices = new int [ dvCount ] ;
679
+ var keyValueMappers = Utils . MarshalInvoke ( MapKeys < int > , keyValueType . RawType , views . Select ( view => view . Schema ) . ToArray ( ) , columnName , true , columnIndices , keyNames ) ;
646
680
var keyType = new KeyType ( DataKind . U4 , 0 , keyNames . Count ) ;
647
681
var keyNamesVBuffer = new VBuffer < DvText > ( keyNames . Count , keyNames . Keys . ToArray ( ) ) ;
648
682
ValueGetter < VBuffer < DvText > > keyValueGetter =
@@ -736,7 +770,7 @@ public static IDataView[] ConcatenatePerInstanceDataViews(IHostEnvironment env,
736
770
var foldDataViews = perInstance . Select ( getPerInstance ) . ToArray ( ) ;
737
771
if ( collate )
738
772
{
739
- var combined = AppendPerInstanceDataViews ( env , foldDataViews , out variableSizeVectorColumnNames ) ;
773
+ var combined = AppendPerInstanceDataViews ( env , perInstance [ 0 ] . Schema . Label ? . Name , foldDataViews , out variableSizeVectorColumnNames ) ;
740
774
return new [ ] { combined } ;
741
775
}
742
776
else
@@ -767,7 +801,8 @@ public static IDataView ConcatenateOverallMetrics(IHostEnvironment env, IDataVie
767
801
return AppendRowsDataView . Create ( env , overallList [ 0 ] . Schema , overallList . ToArray ( ) ) ;
768
802
}
769
803
770
- private static IDataView AppendPerInstanceDataViews ( IHostEnvironment env , IEnumerable < IDataView > foldDataViews , out string [ ] variableSizeVectorColumnNames )
804
+ private static IDataView AppendPerInstanceDataViews ( IHostEnvironment env , string labelColName ,
805
+ IEnumerable < IDataView > foldDataViews , out string [ ] variableSizeVectorColumnNames )
771
806
{
772
807
Contracts . AssertValue ( env ) ;
773
808
env . AssertValue ( foldDataViews ) ;
@@ -776,7 +811,9 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
776
811
// This is a dictionary from the column name to its vector size.
777
812
var vectorSizes = new Dictionary < string , int > ( ) ;
778
813
var firstDvSlotNames = new Dictionary < string , VBuffer < DvText > > ( ) ;
779
- var firstDvKeyColumns = new List < string > ( ) ;
814
+ ColumnType labelColKeyValuesType = null ;
815
+ var firstDvKeyWithNamesColumns = new List < string > ( ) ;
816
+ var firstDvKeyNoNamesColumns = new Dictionary < string , int > ( ) ;
780
817
var firstDvVectorKeyColumns = new List < string > ( ) ;
781
818
var variableSizeVectorColumnNamesList = new List < string > ( ) ;
782
819
var list = new List < IDataView > ( ) ;
@@ -822,10 +859,20 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
822
859
else
823
860
vectorSizes . Add ( name , type . VectorSize ) ;
824
861
}
825
- else if ( dvNumber == 0 && dv . Schema . HasKeyNames ( i , type . KeyCount ) )
862
+ else if ( dvNumber == 0 && name == labelColName )
826
863
{
827
864
// The label column can be a key. Reconcile the key values, and wrap with a KeyToValue transform.
828
- firstDvKeyColumns . Add ( name ) ;
865
+ labelColKeyValuesType = dv . Schema . GetMetadataTypeOrNull ( MetadataUtils . Kinds . KeyValues , i ) ;
866
+ }
867
+ else if ( dvNumber == 0 && dv . Schema . HasKeyNames ( i , type . KeyCount ) )
868
+ firstDvKeyWithNamesColumns . Add ( name ) ;
869
+ else if ( type . KeyCount > 0 && name != labelColName )
870
+ {
871
+ // For any other key column (such as GroupId) we do not reconcile the key values, we only convert to U4.
872
+ if ( ! firstDvKeyNoNamesColumns . ContainsKey ( name ) )
873
+ firstDvKeyNoNamesColumns [ name ] = type . KeyCount ;
874
+ if ( firstDvKeyNoNamesColumns [ name ] < type . KeyCount )
875
+ firstDvKeyNoNamesColumns [ name ] = type . KeyCount ;
829
876
}
830
877
}
831
878
var idv = dv ;
@@ -839,26 +886,34 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, IEnume
839
886
list . Add ( idv ) ;
840
887
dvNumber ++ ;
841
888
}
842
-
843
889
variableSizeVectorColumnNames = variableSizeVectorColumnNamesList . ToArray ( ) ;
844
- if ( variableSizeVectorColumnNamesList . Count == 0 && firstDvKeyColumns . Count == 0 )
845
- return AppendRowsDataView . Create ( env , null , list . ToArray ( ) ) ;
846
890
847
891
var views = list . ToArray ( ) ;
848
- foreach ( var keyCol in firstDvKeyColumns )
849
- ReconcileKeyValues ( env , views , keyCol ) ;
892
+ foreach ( var keyCol in firstDvKeyWithNamesColumns )
893
+ ReconcileKeyValues ( env , views , keyCol , TextType . Instance ) ;
894
+ if ( labelColKeyValuesType != null )
895
+ ReconcileKeyValues ( env , views , labelColName , labelColKeyValuesType . ItemType ) ;
896
+ foreach ( var keyCol in firstDvKeyNoNamesColumns )
897
+ ReconcileKeyValuesWithNoNames ( env , views , keyCol . Key , keyCol . Value ) ;
850
898
foreach ( var vectorKeyCol in firstDvVectorKeyColumns )
851
- ReconcileVectorKeyValues ( env , views , vectorKeyCol ) ;
899
+ ReconcileVectorKeyValues ( env , views , vectorKeyCol , TextType . Instance ) ;
852
900
853
901
Func < IDataView , int , IDataView > keyToValue =
854
902
( idv , i ) =>
855
903
{
856
- foreach ( var keyCol in firstDvKeyColumns . Concat ( firstDvVectorKeyColumns ) )
904
+ foreach ( var keyCol in firstDvVectorKeyColumns . Prepend ( labelColName ) )
857
905
{
906
+ if ( keyCol == labelColName && labelColKeyValuesType == null )
907
+ continue ;
858
908
idv = new KeyToValueTransform ( env , new KeyToValueTransform . Arguments ( ) { Column = new [ ] { new KeyToValueTransform . Column ( ) { Name = keyCol } , } } , idv ) ;
859
909
var hidden = FindHiddenColumns ( idv . Schema , keyCol ) ;
860
910
idv = new ChooseColumnsByIndexTransform ( env , new ChooseColumnsByIndexTransform . Arguments ( ) { Drop = true , Index = hidden . ToArray ( ) } , idv ) ;
861
911
}
912
+ foreach ( var keyCol in firstDvKeyNoNamesColumns )
913
+ {
914
+ var hidden = FindHiddenColumns ( idv . Schema , keyCol . Key ) ;
915
+ idv = new ChooseColumnsByIndexTransform ( env , new ChooseColumnsByIndexTransform . Arguments ( ) { Drop = true , Index = hidden . ToArray ( ) } , idv ) ;
916
+ }
862
917
return idv ;
863
918
} ;
864
919
0 commit comments