Skip to content

Commit 90316ec

Browse files
committed
Add tests covering LoadColumnNameAttribute
1 parent 344f600 commit 90316ec

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

test/Microsoft.ML.Tests/DatabaseLoaderTests.cs

+116
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,52 @@ public void IrisLightGbm()
7777
}).PredictedLabel);
7878
}
7979

80+
[LightGBMFact]
81+
public void IrisLightGbmWithLoadColumnName()
82+
{
83+
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
84+
{
85+
// https://github.com/dotnet/machinelearning/issues/4156
86+
return;
87+
}
88+
89+
var mlContext = new MLContext(seed: 1);
90+
91+
var connectionString = GetConnectionString(TestDatasets.irisDb.name);
92+
var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}""";
93+
94+
var loader = mlContext.Data.CreateDatabaseLoader<IrisDataWithLoadColumnName>();
95+
96+
var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText);
97+
98+
var trainingData = loader.Load(databaseSource);
99+
100+
IEstimator<ITransformer> pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
101+
.Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
102+
.Append(mlContext.MulticlassClassification.Trainers.LightGbm())
103+
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
104+
105+
var model = pipeline.Fit(trainingData);
106+
107+
var engine = mlContext.Model.CreatePredictionEngine<IrisData, IrisPrediction>(model);
108+
109+
Assert.Equal(0, engine.Predict(new IrisData()
110+
{
111+
SepalLength = 4.5f,
112+
SepalWidth = 5.6f,
113+
PetalLength = 0.5f,
114+
PetalWidth = 0.5f,
115+
}).PredictedLabel);
116+
117+
Assert.Equal(1, engine.Predict(new IrisData()
118+
{
119+
SepalLength = 4.9f,
120+
SepalWidth = 2.4f,
121+
PetalLength = 3.3f,
122+
PetalWidth = 1.0f,
123+
}).PredictedLabel);
124+
}
125+
80126
[LightGBMFact]
81127
public void IrisVectorLightGbm()
82128
{
@@ -119,6 +165,48 @@ public void IrisVectorLightGbm()
119165
}).PredictedLabel);
120166
}
121167

168+
[LightGBMFact]
169+
public void IrisVectorLightGbmWithLoadColumnName()
170+
{
171+
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
172+
{
173+
// https://github.com/dotnet/machinelearning/issues/4156
174+
return;
175+
}
176+
177+
var mlContext = new MLContext(seed: 1);
178+
179+
var connectionString = GetConnectionString(TestDatasets.irisDb.name);
180+
var commandText = $@"SELECT * FROM ""{TestDatasets.irisDb.trainFilename}""";
181+
182+
var loader = mlContext.Data.CreateDatabaseLoader<IrisVectorDataWithLoadColumnName>();
183+
184+
var databaseSource = new DatabaseSource(SqlClientFactory.Instance, connectionString, commandText);
185+
186+
var trainingData = loader.Load(databaseSource);
187+
188+
IEstimator<ITransformer> pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
189+
.Append(mlContext.Transforms.Concatenate("Features", "SepalInfo", "PetalInfo"))
190+
.Append(mlContext.MulticlassClassification.Trainers.LightGbm())
191+
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
192+
193+
var model = pipeline.Fit(trainingData);
194+
195+
var engine = mlContext.Model.CreatePredictionEngine<IrisVectorData, IrisPrediction>(model);
196+
197+
Assert.Equal(0, engine.Predict(new IrisVectorData()
198+
{
199+
SepalInfo = new float[] { 4.5f, 5.6f },
200+
PetalInfo = new float[] { 0.5f, 0.5f },
201+
}).PredictedLabel);
202+
203+
Assert.Equal(1, engine.Predict(new IrisVectorData()
204+
{
205+
SepalInfo = new float[] { 4.9f, 2.4f },
206+
PetalInfo = new float[] { 3.3f, 1.0f },
207+
}).PredictedLabel);
208+
}
209+
122210
[Fact]
123211
public void IrisSdcaMaximumEntropy()
124212
{
@@ -189,6 +277,21 @@ public class IrisData
189277
public float PetalWidth;
190278
}
191279

280+
public class IrisDataWithLoadColumnName
281+
{
282+
[LoadColumnName("Label")]
283+
[ColumnName("Label")]
284+
public int Kind;
285+
286+
public float SepalLength;
287+
288+
public float SepalWidth;
289+
290+
public float PetalLength;
291+
292+
public float PetalWidth;
293+
}
294+
192295
public class IrisVectorData
193296
{
194297
public int Label;
@@ -202,6 +305,19 @@ public class IrisVectorData
202305
public float[] PetalInfo;
203306
}
204307

308+
public class IrisVectorDataWithLoadColumnName
309+
{
310+
public int Label;
311+
312+
[LoadColumnName("SepalLength", "SepalWidth")]
313+
[VectorType(2)]
314+
public float[] SepalInfo;
315+
316+
[LoadColumnName("PetalLength", "PetalWidth")]
317+
[VectorType(2)]
318+
public float[] PetalInfo;
319+
}
320+
205321
public class IrisPrediction
206322
{
207323
public int PredictedLabel;

0 commit comments

Comments
 (0)