@@ -131,6 +131,14 @@ private class BreastCancerMulticlassExample
131
131
[ LoadColumn ( 2 , 9 ) , VectorType ( 8 ) ]
132
132
public float [ ] Features ;
133
133
}
134
+ private class BreastCancerBinaryClassification
135
+ {
136
+ [ LoadColumn ( 0 ) ]
137
+ public bool Label ;
138
+
139
+ [ LoadColumn ( 2 , 9 ) , VectorType ( 8 ) ]
140
+ public float [ ] Features ;
141
+ }
134
142
135
143
[ LessThanNetCore30OrNotNetCoreFact ( "netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087" ) ]
136
144
public void KmeansOnnxConversionTest ( )
@@ -187,6 +195,55 @@ public void KmeansOnnxConversionTest()
187
195
Done ( ) ;
188
196
}
189
197
198
+ [ Fact ]
199
+ public void binaryClassificationTrainersOnnxConversionTest ( )
200
+ {
201
+ var mlContext = new MLContext ( seed : 1 ) ;
202
+ string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
203
+ // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
204
+ var dataView = mlContext . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : true ) ;
205
+ IEstimator < ITransformer > [ ] estimators = {
206
+ mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
207
+ mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
208
+ mlContext . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
209
+ mlContext . BinaryClassification . Trainers . FastForest ( ) ,
210
+ mlContext . BinaryClassification . Trainers . LinearSvm ( ) ,
211
+ mlContext . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
212
+ mlContext . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
213
+ mlContext . BinaryClassification . Trainers . FastTree ( ) ,
214
+ mlContext . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
215
+ mlContext . BinaryClassification . Trainers . LightGbm ( ) ,
216
+ mlContext . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
217
+ mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
218
+ mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
219
+ } ;
220
+ var initialPipeline = mlContext . Transforms . ReplaceMissingValues ( "Features" ) .
221
+ Append ( mlContext . Transforms . NormalizeMinMax ( "Features" ) ) ;
222
+ foreach ( var estimator in estimators )
223
+ {
224
+ var pipeline = initialPipeline . Append ( estimator ) ;
225
+ var model = pipeline . Fit ( dataView ) ;
226
+ var transformedData = model . Transform ( dataView ) ;
227
+ var onnxModel = mlContext . Model . ConvertToOnnxProtobuf ( model , dataView ) ;
228
+ // Compare model scores produced by ML.NET and ONNX's runtime.
229
+ if ( IsOnnxRuntimeSupported ( ) )
230
+ {
231
+ var onnxFileName = $ "{ estimator . ToString ( ) } .onnx";
232
+ var onnxModelPath = GetOutputPath ( onnxFileName ) ;
233
+ SaveOnnxModel ( onnxModel , onnxModelPath , null ) ;
234
+ // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
235
+ string [ ] inputNames = onnxModel . Graph . Input . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
236
+ string [ ] outputNames = onnxModel . Graph . Output . Select ( valueInfoProto => valueInfoProto . Name ) . ToArray ( ) ;
237
+ var onnxEstimator = mlContext . Transforms . ApplyOnnxModel ( outputNames , inputNames , onnxModelPath ) ;
238
+ var onnxTransformer = onnxEstimator . Fit ( dataView ) ;
239
+ var onnxResult = onnxTransformer . Transform ( dataView ) ;
240
+ CompareSelectedR4ScalarColumns ( transformedData . Schema [ 5 ] . Name , outputNames [ 3 ] , transformedData , onnxResult , 3 ) ;
241
+ CompareSelectedScalarColumns < Boolean > ( transformedData . Schema [ 4 ] . Name , outputNames [ 2 ] , transformedData , onnxResult ) ;
242
+ }
243
+
244
+ }
245
+ Done ( ) ;
246
+ }
190
247
private class DataPoint
191
248
{
192
249
[ VectorType ( 3 ) ]
@@ -853,7 +910,8 @@ private void CreateDummyExamplesToMakeComplierHappy()
853
910
var dummyExample = new BreastCancerFeatureVector ( ) { Features = null } ;
854
911
var dummyExample1 = new BreastCancerCatFeatureExample ( ) { Label = false , F1 = 0 , F2 = "Amy" } ;
855
912
var dummyExample2 = new BreastCancerMulticlassExample ( ) { Label = "Amy" , Features = null } ;
856
- var dummyExample3 = new SmallSentimentExample ( ) { Tokens = null } ;
913
+ var dummyExample3 = new BreastCancerBinaryClassification ( ) { Label = false , Features = null } ;
914
+ var dummyExample4 = new SmallSentimentExample ( ) { Tokens = null } ;
857
915
}
858
916
859
917
private void CompareResults ( string leftColumnName , string rightColumnName , IDataView left , IDataView right )
@@ -984,7 +1042,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC
984
1042
985
1043
// Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
986
1044
Assert . Equal ( 1 , actual . Length ) ;
987
- Assert . Equal ( expected , actual . GetItemOrDefault ( 0 ) , precision ) ;
1045
+ CompareNumbersWithTolerance ( expected , actual . GetItemOrDefault ( 0 ) , null , precision ) ;
1046
+ }
1047
+ }
1048
+ }
1049
+ private void CompareSelectedScalarColumns < T > ( string leftColumnName , string rightColumnName , IDataView left , IDataView right )
1050
+ {
1051
+ var leftColumn = left . Schema [ leftColumnName ] ;
1052
+ var rightColumn = right . Schema [ rightColumnName ] ;
1053
+
1054
+ using ( var expectedCursor = left . GetRowCursor ( leftColumn ) )
1055
+ using ( var actualCursor = right . GetRowCursor ( rightColumn ) )
1056
+ {
1057
+ T expected = default ;
1058
+ VBuffer < T > actual = default ;
1059
+ var expectedGetter = expectedCursor . GetGetter < T > ( leftColumn ) ;
1060
+ var actualGetter = actualCursor . GetGetter < VBuffer < T > > ( rightColumn ) ;
1061
+ while ( expectedCursor . MoveNext ( ) && actualCursor . MoveNext ( ) )
1062
+ {
1063
+ expectedGetter ( ref expected ) ;
1064
+ actualGetter ( ref actual ) ;
1065
+ var actualVal = actual . GetItemOrDefault ( 0 ) ;
1066
+
1067
+ Assert . Equal ( 1 , actual . Length ) ;
1068
+
1069
+ if ( typeof ( T ) == typeof ( ReadOnlyMemory < Char > ) )
1070
+ Assert . Equal ( expected . ToString ( ) , actualVal . ToString ( ) ) ;
1071
+ else
1072
+ Assert . Equal ( expected , actualVal ) ;
988
1073
}
989
1074
}
990
1075
}
0 commit comments