Skip to content

Commit 52cc874

Browse files
authored
Isolate ONNX implementations in separate DLL and NuGet (#462)
Abstraction of ONNX exporting to interfaces, and isolation of actual implementation to separate DLL. Creation of a new NuGet to isolate Protobuf dependency.
1 parent 4d574d6 commit 52cc874

35 files changed

+528
-338
lines changed

Microsoft.ML.sln

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InferenceTesti
1818
EndProject
1919
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Data", "src\Microsoft.ML.Data\Microsoft.ML.Data.csproj", "{AD92D96B-0E96-4F22-8DCE-892E13B1F282}"
2020
EndProject
21-
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.UniversalModelFormat", "src\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
21+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Onnx", "src\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj", "{65D0603E-B96C-4DFC-BDD1-705891B88C18}"
2222
EndProject
2323
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StandardLearners", "src\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj", "{707BB22C-7E5F-497A-8C2F-74578F675705}"
2424
EndProject
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<PackageDescription>ML.NET component for exporting ONNX Models</PackageDescription>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
10+
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
11+
</ItemGroup>
12+
13+
</Project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<Project DefaultTargets="Pack">
2+
3+
<Import Project="Microsoft.ML.Onnx.nupkgproj" />
4+
5+
</Project>

pkg/Microsoft.ML/Microsoft.ML.nupkgproj

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9-
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
109
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
1110
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="$(SystemReflectionEmitLightweightPackageVersion)" />
1211
<PackageReference Include="System.Threading.Tasks.Dataflow" Version="$(SystemThreadingTasksDataflowPackageVersion)" />

src/Microsoft.ML.Console/Microsoft.ML.Console.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
<ProjectReference Include="..\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
2020
<ProjectReference Include="..\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
2121
<ProjectReference Include="..\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
22+
<ProjectReference Include="..\Microsoft.ML.Onnx\Microsoft.ML.Onnx.csproj" />
2223
<ProjectReference Include="..\Microsoft.ML.PCA\Microsoft.ML.PCA.csproj" />
2324
<ProjectReference Include="..\Microsoft.ML.PipelineInference\Microsoft.ML.PipelineInference.csproj" />
2425
<ProjectReference Include="..\Microsoft.ML.ResultProcessor\Microsoft.ML.ResultProcessor.csproj" />
2526
<ProjectReference Include="..\Microsoft.ML.StandardLearners\Microsoft.ML.StandardLearners.csproj" />
2627
<ProjectReference Include="..\Microsoft.ML.Sweeper\Microsoft.ML.Sweeper.csproj" />
2728
<ProjectReference Include="..\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
28-
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />
2929

3030
<NativeAssemblyReference Include="FastTreeNative" />
3131
<NativeAssemblyReference Include="CpuMathNative" />

src/Microsoft.ML.Data/DataLoadSave/CompositeDataLoader.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ private static VersionInfo GetVersionInfo()
8585
/// Returns the underlying data view of the composite loader.
8686
/// This can be used to programmatically explore the chain of transforms that's inside the composite loader.
8787
/// </summary>
88-
internal IDataView View { get; }
88+
public IDataView View { get; }
8989

9090
/// <summary>
9191
/// Creates a loader according to the specified <paramref name="args"/>.

src/Microsoft.ML.Data/Microsoft.ML.Data.csproj

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
<ItemGroup>
1616
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1717
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
18-
<ProjectReference Include="..\Microsoft.ML.UniversalModelFormat\Microsoft.ML.UniversalModelFormat.csproj" />
1918
</ItemGroup>
2019

2120
</Project>

src/Microsoft.ML.Data/Model/Onnx/ICanSaveOnnx.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public interface ICanSaveOnnx
1919
}
2020

