Skip to content

Commit 68841ad

Browse files
authored
Adding a test to show supported data types in TensorFlowTransform (#2101)
1 parent 3721d61 commit 68841ad

File tree

5 files changed

+200
-28
lines changed

5 files changed

+200
-28
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -632,22 +632,26 @@ public static Type TypeFromTensorType(TFDataType type)
632632
return typeof(float);
633633
case TFDataType.Double:
634634
return typeof(double);
635+
case TFDataType.Int8:
636+
return typeof(sbyte);
637+
case TFDataType.Int16:
638+
return typeof(short);
635639
case TFDataType.Int32:
636640
return typeof(int);
641+
case TFDataType.Int64:
642+
return typeof(long);
637643
case TFDataType.UInt8:
638644
return typeof(byte);
639-
case TFDataType.Int16:
640-
return typeof(short);
641-
case TFDataType.Int8:
642-
return typeof(sbyte);
645+
case TFDataType.UInt16:
646+
return typeof(ushort);
647+
case TFDataType.UInt32:
648+
return typeof(uint);
649+
case TFDataType.UInt64:
650+
return typeof(ulong);
643651
case TFDataType.String:
644652
throw new NotSupportedException();
645-
case TFDataType.Int64:
646-
return typeof(long);
647653
case TFDataType.Bool:
648654
return typeof(bool);
649-
case TFDataType.UInt16:
650-
return typeof(ushort);
651655
case TFDataType.Complex128:
652656
return typeof(Complex);
653657
default:
@@ -666,22 +670,26 @@ public static TFDataType TensorTypeFromType(Type type)
666670
return TFDataType.Float;
667671
if (type == typeof(double))
668672
return TFDataType.Double;
673+
if (type == typeof(sbyte))
674+
return TFDataType.Int8;
675+
if (type == typeof(short))
676+
return TFDataType.Int16;
669677
if (type == typeof(int))
670678
return TFDataType.Int32;
679+
if (type == typeof(long))
680+
return TFDataType.Int64;
671681
if (type == typeof(byte))
672682
return TFDataType.UInt8;
673-
if (type == typeof(short))
674-
return TFDataType.Int16;
675-
if (type == typeof(sbyte))
676-
return TFDataType.Int8;
683+
if (type == typeof(ushort))
684+
return TFDataType.UInt16;
685+
if (type == typeof(uint))
686+
return TFDataType.UInt32;
687+
if (type == typeof(ulong))
688+
return TFDataType.UInt64;
677689
if (type == typeof(string))
678690
return TFDataType.String;
679-
if (type == typeof(long))
680-
return TFDataType.Int64;
681691
if (type == typeof(bool))
682692
return TFDataType.Bool;
683-
if (type == typeof(ushort))
684-
return TFDataType.UInt16;
685693
if (type == typeof(Complex))
686694
return TFDataType.Complex128;
687695

@@ -696,22 +704,26 @@ private static unsafe object FetchSimple(TFDataType dt, IntPtr data)
696704
return *(float*)data;
697705
case TFDataType.Double:
698706
return *(double*)data;
707+
case TFDataType.Int8:
708+
return *(sbyte*)data;
709+
case TFDataType.Int16:
710+
return *(short*)data;
699711
case TFDataType.Int32:
700712
return *(int*)data;
713+
case TFDataType.Int64:
714+
return *(long*)data;
701715
case TFDataType.UInt8:
702716
return *(byte*)data;
703-
case TFDataType.Int16:
704-
return *(short*)data;
705-
case TFDataType.Int8:
706-
return *(sbyte*)data;
717+
case TFDataType.UInt16:
718+
return *(ushort*)data;
719+
case TFDataType.UInt32:
720+
return *(uint*)data;
721+
case TFDataType.UInt64:
722+
return *(ulong*)data;
707723
case TFDataType.String:
708724
throw new NotImplementedException();
709-
case TFDataType.Int64:
710-
return *(long*)data;
711725
case TFDataType.Bool:
712726
return *(bool*)data;
713-
case TFDataType.UInt16:
714-
return *(ushort*)data;
715727
case TFDataType.Complex128:
716728
return *(Complex*)data;
717729
default:

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,24 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
151151
return NumberType.R4;
152152
case TFDataType.Double:
153153
return NumberType.R8;
154-
case TFDataType.UInt16:
155-
return NumberType.U2;
156154
case TFDataType.UInt8:
157155
return NumberType.U1;
156+
case TFDataType.UInt16:
157+
return NumberType.U2;
158158
case TFDataType.UInt32:
159159
return NumberType.U4;
160160
case TFDataType.UInt64:
161161
return NumberType.U8;
162+
case TFDataType.Int8:
163+
return NumberType.I1;
162164
case TFDataType.Int16:
163165
return NumberType.I2;
164166
case TFDataType.Int32:
165167
return NumberType.I4;
166168
case TFDataType.Int64:
167169
return NumberType.I8;
170+
case TFDataType.Bool:
171+
return BoolType.Instance;
168172
default:
169173
return null;
170174
}
@@ -363,9 +367,11 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
363367
case TFDataType.UInt16:
364368
case TFDataType.UInt32:
365369
case TFDataType.UInt64:
370+
case TFDataType.Int8:
366371
case TFDataType.Int16:
367372
case TFDataType.Int32:
368373
case TFDataType.Int64:
374+
case TFDataType.Bool:
369375
return true;
370376
default:
371377
return false;

