Skip to content

Commit c96d690

Browse files
committed
fix for binary classification trainers export to onnx
1 parent 8100364 commit c96d690

File tree

3 files changed

+111
-8
lines changed

3 files changed

+111
-8
lines changed

src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs

+23-5
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,32 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx)
197197
for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
198198
outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
199199

200-
//Check if "Probability" column was generated by the base class, only then
201-
//label can be predicted.
200+
/* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise,
201+
the predicted label is based on the sign of the score.
202+
REVIEW: Binarizer should always have at least two output columns?
203+
*/
204+
string opType = "Binarizer";
205+
var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);
206+
202207
if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
203208
{
204-
string opType = "Binarizer";
205-
var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) },
206-
new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
209+
var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType));
207210
node.AddAttribute("threshold", 0.5);
211+
212+
opType = "Cast";
213+
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
214+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
215+
node.AddAttribute("to", t);
216+
}
217+
else if (Bindings.InfoCount == 2)
218+
{
219+
var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType));
220+
node.AddAttribute("threshold", 0.0);
221+
222+
opType = "Cast";
223+
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
224+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
225+
node.AddAttribute("to", t);
208226
}
209227
}
210228

src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
320320
if (!ctx.ContainsColumn(featName))
321321
return false;
322322
Contracts.Assert(ctx.ContainsColumn(featName));
323-
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName));
323+
return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName));
324324
}
325325

326326
private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>

test/Microsoft.ML.Tests/OnnxConversionTest.cs

+87-2
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,14 @@ private class BreastCancerMulticlassExample
131131
[LoadColumn(2, 9), VectorType(8)]
132132
public float[] Features;
133133
}
134+
private class BreastCancerBinaryClassification
135+
{
136+
[LoadColumn(0)]
137+
public bool Label;
138+
139+
[LoadColumn(2, 9), VectorType(8)]
140+
public float[] Features;
141+
}
134142

135143
[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087")]
136144
public void KmeansOnnxConversionTest()
@@ -187,6 +195,55 @@ public void KmeansOnnxConversionTest()
187195
Done();
188196
}
189197

198+
[Fact]
199+
public void binaryClassificationTrainersOnnxConversionTest()
200+
{
201+
var mlContext = new MLContext(seed: 1);
202+
string dataPath = GetDataPath("breast-cancer.txt");
203+
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
204+
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
205+
IEstimator<ITransformer>[] estimators = {
206+
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
207+
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
208+
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
209+
mlContext.BinaryClassification.Trainers.FastForest(),
210+
mlContext.BinaryClassification.Trainers.LinearSvm(),
211+
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
212+
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
213+
mlContext.BinaryClassification.Trainers.FastTree(),
214+
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
215+
mlContext.BinaryClassification.Trainers.LightGbm(),
216+
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
217+
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
218+
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
219+
};
220+
var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
221+
Append(mlContext.Transforms.NormalizeMinMax("Features"));
222+
foreach (var estimator in estimators)
223+
{
224+
var pipeline = initialPipeline.Append(estimator);
225+
var model = pipeline.Fit(dataView);
226+
var transformedData = model.Transform(dataView);
227+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
228+
// Compare model scores produced by ML.NET and ONNX's runtime.
229+
if (IsOnnxRuntimeSupported())
230+
{
231+
var onnxFileName = $"{estimator.ToString()}.onnx";
232+
var onnxModelPath = GetOutputPath(onnxFileName);
233+
SaveOnnxModel(onnxModel, onnxModelPath, null);
234+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
235+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
236+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
237+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
238+
var onnxTransformer = onnxEstimator.Fit(dataView);
239+
var onnxResult = onnxTransformer.Transform(dataView);
240+
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
241+
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
242+
}
243+
244+
}
245+
Done();
246+
}
190247
private class DataPoint
191248
{
192249
[VectorType(3)]
@@ -853,7 +910,8 @@ private void CreateDummyExamplesToMakeComplierHappy()
853910
var dummyExample = new BreastCancerFeatureVector() { Features = null };
854911
var dummyExample1 = new BreastCancerCatFeatureExample() { Label = false, F1 = 0, F2 = "Amy" };
855912
var dummyExample2 = new BreastCancerMulticlassExample() { Label = "Amy", Features = null };
856-
var dummyExample3 = new SmallSentimentExample() { Tokens = null };
913+
var dummyExample3 = new BreastCancerBinaryClassification() { Label = false, Features = null };
914+
var dummyExample4 = new SmallSentimentExample() { Tokens = null };
857915
}
858916

859917
private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
@@ -984,7 +1042,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC
9841042

9851043
// Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
9861044
Assert.Equal(1, actual.Length);
987-
Assert.Equal(expected, actual.GetItemOrDefault(0), precision);
1045+
CompareNumbersWithTolerance(expected, actual.GetItemOrDefault(0), null, precision);
1046+
}
1047+
}
1048+
}
1049+
private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
1050+
{
1051+
var leftColumn = left.Schema[leftColumnName];
1052+
var rightColumn = right.Schema[rightColumnName];
1053+
1054+
using (var expectedCursor = left.GetRowCursor(leftColumn))
1055+
using (var actualCursor = right.GetRowCursor(rightColumn))
1056+
{
1057+
T expected = default;
1058+
VBuffer<T> actual = default;
1059+
var expectedGetter = expectedCursor.GetGetter<T>(leftColumn);
1060+
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
1061+
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
1062+
{
1063+
expectedGetter(ref expected);
1064+
actualGetter(ref actual);
1065+
var actualVal = actual.GetItemOrDefault(0);
1066+
1067+
Assert.Equal(1, actual.Length);
1068+
1069+
if (typeof(T) == typeof(ReadOnlyMemory<Char>))
1070+
Assert.Equal(expected.ToString(), actualVal.ToString());
1071+
else
1072+
Assert.Equal(expected, actualVal);
9881073
}
9891074
}
9901075
}

0 commit comments

Comments
 (0)