Skip to content

Commit 243ff02

Browse files
authored
Fixing ONNX test (#3253)
Fixes #2981 * When adding an ONNX transform to an ML.NET pipeline, an exception would occur if the input type was not a variable vector or vector type. This is not needed as we do support converting basic types to equivalent ONNX tensor type. Therefore the check was modified to throw if the type is a variable vector.
1 parent 4728a74 commit 243ff02

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
329329
var col = inputSchema.GetColumnOrNull(_parent.Inputs[i]);
330330
if (!col.HasValue)
331331
throw Host.ExceptSchemaMismatch( nameof(inputSchema),"input", _parent.Inputs[i]);
332+
332333
_inputColIndices[i] = col.Value.Index;
333334

334335
var type = inputSchema[_inputColIndices[i]].Type;
@@ -564,7 +565,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
564565
var input = Transformer.Inputs[i];
565566
if (!inputSchema.TryFindColumn(input, out var col))
566567
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
567-
if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
568+
if (col.Kind == SchemaShape.Column.VectorKind.VariableVector)
568569
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, "vector", col.GetTypeString());
569570

570571
var inputsInfo = Transformer.Model.ModelInfo.InputsInfo;

test/Microsoft.ML.Functional.Tests/ONNX.cs

+12-8
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System.IO;
6+
using Microsoft.ML.Data;
67
using Microsoft.ML.Functional.Tests.Datasets;
78
using Microsoft.ML.RunTests;
89
using Microsoft.ML.TestFramework;
@@ -48,18 +49,21 @@ public void SaveOnnxModelLoadAndScoreFastTree()
4849
mlContext.Model.ConvertToOnnx(model, data, file);
4950

5051
// Load the model as a transform.
51-
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath);
52+
// Note that when saving an ML.NET model as an ONNX model, the column types and column names will
53+
// change. The name changes as ONNX doesn't not allow the same name for an input and output within the ONNX model.
54+
// Therefore names maintained but have a number appended to the end of the name. In this case, Score0 is the output
55+
// of the ONNX model. We are renaming Score0 to Score using Copy Columns.
56+
// ONNX also uses tensors and will return an output of a tensor with the dimension of [1,1] for a single float.
57+
// Therefore the VectorScoreColumn class (which contains a float [] field called Score) is used for the return
58+
// type on the Prediction engine.
59+
// See #2980 and #2981 for more information.
60+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(modelPath)
61+
.Append(mlContext.Transforms.CopyColumns("Score", "Score0"));
5262
var onnxModel = onnxEstimator.Fit(data);
5363

54-
// TODO #2980: ONNX outputs don't match the outputs of the model, so we must hand-correct this for now.
55-
// TODO #2981: ONNX models cannot be fit as part of a pipeline, so we must use a workaround like this.
56-
var onnxWorkaroundPipeline = onnxModel.Append(
57-
mlContext.Transforms.CopyColumns("Score", "Score0").Fit(onnxModel.Transform(data)));
58-
5964
// Create prediction engine and test predictions.
6065
var originalPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, ScoreColumn>(model);
61-
// TODO #2982: ONNX produces vector types and not the original output type.
62-
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, VectorScoreColumn>(onnxWorkaroundPipeline);
66+
var onnxPredictionEngine = mlContext.Model.CreatePredictionEngine<HousingRegression, VectorScoreColumn>(onnxModel);
6367

6468
// Take a handful of examples out of the dataset and compute predictions.
6569
var dataEnumerator = mlContext.Data.CreateEnumerable<HousingRegression>(mlContext.Data.TakeRows(data, 5), false);

0 commit comments

Comments
 (0)