src/Microsoft.ML.TensorFlow/doc.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
<description>The name of each output column should match one of the operations in the TensorFlow graph.</description>
4949
</item>
5050
<item>
51-
<description>Currently, float, double, int, long, uint, ulong are the acceptable data types for input/output.</description>
51+
<description>Currently, double, float, long, int, short, sbyte, ulong, uint, ushort, byte and bool are the acceptable data types for input/output.</description>
5252
</item>
5353
<item>
5454
<description>Upon success, the transform will introduce a new column in <see cref="IDataView"/> corresponding to each output column specified.</description>

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
<NativeAssemblyReference Condition="'$(OS)' != 'Windows_NT'" Include="tensorflow_framework" />
4747
</ItemGroup>
4848
<ItemGroup>
49-
<PackageReference Include="Microsoft.ML.TensorFlow.TestModels" Version="0.0.4-test" />
49+
<PackageReference Include="Microsoft.ML.TensorFlow.TestModels" Version="0.0.6-test" />
5050
<PackageReference Include="Microsoft.ML.Onnx.TestModels" Version="0.0.2-test" />
5151
</ItemGroup>
5252
</Project>

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Linq;
78
using System.IO;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.ImageAnalytics;
@@ -69,6 +70,159 @@ public void TensorFlowTransformMatrixMultiplicationTest()
6970
}
7071
}
7172

