diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs
index 56bf8686eb..993cbda20f 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs
@@ -632,22 +632,26 @@ public static Type TypeFromTensorType(TFDataType type)
return typeof(float);
case TFDataType.Double:
return typeof(double);
+ case TFDataType.Int8:
+ return typeof(sbyte);
+ case TFDataType.Int16:
+ return typeof(short);
case TFDataType.Int32:
return typeof(int);
+ case TFDataType.Int64:
+ return typeof(long);
case TFDataType.UInt8:
return typeof(byte);
- case TFDataType.Int16:
- return typeof(short);
- case TFDataType.Int8:
- return typeof(sbyte);
+ case TFDataType.UInt16:
+ return typeof(ushort);
+ case TFDataType.UInt32:
+ return typeof(uint);
+ case TFDataType.UInt64:
+ return typeof(ulong);
case TFDataType.String:
throw new NotSupportedException();
- case TFDataType.Int64:
- return typeof(long);
case TFDataType.Bool:
return typeof(bool);
- case TFDataType.UInt16:
- return typeof(ushort);
case TFDataType.Complex128:
return typeof(Complex);
default:
@@ -666,22 +670,26 @@ public static TFDataType TensorTypeFromType(Type type)
return TFDataType.Float;
if (type == typeof(double))
return TFDataType.Double;
+ if (type == typeof(sbyte))
+ return TFDataType.Int8;
+ if (type == typeof(short))
+ return TFDataType.Int16;
if (type == typeof(int))
return TFDataType.Int32;
+ if (type == typeof(long))
+ return TFDataType.Int64;
if (type == typeof(byte))
return TFDataType.UInt8;
- if (type == typeof(short))
- return TFDataType.Int16;
- if (type == typeof(sbyte))
- return TFDataType.Int8;
+ if (type == typeof(ushort))
+ return TFDataType.UInt16;
+ if (type == typeof(uint))
+ return TFDataType.UInt32;
+ if (type == typeof(ulong))
+ return TFDataType.UInt64;
if (type == typeof(string))
return TFDataType.String;
- if (type == typeof(long))
- return TFDataType.Int64;
if (type == typeof(bool))
return TFDataType.Bool;
- if (type == typeof(ushort))
- return TFDataType.UInt16;
if (type == typeof(Complex))
return TFDataType.Complex128;
@@ -696,22 +704,26 @@ private static unsafe object FetchSimple(TFDataType dt, IntPtr data)
return *(float*)data;
case TFDataType.Double:
return *(double*)data;
+ case TFDataType.Int8:
+ return *(sbyte*)data;
+ case TFDataType.Int16:
+ return *(short*)data;
case TFDataType.Int32:
return *(int*)data;
+ case TFDataType.Int64:
+ return *(long*)data;
case TFDataType.UInt8:
return *(byte*)data;
- case TFDataType.Int16:
- return *(short*)data;
- case TFDataType.Int8:
- return *(sbyte*)data;
+ case TFDataType.UInt16:
+ return *(ushort*)data;
+ case TFDataType.UInt32:
+ return *(uint*)data;
+ case TFDataType.UInt64:
+ return *(ulong*)data;
case TFDataType.String:
throw new NotImplementedException();
- case TFDataType.Int64:
- return *(long*)data;
case TFDataType.Bool:
return *(bool*)data;
- case TFDataType.UInt16:
- return *(ushort*)data;
case TFDataType.Complex128:
return *(Complex*)data;
default:
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
index 48d6417ce5..b1adc85a17 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
@@ -151,20 +151,24 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
return NumberType.R4;
case TFDataType.Double:
return NumberType.R8;
- case TFDataType.UInt16:
- return NumberType.U2;
case TFDataType.UInt8:
return NumberType.U1;
+ case TFDataType.UInt16:
+ return NumberType.U2;
case TFDataType.UInt32:
return NumberType.U4;
case TFDataType.UInt64:
return NumberType.U8;
+ case TFDataType.Int8:
+ return NumberType.I1;
case TFDataType.Int16:
return NumberType.I2;
case TFDataType.Int32:
return NumberType.I4;
case TFDataType.Int64:
return NumberType.I8;
+ case TFDataType.Bool:
+ return BoolType.Instance;
default:
return null;
}
@@ -363,9 +367,11 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
case TFDataType.UInt16:
case TFDataType.UInt32:
case TFDataType.UInt64:
+ case TFDataType.Int8:
case TFDataType.Int16:
case TFDataType.Int32:
case TFDataType.Int64:
+ case TFDataType.Bool:
return true;
default:
return false;
diff --git a/src/Microsoft.ML.TensorFlow/doc.xml b/src/Microsoft.ML.TensorFlow/doc.xml
index 58ff524240..49ab43ea5d 100644
--- a/src/Microsoft.ML.TensorFlow/doc.xml
+++ b/src/Microsoft.ML.TensorFlow/doc.xml
@@ -48,7 +48,7 @@
The name of each output column should match one of the operations in the TensorFlow graph.
-
- Currently, float, double, int, long, uint, ulong are the acceptable data types for input/output.
+ Currently, double, float, long, int, short, sbyte, ulong, uint, ushort, byte and bool are the acceptable data types for input/output.
-
Upon success, the transform will introduce a new column in corresponding to each output column specified.
diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
index c56cfd67f1..e655c73f33 100644
--- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
+++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
@@ -46,7 +46,7 @@
-
+
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index c4ad74b475..441fcc4fc1 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.IO;
using Microsoft.ML.Data;
using Microsoft.ML.ImageAnalytics;
@@ -69,6 +70,159 @@ public void TensorFlowTransformMatrixMultiplicationTest()
}
}
+ private class TypesData
+ {
+ [VectorType(2)]
+ public double[] f64;
+ [VectorType(2)]
+ public float[] f32;
+ [VectorType(2)]
+ public long[] i64;
+ [VectorType(2)]
+ public int[] i32;
+ [VectorType(2)]
+ public short[] i16;
+ [VectorType(2)]
+ public sbyte[] i8;
+ [VectorType(2)]
+ public ulong[] u64;
+ [VectorType(2)]
+ public uint[] u32;
+ [VectorType(2)]
+ public ushort[] u16;
+ [VectorType(2)]
+ public byte[] u8;
+ [VectorType(2)]
+ public bool[] b;
+ }
+
+ ///
+ /// Test to ensure the supported datatypes can passed to TensorFlow .
+ ///
+ [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only
+ public void TensorFlowTransformInputOutputTypesTest()
+ {
+ // This an identity model which returns the same output as input.
+ var model_location = "model_types_test";
+
+ //Data
+ var data = new List(
+ new TypesData[] {
+ new TypesData() { f64 = new[] { -1.0, 2.0 },
+ f32 = new[] { -1.0f, 2.0f },
+ i64 = new[] { -1L, 2 },
+ i32 = new[] { -1, 2 },
+ i16 = new short[] { -1, 2 },
+ i8 = new sbyte[] { -1, 2 },
+ u64 = new ulong[] { 1, 2 },
+ u32 = new uint[] { 1, 2 },
+ u16 = new ushort[] { 1, 2 },
+ u8 = new byte[] { 1, 2 },
+ b = new bool[] { true, true },
+ },
+ new TypesData() { f64 = new[] { -3.0, 4.0 },
+ f32 = new[] { -3.0f, 4.0f },
+ i64 = new[] { -3L, 4 },
+ i32 = new[] { -3, 4 },
+ i16 = new short[] { -3, 4 },
+ i8 = new sbyte[] { -3, 4 },
+ u64 = new ulong[] { 3, 4 },
+ u32 = new uint[] { 3, 4 },
+ u16 = new ushort[] { 3, 4 },
+ u8 = new byte[] { 3, 4 },
+ b = new bool[] { false, false },
+ } });
+
+ var mlContext = new MLContext(seed: 1, conc: 1);
+ // Pipeline
+
+ var loader = ComponentCreation.CreateDataView(mlContext,data);
+
+ var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"};
+ var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" };
+ var trans = new TensorFlowTransformer(mlContext, model_location, inputs, outputs).Transform(loader); ;
+
+ using (var cursor = trans.GetRowCursor(a => true))
+ {
+ var f64getter = cursor.GetGetter>(11);
+ var f32getter = cursor.GetGetter>(12);
+ var i64getter = cursor.GetGetter>(13);
+ var i32getter = cursor.GetGetter>(14);
+ var i16getter = cursor.GetGetter>(15);
+ var i8getter = cursor.GetGetter>(16);
+ var u64getter = cursor.GetGetter>(17);
+ var u32getter = cursor.GetGetter>(18);
+ var u16getter = cursor.GetGetter>(19);
+ var u8getter = cursor.GetGetter>(20);
+ var boolgetter = cursor.GetGetter>(21);
+
+
+ VBuffer f64 = default;
+ VBuffer f32 = default;
+ VBuffer i64 = default;
+ VBuffer i32 = default;
+ VBuffer i16 = default;
+ VBuffer i8 = default;
+ VBuffer u64 = default;
+ VBuffer u32 = default;
+ VBuffer u16 = default;
+ VBuffer u8 = default;
+ VBuffer b = default;
+ foreach (var sample in data)
+ {
+ Assert.True(cursor.MoveNext());
+
+ f64getter(ref f64);
+ f32getter(ref f32);
+ i64getter(ref i64);
+ i32getter(ref i32);
+ i16getter(ref i16);
+ i8getter(ref i8);
+ u64getter(ref u64);
+ u32getter(ref u32);
+ u16getter(ref u16);
+ u8getter(ref u8);
+ u8getter(ref u8);
+ boolgetter(ref b);
+
+ var f64Values = f64.GetValues();
+ Assert.Equal(2, f64Values.Length);
+ Assert.True(f64Values.SequenceEqual(sample.f64));
+ var f32Values = f32.GetValues();
+ Assert.Equal(2, f32Values.Length);
+ Assert.True(f32Values.SequenceEqual(sample.f32));
+ var i64Values = i64.GetValues();
+ Assert.Equal(2, i64Values.Length);
+ Assert.True(i64Values.SequenceEqual(sample.i64));
+ var i32Values = i32.GetValues();
+ Assert.Equal(2, i32Values.Length);
+ Assert.True(i32Values.SequenceEqual(sample.i32));
+ var i16Values = i16.GetValues();
+ Assert.Equal(2, i16Values.Length);
+ Assert.True(i16Values.SequenceEqual(sample.i16));
+ var i8Values = i8.GetValues();
+ Assert.Equal(2, i8Values.Length);
+ Assert.True(i8Values.SequenceEqual(sample.i8));
+ var u64Values = u64.GetValues();
+ Assert.Equal(2, u64Values.Length);
+ Assert.True(u64Values.SequenceEqual(sample.u64));
+ var u32Values = u32.GetValues();
+ Assert.Equal(2, u32Values.Length);
+ Assert.True(u32Values.SequenceEqual(sample.u32));
+ var u16Values = u16.GetValues();
+ Assert.Equal(2, u16Values.Length);
+ Assert.True(u16Values.SequenceEqual(sample.u16));
+ var u8Values = u8.GetValues();
+ Assert.Equal(2, u8Values.Length);
+ Assert.True(u8Values.SequenceEqual(sample.u8));
+ var bValues = b.GetValues();
+ Assert.Equal(2, bValues.Length);
+ Assert.True(bValues.SequenceEqual(sample.b));
+ }
+ Assert.False(cursor.MoveNext());
+ }
+ }
+
[Fact(Skip = "Model files are not available yet")]
public void TensorFlowTransformObjectDetectionTest()
{