Skip to content

Commit be7ccff

Browse files
committed
Merge branch 'master' of https://github.com/dotnet/machinelearning into timeseries
2 parents bb73184 + ff85a5c commit be7ccff

File tree

6 files changed

+266
-3
lines changed

6 files changed

+266
-3
lines changed

src/Microsoft.ML.Data/Model/Onnx/OnnxContext.cs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,56 @@ public abstract OnnxNode CreateNode(string opType, IEnumerable<string> inputs,
9898
/// <returns>A node added to the in-progress ONNX graph, that attributes can be set on</returns>
9999
public OnnxNode CreateNode(string opType, string input, string output, string name, string domain = null)
100100
=> CreateNode(opType, new[] { input }, new[] { output }, name, domain);
101+
102+
/// <summary>
103+
/// Call this function can declare a global float
104+
/// </summary>
105+
/// <param name="value">The float number which is going to be added</param>
106+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
107+
/// <returns>The initializer's ONNX name</returns>
108+
public abstract string AddInitializer(float value, string name = null);
109+
110+
/// <summary>
111+
/// Call this function can declare a global long
112+
/// </summary>
113+
/// <param name="value">The long number which is going to be added into the ONNX graph</param>
114+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
115+
/// <returns>The initializer's ONNX name</returns>
116+
public abstract string AddInitializer(long value, string name = null);
117+
118+
/// <summary>
119+
/// Call this function can declare a global string
120+
/// </summary>
121+
/// <param name="value">The string which is going to be added into the ONNX graph</param>
122+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
123+
/// <returns>The initializer's ONNX name</returns>
124+
public abstract string AddInitializer(string value, string name = null);
125+
126+
/// <summary>
127+
/// Call this function can declare a global float tensor
128+
/// </summary>
129+
/// <param name="values">The floats which are going to be added into the ONNX graph</param>
130+
/// <param name="dims">The shape that the floats</param>
131+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
132+
/// <returns>The initializer's ONNX name</returns>
133+
public abstract string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null);
134+
135+
/// <summary>
136+
/// Call this function can declare a global long tensor
137+
/// </summary>
138+
/// <param name="values">The longs which are going to be added into the ONNX graph</param>
139+
/// <param name="dims">The shape that the floats</param>
140+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
141+
/// <returns>The initializer's ONNX name</returns>
142+
public abstract string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null);
143+
144+
/// <summary>
145+
/// Call this function can declare a global string tensor
146+
/// </summary>
147+
/// <param name="values">The strings which are going to be added into the ONNX graph</param>
148+
/// <param name="dims">The shape that the strings</param>
149+
/// <param name="name">A string used as a seed to generate this initializer's name in the ONNX graph.</param>
150+
/// <returns>The initializer's ONNX name</returns>
151+
public abstract string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null);
101152
}
102153
}

