diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs index 56bf8686eb..993cbda20f 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs @@ -632,22 +632,26 @@ public static Type TypeFromTensorType(TFDataType type) return typeof(float); case TFDataType.Double: return typeof(double); + case TFDataType.Int8: + return typeof(sbyte); + case TFDataType.Int16: + return typeof(short); case TFDataType.Int32: return typeof(int); + case TFDataType.Int64: + return typeof(long); case TFDataType.UInt8: return typeof(byte); - case TFDataType.Int16: - return typeof(short); - case TFDataType.Int8: - return typeof(sbyte); + case TFDataType.UInt16: + return typeof(ushort); + case TFDataType.UInt32: + return typeof(uint); + case TFDataType.UInt64: + return typeof(ulong); case TFDataType.String: throw new NotSupportedException(); - case TFDataType.Int64: - return typeof(long); case TFDataType.Bool: return typeof(bool); - case TFDataType.UInt16: - return typeof(ushort); case TFDataType.Complex128: return typeof(Complex); default: @@ -666,22 +670,26 @@ public static TFDataType TensorTypeFromType(Type type) return TFDataType.Float; if (type == typeof(double)) return TFDataType.Double; + if (type == typeof(sbyte)) + return TFDataType.Int8; + if (type == typeof(short)) + return TFDataType.Int16; if (type == typeof(int)) return TFDataType.Int32; + if (type == typeof(long)) + return TFDataType.Int64; if (type == typeof(byte)) return TFDataType.UInt8; - if (type == typeof(short)) - return TFDataType.Int16; - if (type == typeof(sbyte)) - return TFDataType.Int8; + if (type == typeof(ushort)) + return TFDataType.UInt16; + if (type == typeof(uint)) + return TFDataType.UInt32; + if (type == typeof(ulong)) + return TFDataType.UInt64; if (type == typeof(string)) return TFDataType.String; - if (type == typeof(long)) - return TFDataType.Int64; if (type == typeof(bool)) return TFDataType.Bool; - if (type == typeof(ushort)) - return TFDataType.UInt16; if (type == typeof(Complex)) return TFDataType.Complex128; @@ -696,22 +704,26 @@ private static unsafe object FetchSimple(TFDataType dt, IntPtr data) return *(float*)data; case TFDataType.Double: return *(double*)data; + case TFDataType.Int8: + return *(sbyte*)data; + case TFDataType.Int16: + return *(short*)data; case TFDataType.Int32: return *(int*)data; + case TFDataType.Int64: + return *(long*)data; case TFDataType.UInt8: return *(byte*)data; - case TFDataType.Int16: - return *(short*)data; - case TFDataType.Int8: - return *(sbyte*)data; + case TFDataType.UInt16: + return *(ushort*)data; + case TFDataType.UInt32: + return *(uint*)data; + case TFDataType.UInt64: + return *(ulong*)data; case TFDataType.String: throw new NotImplementedException(); - case TFDataType.Int64: - return *(long*)data; case TFDataType.Bool: return *(bool*)data; - case TFDataType.UInt16: - return *(ushort*)data; case TFDataType.Complex128: return *(Complex*)data; default: diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 48d6417ce5..b1adc85a17 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -151,20 +151,24 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type) return NumberType.R4; case TFDataType.Double: return NumberType.R8; - case TFDataType.UInt16: - return NumberType.U2; case TFDataType.UInt8: return NumberType.U1; + case TFDataType.UInt16: + return NumberType.U2; case TFDataType.UInt32: return NumberType.U4; case TFDataType.UInt64: return NumberType.U8; + case TFDataType.Int8: + return NumberType.I1; case TFDataType.Int16: return NumberType.I2; case TFDataType.Int32: return NumberType.I4; case TFDataType.Int64: return NumberType.I8; + case TFDataType.Bool: + return BoolType.Instance; default: return null; } @@ -363,9 +367,11 @@ internal static bool IsTypeSupported(TFDataType tfoutput) case TFDataType.UInt16: case TFDataType.UInt32: case TFDataType.UInt64: + case TFDataType.Int8: case TFDataType.Int16: case TFDataType.Int32: case TFDataType.Int64: + case TFDataType.Bool: return true; default: return false; diff --git a/src/Microsoft.ML.TensorFlow/doc.xml b/src/Microsoft.ML.TensorFlow/doc.xml index 58ff524240..49ab43ea5d 100644 --- a/src/Microsoft.ML.TensorFlow/doc.xml +++ b/src/Microsoft.ML.TensorFlow/doc.xml @@ -48,7 +48,7 @@ The name of each output column should match one of the operations in the TensorFlow graph. - Currently, float, double, int, long, uint, ulong are the acceptable data types for input/output. + Currently, double, float, long, int, short, sbyte, ulong, uint, ushort, byte and bool are the acceptable data types for input/output. Upon success, the transform will introduce a new column in corresponding to each output column specified. diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index c56cfd67f1..e655c73f33 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -46,7 +46,7 @@ - + diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index c4ad74b475..441fcc4fc1 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.IO; using Microsoft.ML.Data; using Microsoft.ML.ImageAnalytics; @@ -69,6 +70,159 @@ public void TensorFlowTransformMatrixMultiplicationTest() } } + private class TypesData + { + [VectorType(2)] + public double[] f64; + [VectorType(2)] + public float[] f32; + [VectorType(2)] + public long[] i64; + [VectorType(2)] + public int[] i32; + [VectorType(2)] + public short[] i16; + [VectorType(2)] + public sbyte[] i8; + [VectorType(2)] + public ulong[] u64; + [VectorType(2)] + public uint[] u32; + [VectorType(2)] + public ushort[] u16; + [VectorType(2)] + public byte[] u8; + [VectorType(2)] + public bool[] b; + } + + /// + /// Test to ensure the supported datatypes can passed to TensorFlow . + /// + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only + public void TensorFlowTransformInputOutputTypesTest() + { + // This an identity model which returns the same output as input. + var model_location = "model_types_test"; + + //Data + var data = new List( + new TypesData[] { + new TypesData() { f64 = new[] { -1.0, 2.0 }, + f32 = new[] { -1.0f, 2.0f }, + i64 = new[] { -1L, 2 }, + i32 = new[] { -1, 2 }, + i16 = new short[] { -1, 2 }, + i8 = new sbyte[] { -1, 2 }, + u64 = new ulong[] { 1, 2 }, + u32 = new uint[] { 1, 2 }, + u16 = new ushort[] { 1, 2 }, + u8 = new byte[] { 1, 2 }, + b = new bool[] { true, true }, + }, + new TypesData() { f64 = new[] { -3.0, 4.0 }, + f32 = new[] { -3.0f, 4.0f }, + i64 = new[] { -3L, 4 }, + i32 = new[] { -3, 4 }, + i16 = new short[] { -3, 4 }, + i8 = new sbyte[] { -3, 4 }, + u64 = new ulong[] { 3, 4 }, + u32 = new uint[] { 3, 4 }, + u16 = new ushort[] { 3, 4 }, + u8 = new byte[] { 3, 4 }, + b = new bool[] { false, false }, + } }); + + var mlContext = new MLContext(seed: 1, conc: 1); + // Pipeline + + var loader = ComponentCreation.CreateDataView(mlContext,data); + + var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"}; + var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" }; + var trans = new TensorFlowTransformer(mlContext, model_location, inputs, outputs).Transform(loader); ; + + using (var cursor = trans.GetRowCursor(a => true)) + { + var f64getter = cursor.GetGetter>(11); + var f32getter = cursor.GetGetter>(12); + var i64getter = cursor.GetGetter>(13); + var i32getter = cursor.GetGetter>(14); + var i16getter = cursor.GetGetter>(15); + var i8getter = cursor.GetGetter>(16); + var u64getter = cursor.GetGetter>(17); + var u32getter = cursor.GetGetter>(18); + var u16getter = cursor.GetGetter>(19); + var u8getter = cursor.GetGetter>(20); + var boolgetter = cursor.GetGetter>(21); + + + VBuffer f64 = default; + VBuffer f32 = default; + VBuffer i64 = default; + VBuffer i32 = default; + VBuffer i16 = default; + VBuffer i8 = default; + VBuffer u64 = default; + VBuffer u32 = default; + VBuffer u16 = default; + VBuffer u8 = default; + VBuffer b = default; + foreach (var sample in data) + { + Assert.True(cursor.MoveNext()); + + f64getter(ref f64); + f32getter(ref f32); + i64getter(ref i64); + i32getter(ref i32); + i16getter(ref i16); + i8getter(ref i8); + u64getter(ref u64); + u32getter(ref u32); + u16getter(ref u16); + u8getter(ref u8); + u8getter(ref u8); + boolgetter(ref b); + + var f64Values = f64.GetValues(); + Assert.Equal(2, f64Values.Length); + Assert.True(f64Values.SequenceEqual(sample.f64)); + var f32Values = f32.GetValues(); + Assert.Equal(2, f32Values.Length); + Assert.True(f32Values.SequenceEqual(sample.f32)); + var i64Values = i64.GetValues(); + Assert.Equal(2, i64Values.Length); + Assert.True(i64Values.SequenceEqual(sample.i64)); + var i32Values = i32.GetValues(); + Assert.Equal(2, i32Values.Length); + Assert.True(i32Values.SequenceEqual(sample.i32)); + var i16Values = i16.GetValues(); + Assert.Equal(2, i16Values.Length); + Assert.True(i16Values.SequenceEqual(sample.i16)); + var i8Values = i8.GetValues(); + Assert.Equal(2, i8Values.Length); + Assert.True(i8Values.SequenceEqual(sample.i8)); + var u64Values = u64.GetValues(); + Assert.Equal(2, u64Values.Length); + Assert.True(u64Values.SequenceEqual(sample.u64)); + var u32Values = u32.GetValues(); + Assert.Equal(2, u32Values.Length); + Assert.True(u32Values.SequenceEqual(sample.u32)); + var u16Values = u16.GetValues(); + Assert.Equal(2, u16Values.Length); + Assert.True(u16Values.SequenceEqual(sample.u16)); + var u8Values = u8.GetValues(); + Assert.Equal(2, u8Values.Length); + Assert.True(u8Values.SequenceEqual(sample.u8)); + var bValues = b.GetValues(); + Assert.Equal(2, bValues.Length); + Assert.True(bValues.SequenceEqual(sample.b)); + } + Assert.False(cursor.MoveNext()); + } + } + [Fact(Skip = "Model files are not available yet")] public void TensorFlowTransformObjectDetectionTest() {