3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
using System ;
6
+ using System . Collections . Generic ;
6
7
using System . Linq ;
7
8
using Microsoft . ML ;
8
9
using Microsoft . ML . Data ;
@@ -1009,5 +1010,154 @@ public void MatrixFactorization()
1009
1010
// Naive test. Just make sure the pipeline runs.
1010
1011
Assert . InRange ( metrics . L2 , 0 , 0.5 ) ;
1011
1012
}
1013
+
1014
+ /// <summary>
1015
+ /// Number of features per example used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>.
1016
+ /// </summary>
1017
+ private const int _featureVectorLength = 10 ;
1018
+
1019
+ /// <summary>
1020
+ /// Data point used in <see cref="MultiClassLightGbmStaticPipelineWithInMemoryData"/>. A data set there
1021
+ /// is a collection of <see cref="NativeExample"/>.
1022
+ /// </summary>
1023
+ private class NativeExample
1024
+ {
1025
+ [ VectorType ( _featureVectorLength ) ]
1026
+ public float [ ] Features ;
1027
+ [ ColumnName ( "Label" ) ]
1028
+ // One of "AA", "BB", "CC", and "DD".
1029
+ public string Label ;
1030
+ public uint LabelIndex ;
1031
+ // One of "AA", "BB", "CC", and "DD".
1032
+ public string PredictedLabel ;
1033
+ [ VectorType ( 4 ) ]
1034
+ // The probabilities of being "AA", "BB", "CC", and "DD".
1035
+ public float [ ] Scores ;
1036
+
1037
+ public NativeExample ( )
1038
+ {
1039
+ Features = new float [ _featureVectorLength ] ;
1040
+ }
1041
+ }
1042
+
1043
+ /// <summary>
1044
+ /// Helper function used to generate <see cref="NativeExample"/>s.
1045
+ /// </summary>
1046
+ private static List < NativeExample > GenerateRandomExamples ( int count )
1047
+ {
1048
+ var examples = new List < NativeExample > ( ) ;
1049
+ var rnd = new Random ( 0 ) ;
1050
+ for ( int i = 0 ; i < count ; ++ i )
1051
+ {
1052
+ var example = new NativeExample ( ) ;
1053
+ var res = i % 4 ;
1054
+ // Generate random float feature values.
1055
+ for ( int j = 0 ; j < _featureVectorLength ; ++ j )
1056
+ {
1057
+ var value = ( float ) rnd . NextDouble ( ) + res * 0.2f ;
1058
+ example . Features [ j ] = value ;
1059
+ }
1060
+
1061
+ // Generate label based on feature sum.
1062
+ if ( res == 0 )
1063
+ example . Label = "AA" ;
1064
+ else if ( res == 1 )
1065
+ example . Label = "BB" ;
1066
+ else if ( res == 2 )
1067
+ example . Label = "CC" ;
1068
+ else
1069
+ example . Label = "DD" ;
1070
+
1071
+ // The following three attributes are just placeholder for storing prediction results.
1072
+ example . LabelIndex = default ;
1073
+ example . PredictedLabel = null ;
1074
+ example . Scores = new float [ 4 ] ;
1075
+
1076
+ examples . Add ( example ) ;
1077
+ }
1078
+ return examples ;
1079
+ }
1080
+
1081
+ [ Fact ( ) ]
1082
+ public void MultiClassLightGbmStaticPipelineWithInMemoryData ( )
1083
+ {
1084
+ // Create a general context for ML.NET operations. It can be used for exception tracking and logging,
1085
+ // as a catalog of available operations and as the source of randomness.
1086
+ var mlContext = new MLContext ( seed : 1 , conc : 1 ) ;
1087
+
1088
+ // Context for calling static classifiers. It contains constructors of classifiers and evaluation utilities.
1089
+ var ctx = new MulticlassClassificationContext ( mlContext ) ;
1090
+
1091
+ // Create in-memory examples as C# native class.
1092
+ var examples = GenerateRandomExamples ( 1000 ) ;
1093
+
1094
+ // Convert native C# class to IDataView, a consumble format to ML.NET functions.
1095
+ var dataView = ComponentCreation . CreateDataView ( mlContext , examples ) ;
1096
+
1097
+ // IDataView is the data format used in dynamic-typed pipeline. To use static-typed pipeline, we need to convert
1098
+ // IDataView to DataView by calling AssertStatic(...). The basic idea is to specify the static type for each column
1099
+ // in IDataView in a lambda function.
1100
+ var staticDataView = dataView . AssertStatic ( mlContext , c => (
1101
+ Features : c . R4 . Vector ,
1102
+ Label : c . Text . Scalar ) ) ;
1103
+
1104
+ // Create static pipeline. First, we make an estimator out of static DataView as the starting of a pipeline.
1105
+ // Then, we append necessary transforms and a classifier to the starting estimator.
1106
+ var pipe = staticDataView . MakeNewEstimator ( )
1107
+ . Append ( mapper : r => (
1108
+ r . Label ,
1109
+ // Train multi-class LightGBM. The trained model maps Features to Label and probability of each class.
1110
+ // The call of ToKey() is needed to convert string labels to integer indexes.
1111
+ Predictions : ctx . Trainers . LightGbm ( r . Label . ToKey ( ) , r . Features )
1112
+ ) )
1113
+ . Append ( r => (
1114
+ // Actual label.
1115
+ r . Label ,
1116
+ // Labels are converted to keys when training LightGBM so we convert it here again for calling evaluation function.
1117
+ LabelIndex : r . Label . ToKey ( ) ,
1118
+ // Instance of ClassificationVectorData returned
1119
+ r . Predictions ,
1120
+ // ToValue() is used to get the original out from the class indexes computed by ToKey().
1121
+ // For example, if label "AA" is maped to index 0 via ToKey(), then ToValue() produces "AA" from 0.
1122
+ PredictedLabel : r . Predictions . predictedLabel . ToValue ( ) ,
1123
+ // Assign a new name to class probabilities.
1124
+ Scores : r . Predictions . score
1125
+ ) ) ;
1126
+
1127
+ // Split the static-typed data into training and test sets. Only training set is used in fitting
1128
+ // the created pipeline. Metrics are computed on the test.
1129
+ var ( trainingData , testingData ) = ctx . TrainTestSplit ( staticDataView , testFraction : 0.5 ) ;
1130
+
1131
+ // Train the model.
1132
+ var model = pipe . Fit ( trainingData ) ;
1133
+
1134
+ // Do prediction on the test set.
1135
+ var prediction = model . Transform ( testingData ) ;
1136
+
1137
+ // Evaluate the trained model is the test set.
1138
+ var metrics = ctx . Evaluate ( prediction , r => r . LabelIndex , r => r . Predictions ) ;
1139
+
1140
+ // Check if metrics are resonable.
1141
+ Assert . Equal ( 0.863482146891263 , metrics . AccuracyMacro , 6 ) ;
1142
+ Assert . Equal ( 0.86309523809523814 , metrics . AccuracyMicro , 6 ) ;
1143
+
1144
+ // Convert prediction in ML.NET format to native C# class.
1145
+ var nativePredictions = new List < NativeExample > ( prediction . AsDynamic . AsEnumerable < NativeExample > ( mlContext , false ) ) ;
1146
+
1147
+ // Check predicted label and class probabilities of second-first example.
1148
+ // If you see a label with LabelIndex 1, its means its probability is the 1st element in the Scores field.
1149
+ // For example, if "AA" is indexed by 1, "BB" indexed by 2, "CC" indexed by 3, and "DD" indexed by 4, Scores is
1150
+ // ["AA" probability, "BB" probability, "CC" probability, "DD" probability].
1151
+ var nativePrediction = nativePredictions [ 2 ] ;
1152
+ var probAA = nativePrediction . Scores [ 0 ] ;
1153
+ var probBB = nativePrediction . Scores [ 1 ] ;
1154
+ var probCC = nativePrediction . Scores [ 2 ] ;
1155
+ var probDD = nativePrediction . Scores [ 3 ] ;
1156
+
1157
+ Assert . Equal ( 0.922597349 , probAA , 6 ) ;
1158
+ Assert . Equal ( 0.07508608 , probBB , 6 ) ;
1159
+ Assert . Equal ( 0.00221699756 , probCC , 6 ) ;
1160
+ Assert . Equal ( 9.95488E-05 , probDD , 6 ) ;
1161
+ }
1012
1162
}
1013
1163
}
0 commit comments