src/Microsoft.ML.Onnx/AssemblyInfo.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
using System.Runtime.CompilerServices;
2+
3+
[assembly: InternalsVisibleTo("Microsoft.ML.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]

src/Microsoft.ML.Onnx/OnnxContextImpl.cs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ internal sealed class OnnxContextImpl : OnnxContext
1818
private readonly List<NodeProto> _nodes;
1919
private readonly List<OnnxUtils.ModelArgs> _inputs;
2020
// The map from IDataView column names to variable names.
21+
private readonly List<TensorProto> _initializers;
2122
private readonly List<OnnxUtils.ModelArgs> _intermediateValues;
2223
private readonly List<OnnxUtils.ModelArgs> _outputs;
2324
private readonly Dictionary<string, string> _columnNameMap;
@@ -43,6 +44,7 @@ public OnnxContextImpl(IHostEnvironment env, string name, string producerName,
4344
_nodes = new List<NodeProto>();
4445
_intermediateValues = new List<OnnxUtils.ModelArgs>();
4546
_inputs = new List<OnnxUtils.ModelArgs>();
47+
_initializers = new List<TensorProto>();
4648
_outputs = new List<OnnxUtils.ModelArgs>();
4749
_columnNameMap = new Dictionary<string, string>();
4850
_variableNames = new HashSet<string>();
@@ -246,10 +248,67 @@ public void AddInputVariable(ColumnType type, string colName)
246248
_inputs.Add(OnnxUtils.GetModelArgs(type, colName));
247249
}
248250

251+
/// <summary>
252+
/// Adds constant tensors into the graph.
253+
/// </summary>
254+
public override string AddInitializer(float value, string name = null)
255+
{
256+
name = AddVariable(name ?? "float");
257+
_initializers.Add(OnnxUtils.MakeFloat(name, value));
258+
return name;
259+
}
260+
261+
public override string AddInitializer(string value, string name = null)
262+
{
263+
name = AddVariable(name ?? "string");
264+
_initializers.Add(OnnxUtils.MakeString(name, value));
265+
return name;
266+
}
267+
268+
public override string AddInitializer(long value, string name = null)
269+
{
270+
name = AddVariable(name ?? "int64");
271+
_initializers.Add(OnnxUtils.MakeInt64(name, value));
272+
return name;
273+
}
274+
275+
public override string AddInitializer(IEnumerable<float> values, IEnumerable<long> dims, string name = null)
276+
{
277+
_host.CheckValue(values, nameof(values));
278+
if (dims != null)
279+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
280+
281+
name = AddVariable(name ?? "floats");
282+
_initializers.Add(OnnxUtils.MakeFloats(name, values, dims));
283+
return name;
284+
}
285+
286+
public override string AddInitializer(IEnumerable<long> values, IEnumerable<long> dims, string name = null)
287+
{
288+
_host.CheckValue(values, nameof(values));
289+
if (dims != null)
290+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
291+
292+
name = AddVariable(name ?? "int64s");
293+
_initializers.Add(OnnxUtils.MakeInt64s(name, values, dims));
294+
return name;
295+
}
296+
297+
public override string AddInitializer(IEnumerable<string> values, IEnumerable<long> dims, string name = null)
298+
{
299+
_host.CheckValue(values, nameof(values));
300+
if (dims != null)
301+
_host.Check(dims.Aggregate((x, y) => x * y) == values.Count(), "Number of elements doesn't match tensor size");
302+
303+
name = AddVariable(name ?? "strings");
304+
_initializers.Add(OnnxUtils.MakeStrings(name, values, dims));
305+
return name;
306+
}
307+
249308
/// <summary>
250309
/// Makes the ONNX model based on the context.
251310
/// </summary>
252311
public ModelProto MakeModel()
253-
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues);
312+
=> OnnxUtils.MakeModel(_nodes, _producerName, _name, _domain, _producerVersion, _modelVersion, _inputs, _outputs, _intermediateValues, _initializers);
254313
}
255314
}

src/Microsoft.ML.Onnx/OnnxUtils.cs

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,13 @@ public ModelArgs(string name, TensorProto.Types.DataType dataType, List<long> di
238238

239239
public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, string name,
240240
string domain, string producerVersion, long modelVersion, List<ModelArgs> inputs,
241-
List<ModelArgs> outputs, List<ModelArgs> intermediateValues)
241+
List<ModelArgs> outputs, List<ModelArgs> intermediateValues, List<TensorProto> initializers)
242242
{
243243
Contracts.CheckValue(nodes, nameof(nodes));
244244
Contracts.CheckValue(inputs, nameof(inputs));
245245
Contracts.CheckValue(outputs, nameof(outputs));
246-
Contracts.CheckValue(outputs, nameof(intermediateValues));
246+
Contracts.CheckValue(intermediateValues, nameof(intermediateValues));
247+
Contracts.CheckValue(initializers, nameof(initializers));
247248
Contracts.CheckNonEmpty(producerName, nameof(producerName));
248249
Contracts.CheckNonEmpty(name, nameof(name));
249250
Contracts.CheckNonEmpty(domain, nameof(domain));
@@ -282,6 +283,8 @@ public static ModelProto MakeModel(List<NodeProto> nodes, string producerName, s
282283
MakeValue(val, arg.Name, arg.DataType, arg.Dims, arg.DimParams);
283284
}
284285

286+
graph.Initializer.AddRange(initializers);
287+
285288
return model;
286289
}
287290

