Skip to content

Commit 5666dd1

Browse files
authored
Handle inputs with unknown shapes in TensorFlow (#857)
* Enable scoring Inception model and SSD model * Add a unit test for pipeline API. * Update after merge with master * Address PR comments * Add a unit test, and fix a bug in the 'getter' creation method. * Change new unit test to use LogisticRegression instead of LightGBM, so we don't need a separate nuget for it to work. * Address PR comment.
1 parent f0f04ef commit 5666dd1

File tree

8 files changed

+333
-53
lines changed

8 files changed

+333
-53
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ public TFTensor(long[] data) : base(SetupTensor(TFDataType.Int64, data, size: 8)
410410
public TFTensor(Complex[] data) : base(SetupTensor(TFDataType.Complex128, data, size: 16)) { }
411411

412412
// Convenience function to factor out the setup of a new tensor from an array
413-
internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int size)
413+
internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int count, int size)
414414
{
415-
return SetupTensor(dt, dims, data, start: 0, count: data.Length, size: size);
415+
return SetupTensor(dt, dims, data, 0, count, size);
416416
}
417417

418418
// Convenience function to factor out the setup of a new tensor from an array
@@ -422,7 +422,7 @@ internal static IntPtr SetupTensor(TFDataType dt, Array data, int size)
422422
for (int i = 0; i < dims.Length; i++)
423423
dims[i] = data.GetLength(i);
424424

425-
return SetupTensor(dt, dims, data, start: 0, count: data.Length, size: size);
425+
return SetupTensor(dt, dims, data, 0, data.Length, size);
426426
}
427427

428428
// Use for single dimension arrays

