26
26
[ assembly: LoadableClass ( ConcatTransform . Summary , typeof ( IDataTransform ) , typeof ( ConcatTransform ) , null , typeof ( SignatureLoadDataTransform ) ,
27
27
ConcatTransform . UserName , ConcatTransform . LoaderSignature , ConcatTransform . LoaderSignatureOld ) ]
28
28
29
+ [ assembly: LoadableClass ( typeof ( ConcatTransform ) , null , typeof ( SignatureLoadModel ) ,
30
+ ConcatTransform . UserName , ConcatTransform . LoaderSignature ) ]
31
+
32
+ [ assembly: LoadableClass ( typeof ( IRowMapper ) , typeof ( ConcatTransform ) , null , typeof ( SignatureLoadRowMapper ) ,
33
+ ConcatTransform . UserName , ConcatTransform . LoaderSignature ) ]
34
+
29
35
namespace Microsoft . ML . Runtime . Data
30
36
{
31
37
using PfaType = PfaUtils . Type ;
@@ -250,6 +256,9 @@ public void Save(ModelSaveContext ctx)
250
256
col . Save ( ctx ) ;
251
257
}
252
258
259
+ /// <summary>
260
+ /// Constructor for SignatureLoadModel.
261
+ /// </summary>
253
262
public ConcatTransform ( IHostEnvironment env , ModelLoadContext ctx )
254
263
{
255
264
Contracts . CheckValue ( env , nameof ( env ) ) ;
@@ -770,11 +779,17 @@ private IDataTransform MakeDataTransform(IDataView input)
770
779
public IRowMapper MakeRowMapper ( ISchema inputSchema ) => new Mapper ( this , inputSchema ) ;
771
780
772
781
/// <summary>
773
- /// Factory method for SignatureLoadDataTransform
782
+ /// Factory method for SignatureLoadDataTransform.
774
783
/// </summary>
775
784
public static IDataTransform Create ( IHostEnvironment env , ModelLoadContext ctx , IDataView input )
776
785
=> new ConcatTransform ( env , ctx ) . MakeDataTransform ( input ) ;
777
786
787
+ /// <summary>
788
+ /// Factory method for SignatureLoadRowMapper.
789
+ /// </summary>
790
+ public static IRowMapper Create ( IHostEnvironment env , ModelLoadContext ctx , ISchema inputSchema )
791
+ => new ConcatTransform ( env , ctx ) . MakeRowMapper ( inputSchema ) ;
792
+
778
793
public ISchema GetOutputSchema ( ISchema inputSchema )
779
794
{
780
795
_host . CheckValue ( inputSchema , nameof ( inputSchema ) ) ;
@@ -933,16 +948,16 @@ public RowMapperColumnInfo MakeColumnInfo()
933
948
934
949
var metadata = new ColumnMetadataInfo ( _columnInfo . Output ) ;
935
950
if ( _isNormalized )
936
- metadata . Add ( MetadataUtils . Kinds . IsNormalized , new MetadataInfo < bool > ( BoolType . Instance , GetIsNormalized ) ) ;
951
+ metadata . Add ( MetadataUtils . Kinds . IsNormalized , new MetadataInfo < DvBool > ( BoolType . Instance , GetIsNormalized ) ) ;
937
952
if ( _hasSlotNames )
938
- metadata . Add ( MetadataUtils . Kinds . SlotNames , new MetadataInfo < VBuffer < DvText > > ( TextType . Instance , GetSlotNames ) ) ;
953
+ metadata . Add ( MetadataUtils . Kinds . SlotNames , new MetadataInfo < VBuffer < DvText > > ( _slotNamesType , GetSlotNames ) ) ;
939
954
if ( _hasCategoricals )
940
- metadata . Add ( MetadataUtils . Kinds . CategoricalSlotRanges , new MetadataInfo < VBuffer < DvInt4 > > ( TextType . Instance , GetCategoricalSlotRanges ) ) ;
955
+ metadata . Add ( MetadataUtils . Kinds . CategoricalSlotRanges , new MetadataInfo < VBuffer < DvInt4 > > ( _categoricalRangeType , GetCategoricalSlotRanges ) ) ;
941
956
942
957
return new RowMapperColumnInfo ( _columnInfo . Output , OutputType , metadata ) ;
943
958
}
944
959
945
- private void GetIsNormalized ( int col , ref bool value ) => value = _isNormalized ;
960
+ private void GetIsNormalized ( int col , ref DvBool value ) => value = _isNormalized ;
946
961
947
962
private void GetCategoricalSlotRanges ( int iiinfo , ref VBuffer < DvInt4 > dst )
948
963
{
@@ -1025,17 +1040,18 @@ private void GetSlotNames(int iinfo, ref VBuffer<DvText> dst)
1025
1040
public Delegate MakeGetter ( IRow input )
1026
1041
{
1027
1042
if ( _isIdentity )
1028
- {
1029
- Contracts . Assert ( SrcIndices . Length == 1 ) ;
1030
- Func < Delegate > getSrcGetter = ( ) => input . GetGetter < int > ( SrcIndices [ 0 ] ) ;
1031
- return Utils . MarshalInvoke ( getSrcGetter , _srcTypes [ 0 ] . RawType ) ;
1032
- }
1043
+ return Utils . MarshalInvoke ( MakeIdentityGetter < int > , OutputType . RawType , input ) ;
1033
1044
1034
- Func < IRow , ValueGetter < VBuffer < int > > > del = MakeGetter < int > ;
1035
- return Utils . MarshalInvoke ( MakeGetter < int > , _srcTypes [ 0 ] . RawType , input ) ;
1045
+ return Utils . MarshalInvoke ( MakeGetter < int > , OutputType . ItemType . RawType , input ) ;
1046
+ }
1047
+
1048
+ private Delegate MakeIdentityGetter < T > ( IRow input )
1049
+ {
1050
+ Contracts . Assert ( SrcIndices . Length == 1 ) ;
1051
+ return input . GetGetter < T > ( SrcIndices [ 0 ] ) ;
1036
1052
}
1037
1053
1038
- private ValueGetter < VBuffer < T > > MakeGetter < T > ( IRow input )
1054
+ private Delegate MakeGetter < T > ( IRow input )
1039
1055
{
1040
1056
var srcGetterOnes = new ValueGetter < T > [ SrcIndices . Length ] ;
1041
1057
var srcGetterVecs = new ValueGetter < VBuffer < T > > [ SrcIndices . Length ] ;
@@ -1049,109 +1065,109 @@ private ValueGetter<VBuffer<T>> MakeGetter<T>(IRow input)
1049
1065
1050
1066
T tmp = default ( T ) ;
1051
1067
VBuffer < T > [ ] tmpBufs = new VBuffer < T > [ SrcIndices . Length ] ;
1052
- return
1053
- ( ref VBuffer < T > dst ) =>
1068
+ ValueGetter < VBuffer < T > > result = ( ref VBuffer < T > dst ) =>
1069
+ {
1070
+ int dstLength = 0 ;
1071
+ int dstCount = 0 ;
1072
+ for ( int i = 0 ; i < SrcIndices . Length ; i ++ )
1054
1073
{
1055
- int dstLength = 0 ;
1056
- int dstCount = 0 ;
1057
- for ( int i = 0 ; i < SrcIndices . Length ; i ++ )
1074
+ var type = _srcTypes [ i ] ;
1075
+ if ( type . IsVector )
1058
1076
{
1059
- var type = _srcTypes [ i ] ;
1060
- if ( type . IsVector )
1061
- {
1062
- srcGetterVecs [ i ] ( ref tmpBufs [ i ] ) ;
1063
- if ( type . VectorSize != 0 && type . VectorSize != tmpBufs [ i ] . Length )
1064
- {
1065
- throw Contracts . Except ( "Column '{0}': expected {1} slots, but got {2}" ,
1066
- input . Schema . GetColumnName ( SrcIndices [ i ] ) , type . VectorSize , tmpBufs [ i ] . Length )
1067
- . MarkSensitive ( MessageSensitivity . Schema ) ;
1068
- }
1069
- dstLength = checked ( dstLength + tmpBufs [ i ] . Length ) ;
1070
- dstCount = checked ( dstCount + tmpBufs [ i ] . Count ) ;
1071
- }
1072
- else
1077
+ srcGetterVecs [ i ] ( ref tmpBufs [ i ] ) ;
1078
+ if ( type . VectorSize != 0 && type . VectorSize != tmpBufs [ i ] . Length )
1073
1079
{
1074
- dstLength = checked ( dstLength + 1 ) ;
1075
- dstCount = checked ( dstCount + 1 ) ;
1080
+ throw Contracts . Except ( "Column '{0}': expected {1} slots, but got {2}" ,
1081
+ input . Schema . GetColumnName ( SrcIndices [ i ] ) , type . VectorSize , tmpBufs [ i ] . Length )
1082
+ . MarkSensitive ( MessageSensitivity . Schema ) ;
1076
1083
}
1084
+ dstLength = checked ( dstLength + tmpBufs [ i ] . Length ) ;
1085
+ dstCount = checked ( dstCount + tmpBufs [ i ] . Count ) ;
1077
1086
}
1087
+ else
1088
+ {
1089
+ dstLength = checked ( dstLength + 1 ) ;
1090
+ dstCount = checked ( dstCount + 1 ) ;
1091
+ }
1092
+ }
1078
1093
1079
- var values = dst . Values ;
1080
- var indices = dst . Indices ;
1081
- if ( dstCount <= dstLength / 2 )
1094
+ var values = dst . Values ;
1095
+ var indices = dst . Indices ;
1096
+ if ( dstCount <= dstLength / 2 )
1097
+ {
1098
+ // Concatenate into a sparse representation.
1099
+ if ( Utils . Size ( values ) < dstCount )
1100
+ values = new T [ dstCount ] ;
1101
+ if ( Utils . Size ( indices ) < dstCount )
1102
+ indices = new int [ dstCount ] ;
1103
+
1104
+ int offset = 0 ;
1105
+ int count = 0 ;
1106
+ for ( int j = 0 ; j < SrcIndices . Length ; j ++ )
1082
1107
{
1083
- // Concatenate into a sparse representation.
1084
- if ( Utils . Size ( values ) < dstCount )
1085
- values = new T [ dstCount ] ;
1086
- if ( Utils . Size ( indices ) < dstCount )
1087
- indices = new int [ dstCount ] ;
1088
-
1089
- int offset = 0 ;
1090
- int count = 0 ;
1091
- for ( int j = 0 ; j < SrcIndices . Length ; j ++ )
1108
+ Contracts . Assert ( offset < dstLength ) ;
1109
+ if ( _srcTypes [ j ] . IsVector )
1092
1110
{
1093
- Contracts . Assert ( offset < dstLength ) ;
1094
- if ( _srcTypes [ j ] . IsVector )
1111
+ var buffer = tmpBufs [ j ] ;
1112
+ Contracts . Assert ( buffer . Count <= dstCount - count ) ;
1113
+ Contracts . Assert ( buffer . Length <= dstLength - offset ) ;
1114
+ if ( buffer . IsDense )
1095
1115
{
1096
- var buffer = tmpBufs [ j ] ;
1097
- Contracts . Assert ( buffer . Count <= dstCount - count ) ;
1098
- Contracts . Assert ( buffer . Length <= dstLength - offset ) ;
1099
- if ( buffer . IsDense )
1116
+ for ( int i = 0 ; i < buffer . Length ; i ++ )
1100
1117
{
1101
- for ( int i = 0 ; i < buffer . Length ; i ++ )
1102
- {
1103
- values [ count ] = buffer . Values [ i ] ;
1104
- indices [ count ++ ] = offset + i ;
1105
- }
1118
+ values [ count ] = buffer . Values [ i ] ;
1119
+ indices [ count ++ ] = offset + i ;
1106
1120
}
1107
- else
1108
- {
1109
- for ( int i = 0 ; i < buffer . Count ; i ++ )
1110
- {
1111
- values [ count ] = buffer . Values [ i ] ;
1112
- indices [ count ++ ] = offset + buffer . Indices [ i ] ;
1113
- }
1114
- }
1115
- offset += buffer . Length ;
1116
1121
}
1117
1122
else
1118
1123
{
1119
- Contracts . Assert ( count < dstCount ) ;
1120
- srcGetterOnes [ j ] ( ref tmp ) ;
1121
- values [ count ] = tmp ;
1122
- indices [ count ++ ] = offset ;
1123
- offset ++ ;
1124
+ for ( int i = 0 ; i < buffer . Count ; i ++ )
1125
+ {
1126
+ values [ count ] = buffer . Values [ i ] ;
1127
+ indices [ count ++ ] = offset + buffer . Indices [ i ] ;
1128
+ }
1124
1129
}
1130
+ offset += buffer . Length ;
1131
+ }
1132
+ else
1133
+ {
1134
+ Contracts . Assert ( count < dstCount ) ;
1135
+ srcGetterOnes [ j ] ( ref tmp ) ;
1136
+ values [ count ] = tmp ;
1137
+ indices [ count ++ ] = offset ;
1138
+ offset ++ ;
1125
1139
}
1126
- Contracts . Assert ( count <= dstCount ) ;
1127
- Contracts . Assert ( offset == dstLength ) ;
1128
- dst = new VBuffer < T > ( dstLength , count , values , indices ) ;
1129
1140
}
1130
- else
1131
- {
1132
- // Concatenate into a dense representation.
1133
- if ( Utils . Size ( values ) < dstLength )
1134
- values = new T [ dstLength ] ;
1141
+ Contracts . Assert ( count <= dstCount ) ;
1142
+ Contracts . Assert ( offset == dstLength ) ;
1143
+ dst = new VBuffer < T > ( dstLength , count , values , indices ) ;
1144
+ }
1145
+ else
1146
+ {
1147
+ // Concatenate into a dense representation.
1148
+ if ( Utils . Size ( values ) < dstLength )
1149
+ values = new T [ dstLength ] ;
1135
1150
1136
- int offset = 0 ;
1137
- for ( int j = 0 ; j < SrcIndices . Length ; j ++ )
1151
+ int offset = 0 ;
1152
+ for ( int j = 0 ; j < SrcIndices . Length ; j ++ )
1153
+ {
1154
+ Contracts . Assert ( tmpBufs [ j ] . Length <= dstLength - offset ) ;
1155
+ if ( _srcTypes [ j ] . IsVector )
1138
1156
{
1139
- Contracts . Assert ( tmpBufs [ j ] . Length <= dstLength - offset ) ;
1140
- if ( _srcTypes [ j ] . IsVector )
1141
- {
1142
- tmpBufs [ j ] . CopyTo ( values , offset ) ;
1143
- offset += tmpBufs [ j ] . Length ;
1144
- }
1145
- else
1146
- {
1147
- srcGetterOnes [ j ] ( ref tmp ) ;
1148
- values [ offset ++ ] = tmp ;
1149
- }
1157
+ tmpBufs [ j ] . CopyTo ( values , offset ) ;
1158
+ offset += tmpBufs [ j ] . Length ;
1159
+ }
1160
+ else
1161
+ {
1162
+ srcGetterOnes [ j ] ( ref tmp ) ;
1163
+ values [ offset ++ ] = tmp ;
1150
1164
}
1151
- Contracts . Assert ( offset == dstLength ) ;
1152
- dst = new VBuffer < T > ( dstLength , values , indices ) ;
1153
1165
}
1154
- } ;
1166
+ Contracts . Assert ( offset == dstLength ) ;
1167
+ dst = new VBuffer < T > ( dstLength , values , indices ) ;
1168
+ }
1169
+ } ;
1170
+ return result ;
1155
1171
}
1156
1172
1157
1173
public KeyValuePair < string , JToken > SavePfaInfo ( BoundPfaContext ctx )
0 commit comments