@@ -74,6 +74,117 @@ public void TensorFlowTransformMatrixMultiplicationTest()
74
74
}
75
75
}
76
76
77
+ private class ShapeData
78
+ {
79
+ // Data will be passed as 1-D vector.
80
+ // Intended data shape [5], model shape [None]
81
+ [ VectorType ( 5 ) ]
82
+ public float [ ] OneDim ;
83
+
84
+ // Data will be passed as flat vector.
85
+ // Intended data shape [2,2], model shape [2, None]
86
+ [ VectorType ( 4 ) ]
87
+ public float [ ] TwoDim ;
88
+
89
+ // Data will be passed as 3-D vector.
90
+ // Intended data shape [1, 2, 2], model shape [1, None, 2]
91
+ [ VectorType ( 1 , 2 , 2 ) ]
92
+ public float [ ] ThreeDim ;
93
+
94
+ // Data will be passed as flat vector.
95
+ // Intended data shape [1, 2, 2, 3], model shape [1, None, None, 3]
96
+ [ VectorType ( 12 ) ]
97
+ public float [ ] FourDim ;
98
+
99
+ // Data will be passed as 4-D vector.
100
+ // Intended data shape [2, 2, 2, 2], model shape [2, 2, 2, 2]
101
+ [ VectorType ( 2 , 2 , 2 , 2 ) ]
102
+ public float [ ] FourDimKnown ;
103
+ }
104
+
105
+ private List < ShapeData > GetShapeData ( )
106
+ {
107
+ return new List < ShapeData > ( new ShapeData [ ] {
108
+ new ShapeData ( ) { OneDim = new [ ] { 0.1f , 0.2f , 0.3f , 0.4f , 0.5f } ,
109
+ TwoDim = new [ ] { 1.0f , 2.0f , 3.0f , 4.0f } ,
110
+ ThreeDim = new [ ] { 11.0f , 12.0f , 13.0f , 14.0f } ,
111
+ FourDim = new [ ] { 21.0f , 22.0f , 23.0f , 24.0f , 25.0f , 26.0f ,
112
+ 27.0f , 28.0f , 29.0f , 30.0f , 31.0f , 32.0f } ,
113
+ FourDimKnown = new [ ] { 41.0f , 42.0f , 43.0f , 44.0f , 45.0f , 46.0f , 47.0f , 48.0f ,
114
+ 49.0f , 50.0f , 51.0f , 52.0f , 53.0f , 54.0f , 55.0f , 56.0f }
115
+ } ,
116
+ new ShapeData ( ) { OneDim = new [ ] { 100.1f , 100.2f , 100.3f , 100.4f , 100.5f } ,
117
+ TwoDim = new [ ] { 101.0f , 102.0f , 103.0f , 104.0f } ,
118
+ ThreeDim = new [ ] { 111.0f , 112.0f , 113.0f , 114.0f } ,
119
+ FourDim = new [ ] { 121.0f , 122.0f , 123.0f , 124.0f , 125.0f , 126.0f ,
120
+ 127.0f , 128.0f , 129.0f , 130.0f , 131.0f , 132.0f } ,
121
+ FourDimKnown = new [ ] { 141.0f , 142.0f , 143.0f , 144.0f , 145.0f , 146.0f , 147.0f , 148.0f ,
122
+ 149.0f , 150.0f , 151.0f , 152.0f , 153.0f , 154.0f , 155.0f , 156.0f }
123
+ }
124
+ } ) ;
125
+ }
126
+
127
+ [ ConditionalFact ( typeof ( Environment ) , nameof ( Environment . Is64BitProcess ) ) ] // TensorFlow is 64-bit only
128
+ public void TensorFlowTransformInputShapeTest ( )
129
+ {
130
+ var modelLocation = "model_shape_test" ;
131
+ var mlContext = new MLContext ( seed : 1 , conc : 1 ) ;
132
+ var data = GetShapeData ( ) ;
133
+ // Pipeline
134
+ var loader = mlContext . Data . ReadFromEnumerable ( data ) ;
135
+ var inputs = new string [ ] { "OneDim" , "TwoDim" , "ThreeDim" , "FourDim" , "FourDimKnown" } ;
136
+ var outputs = new string [ ] { "o_OneDim" , "o_TwoDim" , "o_ThreeDim" , "o_FourDim" , "o_FourDimKnown" } ;
137
+
138
+ var trans = mlContext . Transforms . ScoreTensorFlowModel ( modelLocation , outputs , inputs ) . Fit ( loader ) . Transform ( loader ) ;
139
+
140
+ using ( var cursor = trans . GetRowCursorForAllColumns ( ) )
141
+ {
142
+ int outColIndex = 5 ;
143
+ var oneDimgetter = cursor . GetGetter < VBuffer < float > > ( outColIndex ) ;
144
+ var twoDimgetter = cursor . GetGetter < VBuffer < float > > ( outColIndex + 1 ) ;
145
+ var threeDimgetter = cursor . GetGetter < VBuffer < float > > ( outColIndex + 2 ) ;
146
+ var fourDimgetter = cursor . GetGetter < VBuffer < float > > ( outColIndex + 3 ) ;
147
+ var fourDimKnowngetter = cursor . GetGetter < VBuffer < float > > ( outColIndex + 4 ) ;
148
+
149
+ VBuffer < float > oneDim = default ;
150
+ VBuffer < float > twoDim = default ;
151
+ VBuffer < float > threeDim = default ;
152
+ VBuffer < float > fourDim = default ;
153
+ VBuffer < float > fourDimKnown = default ;
154
+ foreach ( var sample in data )
155
+ {
156
+ Assert . True ( cursor . MoveNext ( ) ) ;
157
+
158
+ oneDimgetter ( ref oneDim ) ;
159
+ twoDimgetter ( ref twoDim ) ;
160
+ threeDimgetter ( ref threeDim ) ;
161
+ fourDimgetter ( ref fourDim ) ;
162
+ fourDimKnowngetter ( ref fourDimKnown ) ;
163
+
164
+ var oneDimValues = oneDim . GetValues ( ) ;
165
+ Assert . Equal ( sample . OneDim . Length , oneDimValues . Length ) ;
166
+ Assert . True ( oneDimValues . SequenceEqual ( sample . OneDim ) ) ;
167
+
168
+ var twoDimValues = twoDim . GetValues ( ) ;
169
+ Assert . Equal ( sample . TwoDim . Length , twoDimValues . Length ) ;
170
+ Assert . True ( twoDimValues . SequenceEqual ( sample . TwoDim ) ) ;
171
+
172
+ var threeDimValues = threeDim . GetValues ( ) ;
173
+ Assert . Equal ( sample . ThreeDim . Length , threeDimValues . Length ) ;
174
+ Assert . True ( threeDimValues . SequenceEqual ( sample . ThreeDim ) ) ;
175
+
176
+ var fourDimValues = fourDim . GetValues ( ) ;
177
+ Assert . Equal ( sample . FourDim . Length , fourDimValues . Length ) ;
178
+ Assert . True ( fourDimValues . SequenceEqual ( sample . FourDim ) ) ;
179
+
180
+ var fourDimKnownValues = fourDimKnown . GetValues ( ) ;
181
+ Assert . Equal ( sample . FourDimKnown . Length , fourDimKnownValues . Length ) ;
182
+ Assert . True ( fourDimKnownValues . SequenceEqual ( sample . FourDimKnown ) ) ;
183
+ }
184
+ Assert . False ( cursor . MoveNext ( ) ) ;
185
+ }
186
+ }
187
+
77
188
private class TypesData
78
189
{
79
190
[ VectorType ( 2 ) ]
@@ -142,7 +253,7 @@ public void TensorFlowTransformInputOutputTypesTest()
142
253
143
254
var loader = mlContext . Data . ReadFromEnumerable ( data ) ;
144
255
145
- var inputs = new string [ ] { "f64" , "f32" , "i64" , "i32" , "i16" , "i8" , "u64" , "u32" , "u16" , "u8" , "b" } ;
256
+ var inputs = new string [ ] { "f64" , "f32" , "i64" , "i32" , "i16" , "i8" , "u64" , "u32" , "u16" , "u8" , "b" } ;
146
257
var outputs = new string [ ] { "o_f64" , "o_f32" , "o_i64" , "o_i32" , "o_i16" , "o_i8" , "o_u64" , "o_u32" , "o_u16" , "o_u8" , "o_b" } ;
147
258
var trans = mlContext . Transforms . ScoreTensorFlowModel ( model_location , outputs , inputs ) . Fit ( loader ) . Transform ( loader ) ; ;
148
259
@@ -160,7 +271,7 @@ public void TensorFlowTransformInputOutputTypesTest()
160
271
var u8getter = cursor . GetGetter < VBuffer < byte > > ( 20 ) ;
161
272
var boolgetter = cursor . GetGetter < VBuffer < bool > > ( 21 ) ;
162
273
163
-
274
+
164
275
VBuffer < double > f64 = default ;
165
276
VBuffer < float > f32 = default ;
166
277
VBuffer < long > i64 = default ;
@@ -449,7 +560,7 @@ public void TensorFlowTransformMNISTLRTrainingTest()
449
560
ReTrain = true
450
561
} ) )
451
562
. Append ( mlContext . Transforms . Concatenate ( "Features" , "Prediction" ) )
452
- . Append ( mlContext . Transforms . Conversion . MapValueToKey ( "KeyLabel" , "Label" , maxNumKeys : 10 ) )
563
+ . Append ( mlContext . Transforms . Conversion . MapValueToKey ( "KeyLabel" , "Label" , maxNumKeys : 10 ) )
453
564
. Append ( mlContext . MulticlassClassification . Trainers . LightGbm ( "KeyLabel" , "Features" ) ) ;
454
565
455
566
var trainedModel = pipe . Fit ( trainData ) ;
0 commit comments