|
13 | 13 | using Microsoft.ML.Runtime.Internal.Utilities;
|
14 | 14 | using Microsoft.ML.Runtime.Model;
|
15 | 15 | using Microsoft.ML.Scoring;
|
16 |
| -using Microsoft.ML.OnnxScoring; |
| 16 | +using Microsoft.ML.Transforms; |
17 | 17 | using Microsoft.ML.StaticPipe;
|
18 | 18 | using Microsoft.ML.StaticPipe.Runtime;
|
19 | 19 | using Microsoft.ML.Core.Data;
|
|
32 | 32 |
|
33 | 33 | [assembly: EntryPointModule(typeof(OnnxTransform))]
|
34 | 34 |
|
35 |
| -namespace Microsoft.ML.OnnxScoring |
| 35 | +namespace Microsoft.ML.Transforms |
36 | 36 | {
|
37 | 37 | public sealed class OnnxTransform : ITransformer, ICanSaveModel
|
38 | 38 | {
|
@@ -73,8 +73,14 @@ private static VersionInfo GetVersionInfo()
|
73 | 73 | loaderAssemblyName: typeof(OnnxTransform).Assembly.FullName);
|
74 | 74 | }
|
75 | 75 |
|
| 76 | + public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string inputColumn, string outputColumn) |
| 77 | + { |
| 78 | + var args = new Arguments { ModelFile = modelFile, InputColumn = inputColumn, OutputColumn = outputColumn }; |
| 79 | + return Create(env, args, input); |
| 80 | + } |
| 81 | + |
76 | 82 | // Factory method for SignatureDataTransform
|
77 |
| - private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) |
| 83 | + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) |
78 | 84 | {
|
79 | 85 | return new OnnxTransform(env, args).MakeDataTransform(input);
|
80 | 86 | }
|
@@ -122,10 +128,16 @@ private OnnxTransform(IHostEnvironment env, Arguments args, byte[] modelBytes =
|
122 | 128 | else
|
123 | 129 | Model = OnnxModel.CreateFromBytes(modelBytes);
|
124 | 130 |
|
| 131 | + var modelInfo = Model.ModelInfo; |
| 132 | + if (modelInfo.InputsInfo.Length != 1) |
| 133 | + throw env.Except($"OnnxTransform supports Onnx models with one input. The provided model has ${modelInfo.InputsInfo.Length} input(s)."); |
| 134 | + if (modelInfo.OutputsInfo.Length != 1) |
| 135 | + throw env.Except($"OnnxTransform supports Onnx models with one output. The provided model has ${modelInfo.OutputsInfo.Length} output(s)."); |
| 136 | + |
125 | 137 | Input = args.InputColumn;
|
126 | 138 | Output = args.OutputColumn;
|
127 | 139 |
|
128 |
| - var outputNodeInfo = Model.GetOutputsInfo().Where(x => x.Name == args.OutputColumn).First(); |
| 140 | + var outputNodeInfo = Model.ModelInfo.OutputsInfo[0]; |
129 | 141 | var type = OnnxUtils.OnnxToMlNetType(outputNodeInfo.Type);
|
130 | 142 | var shape = outputNodeInfo.Shape;
|
131 | 143 | var dims = shape.Count > 0 ? shape.Skip(shape[0] < 0 ? 1 : 0).Select( x => (int) x ).ToArray() : new[] { 0 };
|
@@ -305,7 +317,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
|
305 | 317 | throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
|
306 | 318 | if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
|
307 | 319 | throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
|
308 |
| - var inputNodeInfo = Transformer.Model.GetInputsInfo().Where(x => x.Name == input).First(); |
| 320 | + var inputNodeInfo = Transformer.Model.ModelInfo.InputsInfo[0]; |
309 | 321 | var expectedType = OnnxUtils.OnnxToMlNetType(inputNodeInfo.Type);
|
310 | 322 | if (col.ItemType != expectedType)
|
311 | 323 | throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
|
|
0 commit comments