@@ -77,6 +77,52 @@ public void IrisLightGbm()
77
77
} ) . PredictedLabel ) ;
78
78
}
79
79
80
+ [ LightGBMFact ]
81
+ public void IrisLightGbmWithLoadColumnName ( )
82
+ {
83
+ if ( ! RuntimeInformation . IsOSPlatform ( OSPlatform . Windows ) )
84
+ {
85
+ // https://github.com/dotnet/machinelearning/issues/4156
86
+ return ;
87
+ }
88
+
89
+ var mlContext = new MLContext ( seed : 1 ) ;
90
+
91
+ var connectionString = GetConnectionString ( TestDatasets . irisDb . name ) ;
92
+ var commandText = $@ "SELECT * FROM ""{ TestDatasets . irisDb . trainFilename } """;
93
+
94
+ var loader = mlContext.Data.CreateDatabaseLoader<IrisDataWithLoadColumnName>();
95
+
96
+ var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText);
97
+
98
+ var trainingData = loader.Load(databaseSource);
99
+
100
+ IEstimator<ITransformer> pipeline = mlContext.Transforms.Conversion.MapValueToKey(" Label")
101
+ . Append ( mlContext . Transforms . Concatenate ( "Features" , "SepalLength" , "SepalWidth" , "PetalLength" , "PetalWidth" ) )
102
+ . Append ( mlContext . MulticlassClassification . Trainers . LightGbm ( ) )
103
+ . Append ( mlContext . Transforms . Conversion . MapKeyToValue ( "PredictedLabel" ) ) ;
104
+
105
+ var model = pipeline . Fit ( trainingData ) ;
106
+
107
+ var engine = mlContext . Model . CreatePredictionEngine < IrisData , IrisPrediction > ( model ) ;
108
+
109
+ Assert . Equal ( 0 , engine . Predict ( new IrisData ( )
110
+ {
111
+ SepalLength = 4.5f ,
112
+ SepalWidth = 5.6f ,
113
+ PetalLength = 0.5f ,
114
+ PetalWidth = 0.5f ,
115
+ } ) . PredictedLabel ) ;
116
+
117
+ Assert . Equal ( 1 , engine . Predict ( new IrisData ( )
118
+ {
119
+ SepalLength = 4.9f ,
120
+ SepalWidth = 2.4f ,
121
+ PetalLength = 3.3f ,
122
+ PetalWidth = 1.0f ,
123
+ } ) . PredictedLabel ) ;
124
+ }
125
+
80
126
[ LightGBMFact ]
81
127
public void IrisVectorLightGbm ( )
82
128
{
@@ -119,6 +165,48 @@ public void IrisVectorLightGbm()
119
165
} ) . PredictedLabel ) ;
120
166
}
121
167
168
+ [ LightGBMFact ]
169
+ public void IrisVectorLightGbmWithLoadColumnName ( )
170
+ {
171
+ if ( ! RuntimeInformation . IsOSPlatform ( OSPlatform . Windows ) )
172
+ {
173
+ // https://github.com/dotnet/machinelearning/issues/4156
174
+ return ;
175
+ }
176
+
177
+ var mlContext = new MLContext ( seed : 1 ) ;
178
+
179
+ var connectionString = GetConnectionString ( TestDatasets . irisDb . name ) ;
180
+ var commandText = $@ "SELECT * FROM ""{ TestDatasets . irisDb . trainFilename } """;
181
+
182
+ var loader = mlContext.Data.CreateDatabaseLoader<IrisVectorDataWithLoadColumnName>();
183
+
184
+ var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText);
185
+
186
+ var trainingData = loader.Load(databaseSource);
187
+
188
+ IEstimator<ITransformer> pipeline = mlContext.Transforms.Conversion.MapValueToKey(" Label")
189
+ . Append ( mlContext . Transforms . Concatenate ( "Features" , "SepalInfo" , "PetalInfo" ) )
190
+ . Append ( mlContext . MulticlassClassification . Trainers . LightGbm ( ) )
191
+ . Append ( mlContext . Transforms . Conversion . MapKeyToValue ( "PredictedLabel" ) ) ;
192
+
193
+ var model = pipeline . Fit ( trainingData ) ;
194
+
195
+ var engine = mlContext . Model . CreatePredictionEngine < IrisVectorData , IrisPrediction > ( model ) ;
196
+
197
+ Assert . Equal ( 0 , engine . Predict ( new IrisVectorData ( )
198
+ {
199
+ SepalInfo = new float [ ] { 4.5f , 5.6f } ,
200
+ PetalInfo = new float [ ] { 0.5f , 0.5f } ,
201
+ } ) . PredictedLabel ) ;
202
+
203
+ Assert . Equal ( 1 , engine . Predict ( new IrisVectorData ( )
204
+ {
205
+ SepalInfo = new float [ ] { 4.9f , 2.4f } ,
206
+ PetalInfo = new float [ ] { 3.3f , 1.0f } ,
207
+ } ) . PredictedLabel ) ;
208
+ }
209
+
122
210
[ Fact ]
123
211
public void IrisSdcaMaximumEntropy ( )
124
212
{
@@ -189,6 +277,21 @@ public class IrisData
189
277
public float PetalWidth ;
190
278
}
191
279
280
+ public class IrisDataWithLoadColumnName
281
+ {
282
+ [ LoadColumnName ( "Label" ) ]
283
+ [ ColumnName ( "Label" ) ]
284
+ public int Kind ;
285
+
286
+ public float SepalLength ;
287
+
288
+ public float SepalWidth ;
289
+
290
+ public float PetalLength ;
291
+
292
+ public float PetalWidth ;
293
+ }
294
+
192
295
public class IrisVectorData
193
296
{
194
297
public int Label ;
@@ -202,6 +305,19 @@ public class IrisVectorData
202
305
public float [ ] PetalInfo ;
203
306
}
204
307
308
+ public class IrisVectorDataWithLoadColumnName
309
+ {
310
+ public int Label ;
311
+
312
+ [ LoadColumnName ( "SepalLength" , "SepalWidth" ) ]
313
+ [ VectorType ( 2 ) ]
314
+ public float [ ] SepalInfo ;
315
+
316
+ [ LoadColumnName ( "PetalLength" , "PetalWidth" ) ]
317
+ [ VectorType ( 2 ) ]
318
+ public float [ ] PetalInfo ;
319
+ }
320
+
205
321
public class IrisPrediction
206
322
{
207
323
public int PredictedLabel ;
0 commit comments