diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 903e1c7398..cc03c5583c 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -49,18 +49,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Mi EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{487213C9-E8A9-4F94-85D7-28A05DBBFE3A}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstandard2.0", "{9252A8EB-ABFB-440C-AB4D-1D562753CE0F}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Sweeper.Tests", "test\Microsoft.ML.Sweeper.Tests\Microsoft.ML.Sweeper.Tests.csproj", "{3DEB504D-7A07-48CE-91A2-8047461CB3D4}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.LightGbm", "src\Microsoft.ML.LightGbm\Microsoft.ML.LightGbm.csproj", "{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Ensemble", "src\Microsoft.ML.Ensemble\Microsoft.ML.Ensemble.csproj", "{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.CpuMath", "Microsoft.ML.CpuMath", "{BF66A305-DF10-47E4-8D81-42049B149D2B}" -EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools-local", "tools-local", "{7F13E156-3EBA-4021-84A5-CD56BA72F99E}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InternalCodeAnalyzer", "tools-local\Microsoft.ML.InternalCodeAnalyzer\Microsoft.ML.InternalCodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}" @@ -109,76 +103,32 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DnnImageFeatur EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.EntryPoints", "src\Microsoft.ML.EntryPoints\Microsoft.ML.EntryPoints.csproj", "{7504D46F-E4B3-43CB-9B1C-82F3131F1C99}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Mkl.Components", "Microsoft.ML.Mkl.Components", "{63006A14-B924-48C5-83C9-CFE9DA22B01F}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.ImageAnalytics", "Microsoft.ML.ImageAnalytics", "{1229F799-37F0-4282-B9F0-74BFA97CC362}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.LightGbm", "Microsoft.ML.LightGbm", "{DE95FE65-9FF7-4233-93DF-7A8F2805624A}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Mkl.Redist", "Microsoft.ML.Mkl.Redist", "{4CF8095E-B4A3-4326-A550-43098E447288}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.OnnxConverter", "Microsoft.ML.OnnxConverter", "{19AC192B-75FE-45D5-B219-898E401D5904}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.OnnxTransformer", "Microsoft.ML.OnnxTransformer", "{93FF16AA-635E-421D-96C1-008818C143A2}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Recommender", "Microsoft.ML.Recommender", "{320AF46A-4809-486E-8F9E-A00C8AE47751}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.TensorFlow", "Microsoft.ML.TensorFlow", "{11894B4A-78B4-4523-A6DD-4495722E244F}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.TimeSeries", "Microsoft.ML.TimeSeries", "{B836F712-7FB6-4B75-A3EB-FB05F8E0D15E}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.DnnImageFeaturizer.AlexNet", "Microsoft.ML.DnnImageFeaturizer.AlexNet", "{B00098E4-771E-41DF-A3AA-A606AAB334B7}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.DnnImageFeaturizer.ResNet18", "Microsoft.ML.DnnImageFeaturizer.ResNet18", "{BD93C0F3-3CED-4BE8-9389-4234250FBFB1}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.DnnImageFeaturizer.ResNet50", "Microsoft.ML.DnnImageFeaturizer.ResNet50", "{8EDFB7E5-7E7E-411D-99C5-7A4895D0F9CB}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.DnnImageFeaturizer.ResNet101", "Microsoft.ML.DnnImageFeaturizer.ResNet101", "{9E689AD4-F908-493C-B882-B1B33E8F7696}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.EntryPoints", "Microsoft.ML.EntryPoints", "{8D8CC016-0020-40EC-BD8E-73F1CE0F9662}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "common", "common", "{A84717CB-F11A-41C5-A74D-C0F1D47B7431}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.DataView", "src\Microsoft.ML.DataView\Microsoft.ML.DataView.csproj", "{85D0CAFD-2FE8-496A-88C7-585D35B94243}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.DataView", "Microsoft.ML.DataView", "{31D38B21-102B-41C0-9E0A-2FE0BF68D123}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RemoteExecutorConsoleApp", "test\RemoteExecutorConsoleApp\RemoteExecutorConsoleApp.csproj", "{5E920CAC-5A28-42FB-936E-49C472130953}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Ensemble", "Microsoft.ML.Ensemble", "{AD7058C9-5608-49A8-BE23-58C33A74EE91}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Experimental", "src\Microsoft.ML.Experimental\Microsoft.ML.Experimental.csproj", "{E02DA82D-3FEE-4C60-BD80-9EC3C3448DFC}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.FastTree", "Microsoft.ML.FastTree", "{B1B3F284-FA3D-4D76-A712-FF04495D244B}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.ML", "src\Microsoft.Extensions.ML\Microsoft.Extensions.ML.csproj", "{D6741C37-B5E6-4050-BCBA-9715809EA15B}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Extensions.ML.Tests", "test\Microsoft.Extensions.ML.Tests\Microsoft.Extensions.ML.Tests.csproj", "{21CAD3A1-5E1F-42C1-BB73-46B6E67F4206}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.Extensions.ML", "Microsoft.Extensions.ML", "{AE4F7569-26F3-4160-8A8B-7A57D0DA3350}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StableApi", "tools-local\Microsoft.ML.StableApi\Microsoft.ML.StableApi.csproj", "{F308DC6B-7E59-40D7-A581-834E8CD99CFE}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Tests", "test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj", "{C2652287-CD6D-40FB-B042-95FB56D09DB8}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML", "src\Microsoft.ML.AutoML\Microsoft.ML.AutoML.csproj", "{E48285BF-F49A-4EA3-AED0-1BDDBF77EB80}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.AutoML", "Microsoft.ML.AutoML", "{F5D11F71-2D61-4AE9-99D7-0F0B54649B15}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Samples", "docs\samples\Microsoft.ML.AutoML.Samples\Microsoft.ML.AutoML.Samples.csproj", "{A6924919-9E37-4023-8B7F-E85C8E3CC9B3}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples.GPU", "docs\samples\Microsoft.ML.Samples.GPU\Microsoft.ML.Samples.GPU.csproj", "{3C8F910B-7F23-4D25-B521-6D5AC9570ADD}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Featurizers", "src\Microsoft.ML.Featurizers\Microsoft.ML.Featurizers.csproj", "{E2DD0721-5B0F-4606-8182-4C7EFB834518}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Featurizers", "Microsoft.ML.Featurizers", "{1BA5C784-52E8-4A87-8525-26B2452F2882}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator", "src\Microsoft.ML.CodeGenerator\Microsoft.ML.CodeGenerator.csproj", "{56CB0850-7341-4D71-9AE4-9EFC472D93DD}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeGenerator.Tests", "test\Microsoft.ML.CodeGenerator.Tests\Microsoft.ML.CodeGenerator.Tests.csproj", "{46CC5637-3DDF-4100-93FC-44BB87B2DB81}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.CodeGenerator", "Microsoft.ML.CodeGenerator", "{3817A875-278C-4140-BF66-3C4A8CA55F0D}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Vision", "src\Microsoft.ML.Vision\Microsoft.ML.Vision.csproj", "{419F93D5-4135-4DA0-A76E-EFC23E04093D}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TestFrameworkCommon", "test\Microsoft.ML.TestFrameworkCommon\Microsoft.ML.TestFrameworkCommon.csproj", "{A22FAD27-77E8-4460-8B92-EC7090B7173A}" @@ -187,10 +137,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.NightlyBuild.T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.NugetPackageVersionUpdater", "test\Microsoft.ML.NugetPackageVersionUpdater\Microsoft.ML.NugetPackageVersionUpdater.csproj", "{C8DB58DC-6434-4431-A81F-263D86E2A5F3}" EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "build", "build", "{C91F81E3-B900-4968-A6DF-F53B515E97E1}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "netstandard2.0", "netstandard2.0", "{027DBA48-85B6-46F1-9487-0B49B5057FC0}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML", "src\Microsoft.ML\Microsoft.ML.csproj", "{6CF88209-69DB-4B36-9604-3ECD9F163E96}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Mkl.Redist", "src\Microsoft.ML.Mkl.Redist\Microsoft.ML.Mkl.Redist.csproj", "{4584326B-C5B3-4CAE-B98A-34C5F5AA16F3}" @@ -1759,7 +1705,6 @@ Global {7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530} - {9252A8EB-ABFB-440C-AB4D-1D562753CE0F} = {487213C9-E8A9-4F94-85D7-28A05DBBFE3A} {3DEB504D-7A07-48CE-91A2-8047461CB3D4} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530} @@ -1802,8 +1747,6 @@ Global {A22FAD27-77E8-4460-8B92-EC7090B7173A} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {A1CAC86F-F4BB-4B6D-9D18-E9AE15B3C66E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {C8DB58DC-6434-4431-A81F-263D86E2A5F3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} - {C91F81E3-B900-4968-A6DF-F53B515E97E1} = {BF66A305-DF10-47E4-8D81-42049B149D2B} - {027DBA48-85B6-46F1-9487-0B49B5057FC0} = {C91F81E3-B900-4968-A6DF-F53B515E97E1} {6CF88209-69DB-4B36-9604-3ECD9F163E96} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {4584326B-C5B3-4CAE-B98A-34C5F5AA16F3} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection diff --git a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs index 659829abe4..a151dd4ff3 100644 --- a/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs +++ b/src/Microsoft.ML.OnnxConverter/OnnxExportExtensions.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; +using System.Linq; using Google.Protobuf; using Microsoft.ML.Data; using Microsoft.ML.Model.OnnxConverter; @@ -14,30 +15,43 @@ namespace Microsoft.ML { public static class OnnxExportExtensions { - private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData) + private static ModelProto ConvertToOnnxProtobufCore(IHostEnvironment env, OnnxContextImpl ctx, ITransformer transform, IDataView inputData, string[] outputColumnNamesToKeep = null) { var outputData = transform.Transform(inputData); LinkedList transforms = null; using (var ch = env.Start("ONNX conversion")) { SaveOnnxCommand.GetPipe(ctx, ch, outputData, out IDataView root, out IDataView sink, out transforms); - return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, null); + // We pass in the output names to keep, but this next call expects a list of ones to drop. Invert the list. + var outputColumnNamesToDrop = new HashSet(); + if (outputColumnNamesToKeep != null) + { + for (int i = 0; i < sink.Schema.Count; ++i) + { + if (!outputColumnNamesToKeep.Contains(sink.Schema[i].Name)) + { + outputColumnNamesToDrop.Add(sink.Schema[i].Name); + } + } + } + return SaveOnnxCommand.ConvertTransformListToOnnxModel(ctx, ch, root, sink, transforms, null, outputColumnNamesToDrop); } } /// /// Convert the specified to ONNX format. Note that ONNX uses Google's Protobuf so the returned value is a Protobuf object. /// - /// The class that attached to. + /// The class that attached to. /// The that will be converted into ONNX format. /// The input of the specified transform. + /// List of output columns we want to keep. /// An ONNX model equivalent to the converted ML.NET model. [BestFriend] - internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData) + internal static ModelProto ConvertToOnnxProtobuf(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, string[] outputColumns=null) { var env = catalog.GetEnvironment(); var ctx = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "machinelearning.dotnet", OnnxVersion.Stable); - return ConvertToOnnxProtobufCore(env, ctx, transform, inputData); + return ConvertToOnnxProtobufCore(env, ctx, transform, inputData, outputColumns); } /// @@ -78,5 +92,17 @@ public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransform /// An ONNX model equivalent to the converted ML.NET model. public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, int opSetVersion, Stream stream) => ConvertToOnnxProtobuf(catalog, transform, inputData, opSetVersion).WriteTo(stream); + + /// + /// Convert the specified to ONNX format and writes to a stream. + /// + /// The class that attached to. + /// The that will be converted into ONNX format. + /// The input of the specified transform. + /// The stream to write the protobuf model to. + /// List of output columns we want to keep. + /// An ONNX model equivalent to the converted ML.NET model. + public static void ConvertToOnnx(this ModelOperationsCatalog catalog, ITransformer transform, IDataView inputData, Stream stream, params string[] outputColumns) => + ConvertToOnnxProtobuf(catalog, transform, inputData, outputColumns).WriteTo(stream); } } diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs index 9932019e7a..c0900b5ca6 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs @@ -553,7 +553,7 @@ public void Dispose() } } - private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCache) + private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamedOnnxValueGetters, List activeOutputColNames, OnnxRuntimeOutputCacher outputCache) { if (outputCache.Position != position) { @@ -565,7 +565,7 @@ private void UpdateCacheIfNeeded(long position, INamedOnnxValueGetter[] srcNamed } outputCache.OutputOnnxValues?.Dispose(); - outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues); + outputCache.OutputOnnxValues = _parent.Model.Run(inputNameOnnxValues, activeOutputColNames); Contracts.Assert(outputCache.OutputOnnxValues.Count > 0); foreach (var outputNameOnnxValue in outputCache.OutputOnnxValues) @@ -580,9 +580,10 @@ private Delegate MakeTensorGetter(DataViewRow input, int iinfo, INamedOnnxVal string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); + var listActiveOutputColumns = activeOutputColNames.ToList(); ValueGetter> valueGetter = (ref VBuffer dst) => { - UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, listActiveOutputColumns, outputCacher); var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; var tensor = namedOnnxValue.AsTensor() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor; if (tensor == null) @@ -598,10 +599,11 @@ private Delegate MakeStringTensorGetter(DataViewRow input, int iinfo, INamedOnnx string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); + var listActiveOutputColumns = activeOutputColNames.ToList(); ValueGetter>> valueGetter = (ref VBuffer> dst) => { - UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, listActiveOutputColumns, outputCacher); var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; var tensor = namedOnnxValue.AsTensor() as Microsoft.ML.OnnxRuntime.Tensors.DenseTensor; if (tensor == null) @@ -621,10 +623,11 @@ private Delegate MakeObjectGetter(DataViewRow input, int iinfo, INamedOnnxVal string[] activeOutputColNames, OnnxRuntimeOutputCacher outputCacher) { Host.AssertValue(input); + var listActiveOutputColumns = activeOutputColNames.ToList(); ValueGetter valueGetter = (ref T dst) => { - UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, activeOutputColNames, outputCacher); + UpdateCacheIfNeeded(input.Position, srcNamedValueGetters, listActiveOutputColumns, outputCacher); var namedOnnxValue = outputCacher.Outputs[_parent.Outputs[iinfo]]; var trueValue = namedOnnxValue.AsEnumerable().Select(value => value.AsDictionary()); var caster = _parent.Model.ModelInfo.OutputsInfo[_parent.MapDataViewColumnToOnnxOutputTensor(iinfo)].Caster; diff --git a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs index 4adf18fa40..cdba65a4b1 100644 --- a/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs +++ b/src/Microsoft.ML.OnnxTransformer/OnnxUtils.cs @@ -358,10 +358,11 @@ public static OnnxModel CreateFromBytes(byte[] modelBytes, int? gpuDeviceId = nu /// Uses an open session to score a list of NamedOnnxValues. /// /// The NamedOnnxValues to score. + /// The active output columns. /// Resulting output NamedOnnxValues list. - public IDisposableReadOnlyCollection Run(List inputNamedOnnxValues) + public IDisposableReadOnlyCollection Run(List inputNamedOnnxValues, List outputColumns) { - return _session.Run(inputNamedOnnxValues); + return _session.Run(inputNamedOnnxValues, outputColumns); } /// diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 89b838c82e..f09d02a09a 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -91,7 +91,7 @@ public void SimpleEndToEndOnnxConversionTest() // Step 3: Check ONNX model's text format. This test will be not necessary if Step 2 can run on Linux and // Mac to support cross-platform tests. - + CheckEquality(subDir, onnxTextName, digitsOfPrecision: 3); Done(); @@ -139,7 +139,7 @@ private class BreastCancerBinaryClassification [Fact] public void KmeansOnnxConversionTest() { - // Create a new context for ML.NET operations. It can be used for exception tracking and logging, + // Create a new context for ML.NET operations. It can be used for exception tracking and logging, // as a catalog of available operations and as the source of randomness. var mlContext = new MLContext(seed: 1); @@ -384,7 +384,7 @@ public void TextNormalizingOnnxConversionTest() new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.Upper, columns: new[] { ("UpperText", "text") })).Append( new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.None, columns: new[] { ("OriginalText", "text") })); var onnxFileName = $"TextNormalizing.onnx"; - + TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("NormText"), new ColumnComparison("UpperText"), new ColumnComparison("OriginalText") }); Done(); @@ -1154,7 +1154,7 @@ public void IndicateMissingValuesOnnxConversionTest() // IsNaN outputs a binary tensor. Support for this has been added in the latest version // of Onnxruntime, but that hasn't been released yet. - // So we need to convert its type to Int32 until then. + // So we need to convert its type to Int32 until then. // ConvertType part of the pipeline can be removed once we pick up a new release of the Onnx runtime var pipeline = mlContext.Transforms.IndicateMissingValues(new[] { new InputOutputColumnPair("MissingIndicator", "Features"), }) @@ -1544,6 +1544,46 @@ public void CopyColumnsOnnxTest() Done(); } + [Fact] + public void SelectiveExportOnnxTest() + { + var mlContext = new MLContext(seed: 1); + + var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); + var dataView = mlContext.Data.LoadFromTextFile(trainDataPath, + separatorChar: ';', + hasHeader: true); + + var mlpipeline = mlContext.Transforms.CopyColumns("Target1", "Target"); + var onnxFileName = "copycolumns.onnx"; + + var mlmodel = mlpipeline.Fit(dataView); + + var onnxModelPath = GetOutputPath(onnxFileName); + using (var stream = File.Create(onnxModelPath)) + { + mlContext.Model.ConvertToOnnx(mlmodel, dataView, stream, "Target1"); + } + + var model = new OnnxCSharpToProtoWrapper.ModelProto(); + using (var modelStream = File.OpenRead(onnxModelPath)) + using (var codedStream = Google.Protobuf.CodedInputStream.CreateWithLimits(modelStream, Int32.MaxValue, 10)) + model = OnnxCSharpToProtoWrapper.ModelProto.Parser.ParseFrom(codedStream); + + Assert.True(model.Graph.Output.Count == 1); + Assert.Equal("Target1.output", model.Graph.Output[0].Name); + + // Make sure that even though the column wasn't passed to ONNX, that it can still be used directly from ML.Net + var pipeline = mlContext.Transforms.ApplyOnnxModel(onnxModelPath); + var loadedModel = pipeline.Fit(dataView); + + // Getting the preview will cause an issue if there is an error since ONNX is no longer exporting that column. + var loadedData = loadedModel.Transform(dataView).Preview(1); + Assert.Equal((Single)140.66, loadedData.ColumnView[1].Values[0]); + + Done(); + } + [Fact] public void UseKeyDataViewTypeAsUInt32InOnnxInput() { @@ -1806,7 +1846,7 @@ public void NonDefaultColNamesMultiClassificationOnnxConversionTest() } Done(); } - + [Fact] public void OneHotHashEncodingOnnxConversionWithCustomOpSetVersionTest() {