@@ -349,5 +352,77 @@ public static ModelArgs GetModelArgs(ColumnType type, string colName,
349352

350353
return new ModelArgs(name, dataType, dimsLocal, dimsParamLocal);
351354
}
355+
356+
// Make long scalar in ONNX from native C# number
357+
public static TensorProto MakeInt64(string name, long value)
358+
{
359+
var tensor = new TensorProto();
360+
tensor.Name = name;
361+
tensor.DataType = TensorProto.Types.DataType.Int64;
362+
tensor.Int64Data.Add(value);
363+
return tensor;
364+
}
365+
366+
// Make long vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
367+
public static TensorProto MakeInt64s(string name, IEnumerable<long> values, IEnumerable<long> dims = null)
368+
{
369+
var tensor = new TensorProto();
370+
tensor.Name = name;
371+
tensor.DataType = TensorProto.Types.DataType.Int64;
372+
tensor.Int64Data.AddRange(values);
373+
if (dims != null)
374+
tensor.Dims.AddRange(dims);
375+
else
376+
tensor.Dims.Add(values.Count());
377+
return tensor;
378+
}
379+
380+
// Make float scalar in ONNX from native C# number
381+
public static TensorProto MakeFloat(string name, float value)
382+
{
383+
var tensor = new TensorProto();
384+
tensor.Name = name;
385+
tensor.DataType = TensorProto.Types.DataType.Float;
386+
tensor.FloatData.Add(value);
387+
return tensor;
388+
}
389+
390+
// Make float vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
391+
public static TensorProto MakeFloats(string name, IEnumerable<float> values, IEnumerable<long> dims = null)
392+
{
393+
var tensor = new TensorProto();
394+
tensor.Name = name;
395+
tensor.DataType = TensorProto.Types.DataType.Float;
396+
tensor.FloatData.AddRange(values);
397+
if (dims != null)
398+
tensor.Dims.AddRange(dims);
399+
else
400+
tensor.Dims.Add(values.Count());
401+
return tensor;
402+
}
403+
404+
// Make string scalar in ONNX from native C# number
405+
public static TensorProto MakeString(string name, string value)
406+
{
407+
var tensor = new TensorProto();
408+
tensor.Name = name;
409+
tensor.DataType = TensorProto.Types.DataType.String;
410+
tensor.StringData.Add(StringToByteString(value));
411+
return tensor;
412+
}
413+
414+
// Make string vector (i.e., 1-D tensor) with dims=null. Otherwise, dims is used as the shape of the produced tensor.
415+
public static TensorProto MakeStrings(string name, IEnumerable<string> values, IEnumerable<long> dims = null)
416+
{
417+
var tensor = new TensorProto();
418+
tensor.Name = name;
419+
tensor.DataType = TensorProto.Types.DataType.String;
420+
tensor.StringData.AddRange(StringToByteString(values));
421+
if (dims != null)
422+
tensor.Dims.AddRange(dims);
423+
else
424+
tensor.Dims.Add(values.Count());
425+
return tensor;
426+
}
352427
}
353428
}

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

3+
<PropertyGroup>
4+
<AssemblyName>Microsoft.ML.Tests</AssemblyName>
5+
</PropertyGroup>
6+
37
<ItemGroup>
48
<Compile Remove="Scenarios\Api\AspirationalExamples.cs" />
59
</ItemGroup>

test/Microsoft.ML.Tests/OnnxTests.cs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
using Microsoft.ML.Legacy.Transforms;
99
using Microsoft.ML.Runtime.Api;
1010
using Microsoft.ML.Runtime.Data;
11+
using Microsoft.ML.Runtime.Model.Onnx;
1112
using Microsoft.ML.Runtime.RunTests;
1213
using System;
14+
using System.Collections.Generic;
1315
using System.IO;
1416
using System.Text.RegularExpressions;
1517
using Xunit;
@@ -51,6 +53,75 @@ public class BreastCancerMCPrediction
5153
public float[] Scores;
5254
}
5355