src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.cs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,60 +76,61 @@ public static TFTensor CreateScalar<T>(T data)
7676
/// </summary>
7777
/// <typeparam name="T[]">.NET type of tensor to create</typeparam>
7878
/// <param name="data">value of tensor</param>
79+
/// <param name="count">The number of elements in the tensor</param>
7980
/// <param name="shape">shape of tensor</param>
80-
public static TFTensor Create<T>(T[] data, TFShape shape)
81+
public static TFTensor Create<T>(T[] data, int count, TFShape shape)
8182
{
8283
if (typeof(T) == typeof(System.Boolean))
8384
{
84-
return new TFTensor(SetupTensor(TFDataType.Bool, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
85+
return new TFTensor(SetupTensor(TFDataType.Bool, shape, (Array)(object)data, 0, count, 4));
8586
}
8687
else if (typeof(T) == typeof(System.Byte))
8788
{
88-
return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
89+
return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, count, 1));
8990
}
9091
else if (typeof(T) == typeof(System.Char))
9192
{
92-
return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
93+
return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, count, 1));
9394
}
9495
else if (typeof(T) == typeof(System.Numerics.Complex))
9596
{
96-
return new TFTensor(SetupTensor(TFDataType.Complex128, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 16));
97+
return new TFTensor(SetupTensor(TFDataType.Complex128, shape, (Array)(object)data, 0, count, 16));
9798
}
9899
else if (typeof(T) == typeof(System.Double))
99100
{
100-
return new TFTensor(SetupTensor(TFDataType.Double, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
101+
return new TFTensor(SetupTensor(TFDataType.Double, shape, (Array)(object)data, 0, count, 8));
101102
}
102103
else if (typeof(T) == typeof(System.Single))
103104
{
104-
return new TFTensor(SetupTensor(TFDataType.Float, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
105+
return new TFTensor(SetupTensor(TFDataType.Float, shape, (Array)(object)data, 0, count, 4));
105106
}
106107
else if (typeof(T) == typeof(System.Int32))
107108
{
108-
return new TFTensor(SetupTensor(TFDataType.Int32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
109+
return new TFTensor(SetupTensor(TFDataType.Int32, shape, (Array)(object)data, 0, count, 4));
109110
}
110111
else if (typeof(T) == typeof(System.Int64))
111112
{
112-
return new TFTensor(SetupTensor(TFDataType.Int64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
113+
return new TFTensor(SetupTensor(TFDataType.Int64, shape, (Array)(object)data, 0, count, 8));
113114
}
114115
else if (typeof(T) == typeof(System.SByte))
115116
{
116-
return new TFTensor(SetupTensor(TFDataType.Int8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1));
117+
return new TFTensor(SetupTensor(TFDataType.Int8, shape, (Array)(object)data, 0, count, 1));
117118
}
118119
else if (typeof(T) == typeof(System.Int16))
119120
{
120-
return new TFTensor(SetupTensor(TFDataType.Int16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2));
121+
return new TFTensor(SetupTensor(TFDataType.Int16, shape, (Array)(object)data, 0, count, 2));
121122
}
122123
else if (typeof(T) == typeof(System.UInt32))
123124
{
124-
return new TFTensor(SetupTensor(TFDataType.UInt32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4));
125+
return new TFTensor(SetupTensor(TFDataType.UInt32, shape, (Array)(object)data, 0, count, 4));
125126
}
126127
else if (typeof(T) == typeof(System.UInt64))
127128
{
128-
return new TFTensor(SetupTensor(TFDataType.UInt64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8));
129+
return new TFTensor(SetupTensor(TFDataType.UInt64, shape, (Array)(object)data, 0, count, 8));
129130
}
130131
else if (typeof(T) == typeof(System.UInt16))
131132
{
132-
return new TFTensor(SetupTensor(TFDataType.UInt16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2));
133+
return new TFTensor(SetupTensor(TFDataType.UInt16, shape, (Array)(object)data, 0, count, 2));
133134
}
134135
// note that we will get here for jagged arrays, which is intententional since we'd need to copy them.
135136
throw new NotSupportedException($"Unsupported type {typeof(T)}");

src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.tt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,14 @@ namespace Microsoft.ML.Transforms.TensorFlow
3737
/// </summary>
3838
/// <typeparam name="T[]">.NET type of tensor to create</typeparam>
3939
/// <param name="data">value of tensor</param>
40+
/// <param name="count">The number of elements in the tensor</param>
4041
/// <param name="shape">shape of tensor</param>
41-
public static TFTensor Create<T>(T[] data, TFShape shape)
42+
public static TFTensor Create<T>(T[] data, int count, TFShape shape)
4243
{
4344
<# foreach (TypeConfiguration type in typeConfiguration) { #>
4445
<#=GenerateIfStatementHeader(type)#>
4546
{
46-
return new TFTensor(SetupTensor(TFDataType.<#=type.TFDataType#>, shape, (Array)(object)data, 0, ((Array)(object)data).Length, <#=type.Size#>));
47+
return new TFTensor(SetupTensor(TFDataType.<#=type.TFDataType#>, shape, (Array)(object)data, 0, count, <#=type.Size#>));
4748
}
4849
<# } #>
4950
// note that we will get here for jagged arrays, which is intententional since we'd need to copy them.

src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
using System;
66
using System.Runtime.InteropServices;
7-
using System.Text;
8-
using System.Globalization;
97
using System.Linq;
108

119
// We use this TF_Xxx as the native "TF_Xxx *" as those are opaque
@@ -24,9 +22,7 @@
2422
using TF_DeviceList = System.IntPtr;
2523

2624
using size_t = System.UIntPtr;
27-
using System.Numerics;
2825
using System.Collections.Generic;
29-
using System.Linq.Expressions;
3026

3127
#pragma warning disable MSML_GeneralName
3228
#pragma warning disable MSML_PrivateFieldName

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
7+
using System.Linq;
68
using System.Runtime.InteropServices;
9+
using Microsoft.ML.Runtime;
710
using Microsoft.ML.Runtime.Data;
811
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
912

@@ -30,6 +33,10 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type)
3033
return NumberType.R4;
3134
case TFDataType.Double:
3235
return NumberType.R8;
36+
case TFDataType.UInt16:
37+
return NumberType.U2;
38+
case TFDataType.UInt8:
39+
return NumberType.U1;
3340
case TFDataType.UInt32:
3441
return NumberType.U4;
3542
case TFDataType.UInt64:
@@ -57,6 +64,10 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
5764
{
5865
case TFDataType.Float:
5966
case TFDataType.Double:
67+
case TFDataType.UInt8:
68+
case TFDataType.UInt16:
69+
case TFDataType.UInt32:
70+
case TFDataType.UInt64:
6071
return true;
6172
default:
6273
return false;

0 commit comments

Comments
 (0)