2121
/// <summary>
22-
/// This data model component is savable as Onnx.
22+
/// This data model component is savable as ONNX.
2323
/// </summary>
2424
public interface ITransformCanSaveOnnx: ICanSaveOnnx, IDataTransform
2525
{

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

+64-208
Original file line numberDiff line numberDiff line change
@@ -2,245 +2,101 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
7-
using System.Linq;
8-
using Microsoft.ML.Runtime.UniversalModelFormat.Onnx;
96
using Microsoft.ML.Runtime.Data;
107

118
namespace Microsoft.ML.Runtime.Model.Onnx
129
{
1310
/// <summary>
14-
/// A context for defining a ONNX output.
11+
/// A context for defining a ONNX output. The context internally contains the model-in-progress being built. This
12+
/// same context object is iteratively given to exportable components via the <see cref="ICanSaveOnnx"/> interface
13+
/// and subinterfaces, that attempt to express their operations as ONNX nodes, if they can. At the point that it is
14+
/// given to a component, all other components up to that component have already attempted to express themselves in
15+
/// this context, with their outputs possibly available in the ONNX graph.
1516
/// </summary>
16-
public sealed class OnnxContext
17+
public abstract class OnnxContext
1718
{
18-
private readonly List<NodeProto> _nodes;
19-
private readonly List<OnnxUtils.ModelArgs> _inputs;
20-
private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
21-
private readonly List<OnnxUtils.ModelArgs> _outputs;
22-
private readonly Dictionary<string, string> _columnNameMap;
23-
private readonly HashSet<string> _variableMap;
24-
private readonly HashSet<string> _nodeNames;
25-
private readonly string _name;
26-
private readonly string _producerName;
27-
private readonly IHost _host;
28-
private readonly string _domain;
29-
private readonly string _producerVersion;
30-
private readonly long _modelVersion;
31-
32-
public OnnxContext(IHostEnvironment env, string name, string producerName,
33-
string producerVersion, long modelVersion, string domain)
34-
{
35-
Contracts.CheckValue(env, nameof(env));
36-
Contracts.CheckValue(name, nameof(name));
37-
Contracts.CheckValue(name, nameof(domain));
38-
39-
_host = env.Register(nameof(OnnxContext));
40-
_nodes = new List<NodeProto>();
41-
_intermediateValues = new List<OnnxUtils.ModelArgs>();
42-
_inputs = new List<OnnxUtils.ModelArgs>();
43-
_outputs = new List<OnnxUtils.ModelArgs>();
44-
_columnNameMap = new Dictionary<string, string>();
45-
_variableMap = new HashSet<string>();
46-
_nodeNames = new HashSet<string>();
47-
_name = name;
48-
_producerName = producerName;
49-
_producerVersion = producerVersion;
50-
_modelVersion = modelVersion;
51-
_domain = domain;
52-
}
53-
54-
public bool ContainsColumn(string colName) => _columnNameMap.ContainsKey(colName);
55-
56-
/// <summary>
57-
/// Stops tracking a column. If removeVariable is true then it also removes the
58-
/// variable associated with it, this is useful in the event where an output variable is
59-
/// created before realizing the transform cannot actually save as ONNX.
60-
/// </summary>
61-
/// <param name="colName">IDataView column name to stop tracking</param>
62-
/// <param name="removeVariable">Remove associated ONNX variable at the time.</param>
63-
public void RemoveColumn(string colName, bool removeVariable)
64-
{
65-
66-
if (removeVariable)
67-
{
68-
foreach (var val in _intermediateValues)
69-
{
70-
if (val.Name == _columnNameMap[colName])
71-
{
72-
_intermediateValues.Remove(val);
73-
break;
74-
}
75-
}
76-
}
77-
78-
if (_columnNameMap.ContainsKey(colName))
79-
_columnNameMap.Remove(colName);
80-
}
81-
82-
/// <summary>
83-
/// Removes an ONNX variable. If removeColumn is true then it also removes the
84-
/// IDataView column associated with it.
85-
/// </summary>
86-
/// <param name="variableName">ONNX variable to remove.</param>
87-
/// <param name="removeColumn">IDataView column to stop tracking</param>
88-
public void RemoveVariable(string variableName, bool removeColumn)
89-
{
90-
_host.Assert(_columnNameMap.ContainsValue(variableName));
91-
if (removeColumn)
92-
{
93-
foreach (var val in _intermediateValues)
94-
{
95-
if (val.Name == variableName)
96-
{
97-
_intermediateValues.Remove(val);
98-
break;
99-
}
100-
}
101-
}
102-
103-
string columnName = _columnNameMap.Single(kvp => string.Compare(kvp.Value, variableName) == 0).Key;
104-
105-
Contracts.Assert(_variableMap.Contains(columnName));
106-
107-
_columnNameMap.Remove(columnName);
108-
_variableMap.Remove(columnName);
109-
}
110-
11119
/// <summary>
11220
/// Generates a unique name for the node based on a prefix.
11321
/// </summary>
114-
public string GetNodeName(string prefix)
115-
{
116-
_host.CheckValue(prefix, nameof(prefix));
117-
return GetUniqueName(prefix, c => _nodeNames.Contains(c));
118-
}
22+
/// <param name="prefix">The prefix for the node</param>
23+
/// <returns>A name that has not yet been returned from this function, starting with <paramref name="prefix"/></returns>
24+
public abstract string GetNodeName(string prefix);
11925

12026
/// <summary>
121-
/// Adds a node to the node list of the graph.
27+
/// Looks up whether a given data view column has a mapping in the ONNX context. Once confirmed, callers can
28+
/// safely call <see cref="GetVariableName(string)"/>.
12229
/// </summary>
123-
/// <param name="node"></param>
124-
public void AddNode(NodeProto node)
125-
{
126-
_host.CheckValue(node, nameof(node));
127-
_host.Assert(!_nodeNames.Contains(node.Name));
128-
129-
_nodeNames.Add(node.Name);
130-
_nodes.Add(node);
131-
}
30+
/// <param name="colName">The data view column name</param>
31+
/// <returns>Whether the column is mapped in this context</returns>
32+
public abstract bool ContainsColumn(string colName);
13233

13334
/// <summary>
134-
/// Generates a unique name based on a prefix.
35+
/// Stops tracking a column.
13536
/// </summary>
136-
private string GetUniqueName(string prefix, Func<string, bool> pred)
137-
{
138-
_host.CheckValue(prefix, nameof(prefix));
139-
_host.CheckValue(pred, nameof(pred));
140-
141-
if (!pred(prefix))
142-
return prefix;
143-
144-
int count = 0;
145-
while (pred(prefix + count++)) ;
146-
return prefix + --count;
147-
}
37+
/// <param name="colName">Column name to stop tracking</param>
38+
/// <param name="removeVariable">Remove associated ONNX variable. This is useful in the event where an output
39+
/// variable is created through <see cref="AddIntermediateVariable(ColumnType, string, bool)"/>before realizing
40+
/// the transform cannot actually save as ONNX.</param>
41+
public abstract void RemoveColumn(string colName, bool removeVariable = false);
14842

14943
/// <summary>
150-
/// Retrieves the variable name that maps to the IDataView column name at a
151-
/// given point in the pipeline execution.
44+
/// Removes an ONNX variable. If removeColumn is true then it also removes the tracking for the <see
45+
/// cref="IDataView"/> column associated with it.
15246
/// </summary>
153-
/// <returns>Column Name mapping.</returns>
154-
public string GetVariableName(string colName)
155-
{
156-
_host.CheckValue(colName, nameof(colName));
157-
_host.Assert(_columnNameMap.ContainsKey(colName));
158-
159-
return _columnNameMap[colName];
160-
}
161-
162-
/// <summary>
163-
/// Retrieves the variable name that maps to the IDataView column name at a
164-
/// given point in the pipeline execution.
165-
/// </summary>
166-
/// <returns>Column Name mapping.</returns>
167-
public string TryGetVariableName(string colName)
168-
{
169-
if (_columnNameMap.ContainsKey(colName))
170-
return GetVariableName(colName);
171-
172-
return null;
173-
}
174-
175-
/// <summary>
176-
/// Generates a unique column name based on the IDataView column name if
177-
/// there is a collision between names in the pipeline at any point.
178-
/// </summary>
179-
/// <param name="colName">IDataView column name.</param>
180-
/// <returns>Unique variable name.</returns>
181-
private string AddVariable(string colName)
182-
{
183-
_host.CheckValue(colName, nameof(colName));
184-
185-
if (!_columnNameMap.ContainsKey(colName))
186-
_columnNameMap.Add(colName, colName);
187-
else
188-
_columnNameMap[colName] = GetUniqueName(colName, s => _variableMap.Contains(s));
189-
190-
_variableMap.Add(_columnNameMap[colName]);
191-
return _columnNameMap[colName];
192-
}
47+
/// <param name="variableName">ONNX variable to remove. Note that this is an ONNX variable name, not an <see
48+
/// cref="IDataView"/> column name</param>
49+
/// <param name="removeColumn">IDataView column to stop tracking</param>
50+
public abstract void RemoveVariable(string variableName, bool removeColumn);
19351

19452
/// <summary>
195-
/// Adds an intermediate column to the list.
53+
/// ONNX variables are referred to by name. At each stage of a ML.NET pipeline, the corresponding
54+
/// <see cref="IDataView"/>'s column names will map to a variable in the ONNX graph if the intermediate steps
55+
/// used to calculate that value are things we knew how to save as ONNX. Retrieves the variable name that maps
56+
/// to the <see cref="IDataView"/> column name at a given point in the pipeline execution. Callers should
57+
/// probably confirm with <see cref="ContainsColumn(string)"/> whether a mapping for that data view column
58+
/// already exists.
19659
/// </summary>
197-
public string AddIntermediateVariable(ColumnType type, string colName, bool skip = false)
198-
{
199-
200-
colName = AddVariable(colName);
201-
202-
//Let the runtime figure the shape.
203-
if (!skip)
204-
{
205-
_host.CheckValue(type, nameof(type));
206-
207-
_intermediateValues.Add(OnnxUtils.GetModelArgs(type, colName));
208-
}
209-
210-
return colName;
211-
}
60+
/// <param name="colName">The data view column name</param>
61+
/// <returns>The ONNX variable name corresponding to that data view column</returns>
62+
public abstract string GetVariableName(string colName);
21263

21364
/// <summary>
214-
/// Adds an output variable to the list.
65+
/// Establishes a new mapping from an data view column in the context, if necessary generates a unique name, and
66+
/// returns that newly allocated name.
21567
/// </summary>
216-
public string AddOutputVariable(ColumnType type, string colName, List<long> dim = null)
217-
{
218-
_host.CheckValue(type, nameof(type));
219-
220-
if (!ContainsColumn(colName))
221-
AddVariable(colName);
222-
223-
colName = GetVariableName(colName);
224-
_outputs.Add(OnnxUtils.GetModelArgs(type, colName, dim));
225-
return colName;
226-
}
68+
/// <param name="type">The data view type associated with this column name</param>
69+
/// <param name="colName">The data view column name</param>
70+
/// <param name="skip">Whether we should skip the process of establishing the mapping from data view column to
71+
/// ONNX variable name.</param>
72+
/// <returns>The returned value is the name of the variable corresponding </returns>
73+
public abstract string AddIntermediateVariable(ColumnType type, string colName, bool skip = false);
22774

22875
/// <summary>
229-
/// Adds an input variable to the list.
76+
/// Creates an ONNX node
23077
/// </summary>
231-
public void AddInputVariable(ColumnType type, string colName)
232-
{
233-
_host.CheckValue(type, nameof(type));
234-
_host.CheckValue(colName, nameof(colName));
235-
236-
colName = AddVariable(colName);
237-
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
238-
}
78+
/// <param name="opType">The name of the ONNX operator to apply</param>
79+
/// <param name="inputs">The names of the variables as inputs</param>
80+
/// <param name="outputs">The names of the variables to create as outputs,
81+
/// which ought to have been something returned from <see cref="AddIntermediateVariable(ColumnType, string, bool)"/></param>
82+
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="GetNodeName(string)"/></param>
83+
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
84+
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
85+
public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
86+
IEnumerable<string> outputs, string name, string domain = null);
23987

24088
/// <summary>
241-
/// Makes the ONNX model based on the context.
89+
/// Convenience alternative to <see cref="CreateNode(string, IEnumerable{string}, IEnumerable{string}, string, string)"/>
90+
/// for the case where there is exactly one input and output.
24291
/// </summary>
243-
public ModelProto MakeModel()
244-
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
92+
/// <param name="opType">The name of the ONNX operator to apply</param>
93+
/// <param name="input">The name of the variable as input</param>
94+
/// <param name="output">The name of the variable as output,
95+
/// which ought to have been something returned from <see cref="OnnxContext.AddIntermediateVariable(ColumnType, string, bool)"/></param>
96+
/// <param name="name">The name of the operator, which ought to be something returned from <see cref="OnnxContext.GetNodeName(string)"/></param>
97+
/// <param name="domain">The domain of the ONNX operator, if non-default</param>
98+
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
99+
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
100+
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);
245101
}
246102
}

0 commit comments

Comments
 (0)