56+
[Fact]
57+
public void InitializerCreationTest()
58+
{
59+
using (var env = new ConsoleEnvironment())
60+
{
61+
// Create the actual implementation
62+
var ctxImpl = new OnnxContextImpl(env, "model", "ML.NET", "0", 0, "com.test");
63+
64+
// Use implementation as in the actual conversion code
65+
var ctx = ctxImpl as OnnxContext;
66+
ctx.AddInitializer(9.4f, "float");
67+
ctx.AddInitializer(17L, "int64");
68+
ctx.AddInitializer("36", "string");
69+
ctx.AddInitializer(new List<float> { 9.4f, 1.7f, 3.6f }, new List<long> { 1, 3 }, "floats");
70+
ctx.AddInitializer(new List<long> { 94L, 17L, 36L }, new List<long> { 1, 3 }, "int64s");
71+
ctx.AddInitializer(new List<string> { "94" , "17", "36" }, new List<long> { 1, 3 }, "strings");
72+
73+
var model = ctxImpl.MakeModel();
74+
75+
var floatScalar = model.Graph.Initializer[0];
76+
Assert.True(floatScalar.Name == "float");
77+
Assert.True(floatScalar.Dims.Count == 0);
78+
Assert.True(floatScalar.FloatData.Count == 1);
79+
Assert.True(floatScalar.FloatData[0] == 9.4f);
80+
81+
var int64Scalar = model.Graph.Initializer[1];
82+
Assert.True(int64Scalar.Name == "int64");
83+
Assert.True(int64Scalar.Dims.Count == 0);
84+
Assert.True(int64Scalar.Int64Data.Count == 1);
85+
Assert.True(int64Scalar.Int64Data[0] == 17L);
86+
87+
var stringScalar = model.Graph.Initializer[2];
88+
Assert.True(stringScalar.Name == "string");
89+
Assert.True(stringScalar.Dims.Count == 0);
90+
Assert.True(stringScalar.StringData.Count == 1);
91+
Assert.True(stringScalar.StringData[0].ToStringUtf8() == "36");
92+
93+
var floatsTensor = model.Graph.Initializer[3];
94+
Assert.True(floatsTensor.Name == "floats");
95+
Assert.True(floatsTensor.Dims.Count == 2);
96+
Assert.True(floatsTensor.Dims[0] == 1);
97+
Assert.True(floatsTensor.Dims[1] == 3);
98+
Assert.True(floatsTensor.FloatData.Count == 3);
99+
Assert.True(floatsTensor.FloatData[0] == 9.4f);
100+
Assert.True(floatsTensor.FloatData[1] == 1.7f);
101+
Assert.True(floatsTensor.FloatData[2] == 3.6f);
102+
103+
var int64sTensor = model.Graph.Initializer[4];
104+
Assert.True(int64sTensor.Name == "int64s");
105+
Assert.True(int64sTensor.Dims.Count == 2);
106+
Assert.True(int64sTensor.Dims[0] == 1);
107+
Assert.True(int64sTensor.Dims[1] == 3);
108+
Assert.True(int64sTensor.Int64Data.Count == 3);
109+
Assert.True(int64sTensor.Int64Data[0] == 94L);
110+
Assert.True(int64sTensor.Int64Data[1] == 17L);
111+
Assert.True(int64sTensor.Int64Data[2] == 36L);
112+
113+
var stringsTensor = model.Graph.Initializer[5];
114+
Assert.True(stringsTensor.Name == "strings");
115+
Assert.True(stringsTensor.Dims.Count == 2);
116+
Assert.True(stringsTensor.Dims[0] == 1);
117+
Assert.True(stringsTensor.Dims[1] == 3);
118+
Assert.True(stringsTensor.StringData.Count == 3);
119+
Assert.True(stringsTensor.StringData[0].ToStringUtf8() == "94");
120+
Assert.True(stringsTensor.StringData[1].ToStringUtf8() == "17");
121+
Assert.True(stringsTensor.StringData[2].ToStringUtf8() == "36");
122+
}
123+
}
124+
54125
[Fact]
55126
public void BinaryClassificationFastTreeSaveModelToOnnxTest()
56127
{

0 commit comments

Comments
 (0)