Skip to content

Commit 8e7a8ae

Browse files
codemzseerhardt
authored andcommitted
Prevent learning pipeline from adding null transform model to the pipeline (dotnet#154)
* Prevent learning pipeline from adding null transform model to the pipeline. * Add test.
1 parent b865ec0 commit 8e7a8ae

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/Microsoft.ML/LearningPipeline.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ public PredictionModel<TInput, TOutput> Train<TInput, TOutput>()
182182

183183
if (transformModels.Count > 0)
184184
{
185-
transformModels.Insert(0,lastTransformModel);
185+
if (lastTransformModel != null)
186+
transformModels.Insert(0, lastTransformModel);
187+
186188
var modelInput = new Transforms.ModelCombiner
187189
{
188190
Models = new ArrayVar<ITransformModel>(transformModels.ToArray())

test/Microsoft.ML.Tests/LearningPipelineTests.cs

+37
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML;
6+
using Microsoft.ML.Runtime.Api;
67
using Microsoft.ML.TestFramework;
8+
using Microsoft.ML.Transforms;
79
using System.Linq;
810
using Xunit;
911
using Xunit.Abstractions;
@@ -42,5 +44,40 @@ public void CanAddAndRemoveFromPipeline()
4244
pipeline.Add(new Trainers.StochasticDualCoordinateAscentRegressor());
4345
Assert.Equal(3, pipeline.Count);
4446
}
47+
48+
private class InputData
49+
{
50+
[Column(ordinal: "1")]
51+
public string F1;
52+
}
53+
54+
private class TransformedData
55+
{
56+
#pragma warning disable 649
57+
[ColumnName("F1")]
58+
public float[] TransformedF1;
59+
#pragma warning restore 649
60+
}
61+
62+
[Fact]
63+
public void TransformOnlyPipeline()
64+
{
65+
const string _dataPath = @"..\..\Data\breast-cancer.txt";
66+
var pipeline = new LearningPipeline();
67+
pipeline.Add(new TextLoader<InputData>(_dataPath, useHeader: false));
68+
pipeline.Add(new CategoricalHashOneHotVectorizer("F1") { HashBits = 10, Seed = 314489979, OutputKind = CategoricalTransformOutputKind.Bag });
69+
var model = pipeline.Train<InputData, TransformedData>();
70+
var predictionModel = model.Predict(new InputData() { F1 = "5" });
71+
72+
Assert.NotNull(predictionModel);
73+
Assert.NotNull(predictionModel.TransformedF1);
74+
Assert.Equal(1024, predictionModel.TransformedF1.Length);
75+
76+
for (int index = 0; index < 1024; index++)
77+
if (index == 265)
78+
Assert.Equal(1, predictionModel.TransformedF1[index]);
79+
else
80+
Assert.Equal(0, predictionModel.TransformedF1[index]);
81+
}
4582
}
4683
}

0 commit comments

Comments
 (0)