Skip to content

Commit 8592d96

Browse files
author
Shahab Moradi
authored
OnnxTransform: Fix 3 bugbash bugs (#1080)
* (Fix issue #1050) Renamed namespace to Microsoft.ML.Transforms to conform with TF transform. * (Fix issue #1051) Added public create methods * (Fixed issue 1053) IDV names don't have to be the same as OnnxModel node names. * Added license header
1 parent eb87467 commit 8592d96

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

src/Microsoft.ML.OnnxTransform/OnnxTransform.cs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
using Microsoft.ML.Runtime.Internal.Utilities;
1414
using Microsoft.ML.Runtime.Model;
1515
using Microsoft.ML.Scoring;
16-
using Microsoft.ML.OnnxScoring;
16+
using Microsoft.ML.Transforms;
1717
using Microsoft.ML.StaticPipe;
1818
using Microsoft.ML.StaticPipe.Runtime;
1919
using Microsoft.ML.Core.Data;
@@ -32,7 +32,7 @@
3232

3333
[assembly: EntryPointModule(typeof(OnnxTransform))]
3434

35-
namespace Microsoft.ML.OnnxScoring
35+
namespace Microsoft.ML.Transforms
3636
{
3737
public sealed class OnnxTransform : ITransformer, ICanSaveModel
3838
{
@@ -73,8 +73,14 @@ private static VersionInfo GetVersionInfo()
7373
loaderAssemblyName: typeof(OnnxTransform).Assembly.FullName);
7474
}
7575

76+
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string inputColumn, string outputColumn)
77+
{
78+
var args = new Arguments { ModelFile = modelFile, InputColumn = inputColumn, OutputColumn = outputColumn };
79+
return Create(env, args, input);
80+
}
81+
7682
// Factory method for SignatureDataTransform
77-
private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
83+
public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
7884
{
7985
return new OnnxTransform(env, args).MakeDataTransform(input);
8086
}
@@ -122,10 +128,16 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes =
122128
else
123129
Model = OnnxModel.CreateFromBytes(modelBytes);
124130

131+
var modelInfo = Model.ModelInfo;
132+
if (modelInfo.InputsInfo.Length != 1)
133+
throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s).");
134+
if (modelInfo.OutputsInfo.Length != 1)
135+
throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s).");
136+
125137
Input = args.InputColumn;
126138
Output = args.OutputColumn;
127139

128-
var outputNodeInfo = Model.GetOutputsInfo().Where(x => x.Name == args.OutputColumn).First();
140+
var outputNodeInfo = Model.ModelInfo.OutputsInfo[0];
129141
var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
130142
var shape = outputNodeInfo.Shape;
131143
var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select( x => (int) x ).ToArray() : new[] { 0 };
@@ -305,7 +317,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
305317
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
306318
if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
307319
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
308-
var inputNodeInfo = Transformer.Model.GetInputsInfo().Where(x => x.Name == input).First();
320+
var inputNodeInfo = Transformer.Model.ModelInfo.InputsInfo[0];
309321
var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
310322
if (col.ItemType != expectedType)
311323
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());

src/Microsoft.ML.OnnxTransform/OnnxUtils.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
15
using Microsoft.ML.Runtime;
26
using Microsoft.ML.Runtime.Data;
37
using Microsoft.ML.Runtime.Internal.Utilities;
@@ -10,7 +14,7 @@
1014

1115
using OnnxShape = System.Collections.Generic.List<long>;
1216

13-
namespace Microsoft.ML.OnnxScoring
17+
namespace Microsoft.ML.Transforms
1418
{
1519
/// <summary>
1620
/// IdvToTensorAdapter adapts an Idv (row-iterator interface) to a tensor-iterator interface.
@@ -206,14 +210,14 @@ public byte[] ToByteArray()
206210
return File.ReadAllBytes(_modelFile);
207211
}
208212

209-
public OnnxNodeInfo[] GetInputsInfo()
213+
private OnnxNodeInfo[] GetInputsInfo()
210214
{
211215
return DictToNodesInfo(
212216
_modelManager.GetInputTypeDict(_modelName, _ignoredVersion),
213217
_modelManager.GetInputShapesDict(_modelName, _ignoredVersion));
214218
}
215219

216-
public OnnxNodeInfo[] GetOutputsInfo()
220+
private OnnxNodeInfo[] GetOutputsInfo()
217221
{
218222
return DictToNodesInfo(
219223
_modelManager.GetOutputTypeDict(_modelName, _ignoredVersion),

test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Core.Data;
6-
using Microsoft.ML.OnnxScoring;
6+
using Microsoft.ML.Transforms;
77
using Microsoft.ML.Runtime.Api;
88
using Microsoft.ML.Runtime.Data;
99
using Microsoft.ML.Runtime.ImageAnalytics;

0 commit comments

Comments
 (0)