73+
private class TypesData
74+
{
75+
[VectorType(2)]
76+
public double[] f64;
77+
[VectorType(2)]
78+
public float[] f32;
79+
[VectorType(2)]
80+
public long[] i64;
81+
[VectorType(2)]
82+
public int[] i32;
83+
[VectorType(2)]
84+
public short[] i16;
85+
[VectorType(2)]
86+
public sbyte[] i8;
87+
[VectorType(2)]
88+
public ulong[] u64;
89+
[VectorType(2)]
90+
public uint[] u32;
91+
[VectorType(2)]
92+
public ushort[] u16;
93+
[VectorType(2)]
94+
public byte[] u8;
95+
[VectorType(2)]
96+
public bool[] b;
97+
}
98+
99+
/// <summary>
100+
/// Test to ensure the supported datatypes can passed to TensorFlow .
101+
/// </summary>
102+
[ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only
103+
public void TensorFlowTransformInputOutputTypesTest()
104+
{
105+
// This an identity model which returns the same output as input.
106+
var model_location = "model_types_test";
107+
108+
//Data
109+
var data = new List<TypesData>(
110+
new TypesData[] {
111+
new TypesData() { f64 = new[] { -1.0, 2.0 },
112+
f32 = new[] { -1.0f, 2.0f },
113+
i64 = new[] { -1L, 2 },
114+
i32 = new[] { -1, 2 },
115+
i16 = new short[] { -1, 2 },
116+
i8 = new sbyte[] { -1, 2 },
117+
u64 = new ulong[] { 1, 2 },
118+
u32 = new uint[] { 1, 2 },
119+
u16 = new ushort[] { 1, 2 },
120+
u8 = new byte[] { 1, 2 },
121+
b = new bool[] { true, true },
122+
},
123+
new TypesData() { f64 = new[] { -3.0, 4.0 },
124+
f32 = new[] { -3.0f, 4.0f },
125+
i64 = new[] { -3L, 4 },
126+
i32 = new[] { -3, 4 },
127+
i16 = new short[] { -3, 4 },
128+
i8 = new sbyte[] { -3, 4 },
129+
u64 = new ulong[] { 3, 4 },
130+
u32 = new uint[] { 3, 4 },
131+
u16 = new ushort[] { 3, 4 },
132+
u8 = new byte[] { 3, 4 },
133+
b = new bool[] { false, false },
134+
} });
135+
136+
var mlContext = new MLContext(seed: 1, conc: 1);
137+
// Pipeline
138+
139+
var loader = ComponentCreation.CreateDataView(mlContext,data);
140+
141+
var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"};
142+
var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" };
143+
var trans = new TensorFlowTransformer(mlContext, model_location, inputs, outputs).Transform(loader); ;
144+
145+
using (var cursor = trans.GetRowCursor(a => true))
146+
{
147+
var f64getter = cursor.GetGetter<VBuffer<double>>(11);
148+
var f32getter = cursor.GetGetter<VBuffer<float>>(12);
149+
var i64getter = cursor.GetGetter<VBuffer<long>>(13);
150+
var i32getter = cursor.GetGetter<VBuffer<int>>(14);
151+
var i16getter = cursor.GetGetter<VBuffer<short>>(15);
152+
var i8getter = cursor.GetGetter<VBuffer<sbyte>>(16);
153+
var u64getter = cursor.GetGetter<VBuffer<ulong>>(17);
154+
var u32getter = cursor.GetGetter<VBuffer<uint>>(18);
155+
var u16getter = cursor.GetGetter<VBuffer<ushort>>(19);
156+
var u8getter = cursor.GetGetter<VBuffer<byte>>(20);
157+
var boolgetter = cursor.GetGetter<VBuffer<bool>>(21);
158+
159+
160+
VBuffer<double> f64 = default;
161+
VBuffer<float> f32 = default;
162+
VBuffer<long> i64 = default;
163+
VBuffer<int> i32 = default;
164+
VBuffer<short> i16 = default;
165+
VBuffer<sbyte> i8 = default;
166+
VBuffer<ulong> u64 = default;
167+
VBuffer<uint> u32 = default;
168+
VBuffer<ushort> u16 = default;
169+
VBuffer<byte> u8 = default;
170+
VBuffer<bool> b = default;
171+
foreach (var sample in data)
172+
{
173+
Assert.True(cursor.MoveNext());
174+
175+
f64getter(ref f64);
176+
f32getter(ref f32);
177+
i64getter(ref i64);
178+
i32getter(ref i32);
179+
i16getter(ref i16);
180+
i8getter(ref i8);
181+
u64getter(ref u64);
182+
u32getter(ref u32);
183+
u16getter(ref u16);
184+
u8getter(ref u8);
185+
u8getter(ref u8);
186+
boolgetter(ref b);
187+
188+
var f64Values = f64.GetValues();
189+
Assert.Equal(2, f64Values.Length);
190+
Assert.True(f64Values.SequenceEqual(sample.f64));
191+
var f32Values = f32.GetValues();
192+
Assert.Equal(2, f32Values.Length);
193+
Assert.True(f32Values.SequenceEqual(sample.f32));
194+
var i64Values = i64.GetValues();
195+
Assert.Equal(2, i64Values.Length);
196+
Assert.True(i64Values.SequenceEqual(sample.i64));
197+
var i32Values = i32.GetValues();
198+
Assert.Equal(2, i32Values.Length);
199+
Assert.True(i32Values.SequenceEqual(sample.i32));
200+
var i16Values = i16.GetValues();
201+
Assert.Equal(2, i16Values.Length);
202+
Assert.True(i16Values.SequenceEqual(sample.i16));
203+
var i8Values = i8.GetValues();
204+
Assert.Equal(2, i8Values.Length);
205+
Assert.True(i8Values.SequenceEqual(sample.i8));
206+
var u64Values = u64.GetValues();
207+
Assert.Equal(2, u64Values.Length);
208+
Assert.True(u64Values.SequenceEqual(sample.u64));
209+
var u32Values = u32.GetValues();
210+
Assert.Equal(2, u32Values.Length);
211+
Assert.True(u32Values.SequenceEqual(sample.u32));
212+
var u16Values = u16.GetValues();
213+
Assert.Equal(2, u16Values.Length);
214+
Assert.True(u16Values.SequenceEqual(sample.u16));
215+
var u8Values = u8.GetValues();
216+
Assert.Equal(2, u8Values.Length);
217+
Assert.True(u8Values.SequenceEqual(sample.u8));
218+
var bValues = b.GetValues();
219+
Assert.Equal(2, bValues.Length);
220+
Assert.True(bValues.SequenceEqual(sample.b));
221+
}
222+
Assert.False(cursor.MoveNext());
223+
}
224+
}
225+
72226
[Fact(Skip = "Model files are not available yet")]
73227
public void TensorFlowTransformObjectDetectionTest()
74228
{

0 commit comments

Comments
 (0)