Skip to content

Commit 1bb1249

Browse files
authored
Export to ONNX and cross-platform command-line tool to script ML.NET training and inference (#248)
* Export to ONNX and Maml cross-platform executable.
1 parent 5730685 commit 1bb1249

File tree

20 files changed

+5446
-3347
lines changed

20 files changed

+5446
-3347
lines changed

Microsoft.ML.sln

+17-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ MinimumVisualStudioVersion = 10.0.40219.1
55
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Core", "src\Microsoft.ML.Core\Microsoft.ML.Core.csproj", "{A6CA6CC6-5D7C-4D7F-A0F5-35E14B383B0A}"
66
EndProject
77
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{09EADF06-BE25-4228-AB53-95AE3E15B530}"
8+
ProjectSection(SolutionItems) = preProject
9+
src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj = src\Microsoft.ML.Commands\Microsoft.ML.Commands.csproj
10+
EndProjectSection
811
EndProject
912
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{AED9C836-31E3-4F3F-8ABC-929555D3F3C4}"
1013
EndProject
@@ -30,8 +33,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.KMeansClusteri
3033
EndProject
3134
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.PCA", "src\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj", "{58E06735-1129-4DD5-86E0-6BBFF049AAD9}"
3235
EndProject
33-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{D956E291-F6E5-4474-9023-91793F45ABEB}"
34-
EndProject
3536
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Api", "src\Microsoft.ML.Api\Microsoft.ML.Api.csproj", "{2F636A2C-062C-49F4-85F3-60DCADAB6A43}"
3637
EndProject
3738
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Tests", "test\Microsoft.ML.Tests\Microsoft.ML.Tests.csproj", "{64BC22D3-1E76-41EF-94D8-C79E471FF2DD}"
@@ -104,6 +105,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.Parquet", "Mic
104105
EndProject
105106
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Benchmarks", "test\Microsoft.ML.Benchmarks\Microsoft.ML.Benchmarks.csproj", "{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}"
106107
EndProject
108+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Maml", "src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj", "{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}"
109+
EndProject
110+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Console", "src\Microsoft.ML.Console\Microsoft.ML.Console.csproj", "{362A98CF-FBF7-4EBB-A11B-990BBF845B15}"
111+
EndProject
107112
Global
108113
GlobalSection(SolutionConfigurationPlatforms) = preSolution
109114
Debug|Any CPU = Debug|Any CPU
@@ -158,10 +163,6 @@ Global
158163
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Debug|Any CPU.Build.0 = Debug|Any CPU
159164
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.ActiveCfg = Release|Any CPU
160165
{58E06735-1129-4DD5-86E0-6BBFF049AAD9}.Release|Any CPU.Build.0 = Release|Any CPU
161-
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
162-
{D956E291-F6E5-4474-9023-91793F45ABEB}.Debug|Any CPU.Build.0 = Debug|Any CPU
163-
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.ActiveCfg = Release|Any CPU
164-
{D956E291-F6E5-4474-9023-91793F45ABEB}.Release|Any CPU.Build.0 = Release|Any CPU
165166
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
166167
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Debug|Any CPU.Build.0 = Debug|Any CPU
167168
{2F636A2C-062C-49F4-85F3-60DCADAB6A43}.Release|Any CPU.ActiveCfg = Release|Any CPU
@@ -202,6 +203,14 @@ Global
202203
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Debug|Any CPU.Build.0 = Debug|Any CPU
203204
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.ActiveCfg = Release|Any CPU
204205
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F}.Release|Any CPU.Build.0 = Release|Any CPU
206+
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
207+
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Debug|Any CPU.Build.0 = Debug|Any CPU
208+
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.ActiveCfg = Release|Any CPU
209+
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4}.Release|Any CPU.Build.0 = Release|Any CPU
210+
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
211+
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Debug|Any CPU.Build.0 = Debug|Any CPU
212+
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.ActiveCfg = Release|Any CPU
213+
{362A98CF-FBF7-4EBB-A11B-990BBF845B15}.Release|Any CPU.Build.0 = Release|Any CPU
205214
EndGlobalSection
206215
GlobalSection(SolutionProperties) = preSolution
207216
HideSolutionNode = FALSE
@@ -219,7 +228,6 @@ Global
219228
{7288C084-11C0-43BE-AC7F-45DCFEAEEBF6} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
220229
{F1CAE3AB-4F86-4BC0-BBA8-C4A58E7E8A4A} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
221230
{58E06735-1129-4DD5-86E0-6BBFF049AAD9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
222-
{D956E291-F6E5-4474-9023-91793F45ABEB} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
223231
{2F636A2C-062C-49F4-85F3-60DCADAB6A43} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
224232
{64BC22D3-1E76-41EF-94D8-C79E471FF2DD} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
225233
{FDA2FD2C-A708-43AC-A941-4D941B0853BF} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
@@ -236,6 +244,8 @@ Global
236244
{DEC8F776-49F7-4D87-836C-FE4DC057D08C} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
237245
{6C95FC87-F5F2-4EEF-BB97-567F2F5DD141} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
238246
{7A9DB75F-2CA5-4184-9EF5-1F17EB39483F} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
247+
{64F40A0D-D4C2-4AA7-8470-E9CC437827E4} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
248+
{362A98CF-FBF7-4EBB-A11B-990BBF845B15} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
239249
EndGlobalSection
240250
GlobalSection(ExtensibilityGlobals) = postSolution
241251
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

src/Microsoft.ML.Console/Console.cs

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
namespace Microsoft.ML.Runtime.Tools.Console
6+
{
7+
public static class Console
8+
{
9+
public static int Main(string[] args) => Maml.Main(args);
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
5+
<DefineConstants>CORECLR</DefineConstants>
6+
<IncludeInPackage>Microsoft.ML</IncludeInPackage>
7+
<TargetFramework>netcoreapp2.0</TargetFramework>
8+
<OutputType>Exe</OutputType>
9+
<AssemblyName>MML</AssemblyName>
10+
<StartupObject>Microsoft.ML.Runtime.Tools.Console.Console</StartupObject>
11+
</PropertyGroup>
12+
13+
<ItemGroup>
14+
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
15+
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
16+
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
17+
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
18+
</ItemGroup>
19+
20+
</Project>

src/Microsoft.ML.Data/Commands/DataCommand.cs

+9-9
Original file line numberDiff line numberDiff line change
@@ -20,38 +20,38 @@ public static class DataCommand
2020
{
2121
public abstract class ArgumentsBase
2222
{
23-
[Argument(ArgumentType.Multiple, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
23+
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The data loader", ShortName = "loader", SortOrder = 1, NullName = "<Auto>")]
2424
public SubComponent<IDataLoader, SignatureDataLoader> Loader;
2525

2626
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "The data file", ShortName = "data", SortOrder = 0)]
2727
public string DataFile;
2828

29-
[Argument(ArgumentType.AtMostOnce, HelpText = "Model file to save", ShortName = "out")]
29+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Model file to save", ShortName = "out")]
3030
public string OutputModelFile;
3131

32-
[Argument(ArgumentType.AtMostOnce, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
32+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, IsInputFileName = true, HelpText = "Model file to load", ShortName = "in", SortOrder = 90)]
3333
public string InputModelFile;
3434

35-
[Argument(ArgumentType.Multiple, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
35+
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Load transforms from model file?", ShortName = "loadTrans", SortOrder = 91)]
3636
public bool? LoadTransforms;
3737

38-
[Argument(ArgumentType.AtMostOnce, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
38+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Random seed", ShortName = "seed", SortOrder = 101)]
3939
public int? RandomSeed;
4040

41-
[Argument(ArgumentType.AtMostOnce, HelpText = "Verbose?", ShortName = "v", Hide = true)]
41+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Verbose?", ShortName = "v", Hide = true)]
4242
public bool? Verbose;
4343

44-
[Argument(ArgumentType.AtMostOnce, HelpText = "The web server to publish the RESTful API", Hide = true)]
44+
[Argument(ArgumentType.AtMostOnce, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "The web server to publish the RESTful API", Hide = true)]
4545
public ServerChannel.IServerFactory Server;
4646

4747
// This is actually an advisory value. The implementations themselves are responsible for
4848
// determining what they consider appropriate, and the actual heuristics is a bit more
4949
// complex than just this.
50-
[Argument(ArgumentType.LastOccurenceWins,
50+
[Argument(ArgumentType.LastOccurenceWins, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly,
5151
HelpText = "Desired degree of parallelism in the data pipeline", ShortName = "n")]
5252
public int? Parallel;
5353

54-
[Argument(ArgumentType.Multiple, HelpText = "Transform", ShortName = "xf")]
54+
[Argument(ArgumentType.Multiple, Visibility = ArgumentAttribute.VisibilityType.CmdLineOnly, HelpText = "Transform", ShortName = "xf")]
5555
public KeyValuePair<string, SubComponent<IDataTransform, SignatureDataTransform>>[] Transform;
5656
}
5757

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

+9-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ public sealed class OnnxContext
2323
private readonly HashSet<string> _variableMap;
2424
private readonly HashSet<string> _nodeNames;
2525
private readonly string _name;
26+
private readonly string _producerName;
2627
private readonly IHost _host;
2728
private readonly string _domain;
29+
private readonly string _producerVersion;
30+
private readonly long _modelVersion;
2831

29-
public OnnxContext(IHostEnvironment env, string name, string domain)
32+
public OnnxContext(IHostEnvironment env, string name, string producerName,
33+
string producerVersion, long modelVersion, string domain)
3034
{
3135
Contracts.CheckValue(env, nameof(env));
3236
Contracts.CheckValue(name, nameof(name));
@@ -41,6 +45,9 @@ public OnnxContext(IHostEnvironment env, string name, string domain)
4145
_variableMap = new HashSet<string>();
4246
_nodeNames = new HashSet<string>();
4347
_name = name;
48+
_producerName = producerName;
49+
_producerVersion = producerVersion;
50+
_modelVersion = modelVersion;
4451
_domain = domain;
4552
}
4653

@@ -234,6 +241,6 @@ public void AddInputVariable(ColumnType type, string colName)
234241
/// Makes the ONNX model based on the context.
235242
/// </summary>
236243
public ModelProto MakeModel()
237-
=> OnnxUtils.MakeModel(_nodes, _name, _name, _domain, _inputs, _outputs, _intermediateValues);
244+
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
238245
}
239246
}

src/Microsoft.ML.Data/Model/Onnx/OnnxUtils.cs

+10-3
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ private static AttributeProto MakeAttribute(string key, IEnumerable<GraphProto>
153153

154154
private static AttributeProto MakeAttribute(string key, bool value) => MakeAttribute(key, value ? 1 : 0);
155155

156-
public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name)
156+
public static NodeProto MakeNode(string opType, List<string> inputs, List<string> outputs, string name, string domain = null)
157157
{
158158
Contracts.CheckNonEmpty(opType, nameof(opType));
159159
Contracts.CheckValue(inputs, nameof(inputs));
@@ -165,7 +165,7 @@ public static NodeProto MakeNode(string opType, List<string> inputs, List<string
165165
node.Input.Add(inputs);
166166
node.Output.Add(outputs);
167167
node.Name = name;
168-
node.Domain = "ai.onnx.ml";
168+
node.Domain = domain ?? "ai.onnx.ml";
169169
return node;
170170
}
171171

@@ -251,7 +251,8 @@ public NodeProtoWrapper(NodeProto node)
251251
}
252252
}
253253

254-
public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name, string domain, List<ModelArgs> inputs,
254+
public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
255+
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
255256
List<ModelArgs> outputs, List<ModelArgs> intermediateValues)
256257
{
257258
Contracts.CheckValue(nodes, nameof(nodes));
@@ -261,10 +262,16 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
261262
Contracts.CheckNonEmpty(producerName, nameof(producerName));
262263
Contracts.CheckNonEmpty(name, nameof(name));
263264
Contracts.CheckNonEmpty(domain, nameof(domain));
265+
Contracts.CheckNonEmpty(producerVersion, nameof(producerVersion));
264266

265267
var model = new ModelProto();
266268
model.Domain = domain;
267269
model.ProducerName = producerName;
270+
model.ProducerVersion = producerVersion;
271+
model.IrVersion = (long)UniversalModelFormat.Onnx.Version.IrVersion;
272+
model.ModelVersion = modelVersion;
273+
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx.ml", Version = 1 });
274+
model.OpsetImport.Add(new OperatorSetIdProto() { Domain = "ai.onnx", Version = 6 });
268275
model.Graph = new GraphProto();
269276
var graph = model.Graph;
270277
graph.Node.Add(nodes);

0 commit comments

Comments
 (0)