|
4 | 4 |
|
5 | 5 | using System;
|
6 | 6 | using System.Collections.Generic;
|
| 7 | +using System.Linq; |
7 | 8 | using System.IO;
|
8 | 9 | using Microsoft.ML.Data;
|
9 | 10 | using Microsoft.ML.ImageAnalytics;
|
@@ -69,6 +70,159 @@ public void TensorFlowTransformMatrixMultiplicationTest()
|
69 | 70 | }
|
70 | 71 | }
|
71 | 72 |
|
| 73 | + private class TypesData |
| 74 | + { |
| 75 | + [VectorType(2)] |
| 76 | + public double[] f64; |
| 77 | + [VectorType(2)] |
| 78 | + public float[] f32; |
| 79 | + [VectorType(2)] |
| 80 | + public long[] i64; |
| 81 | + [VectorType(2)] |
| 82 | + public int[] i32; |
| 83 | + [VectorType(2)] |
| 84 | + public short[] i16; |
| 85 | + [VectorType(2)] |
| 86 | + public sbyte[] i8; |
| 87 | + [VectorType(2)] |
| 88 | + public ulong[] u64; |
| 89 | + [VectorType(2)] |
| 90 | + public uint[] u32; |
| 91 | + [VectorType(2)] |
| 92 | + public ushort[] u16; |
| 93 | + [VectorType(2)] |
| 94 | + public byte[] u8; |
| 95 | + [VectorType(2)] |
| 96 | + public bool[] b; |
| 97 | + } |
| 98 | + |
| 99 | + /// <summary> |
| 100 | + /// Test to ensure the supported datatypes can passed to TensorFlow . |
| 101 | + /// </summary> |
| 102 | + [ConditionalFact(typeof(Environment), nameof(Environment.Is64BitProcess))] // TensorFlow is 64-bit only |
| 103 | + public void TensorFlowTransformInputOutputTypesTest() |
| 104 | + { |
| 105 | + // This an identity model which returns the same output as input. |
| 106 | + var model_location = "model_types_test"; |
| 107 | + |
| 108 | + //Data |
| 109 | + var data = new List<TypesData>( |
| 110 | + new TypesData[] { |
| 111 | + new TypesData() { f64 = new[] { -1.0, 2.0 }, |
| 112 | + f32 = new[] { -1.0f, 2.0f }, |
| 113 | + i64 = new[] { -1L, 2 }, |
| 114 | + i32 = new[] { -1, 2 }, |
| 115 | + i16 = new short[] { -1, 2 }, |
| 116 | + i8 = new sbyte[] { -1, 2 }, |
| 117 | + u64 = new ulong[] { 1, 2 }, |
| 118 | + u32 = new uint[] { 1, 2 }, |
| 119 | + u16 = new ushort[] { 1, 2 }, |
| 120 | + u8 = new byte[] { 1, 2 }, |
| 121 | + b = new bool[] { true, true }, |
| 122 | + }, |
| 123 | + new TypesData() { f64 = new[] { -3.0, 4.0 }, |
| 124 | + f32 = new[] { -3.0f, 4.0f }, |
| 125 | + i64 = new[] { -3L, 4 }, |
| 126 | + i32 = new[] { -3, 4 }, |
| 127 | + i16 = new short[] { -3, 4 }, |
| 128 | + i8 = new sbyte[] { -3, 4 }, |
| 129 | + u64 = new ulong[] { 3, 4 }, |
| 130 | + u32 = new uint[] { 3, 4 }, |
| 131 | + u16 = new ushort[] { 3, 4 }, |
| 132 | + u8 = new byte[] { 3, 4 }, |
| 133 | + b = new bool[] { false, false }, |
| 134 | + } }); |
| 135 | + |
| 136 | + var mlContext = new MLContext(seed: 1, conc: 1); |
| 137 | + // Pipeline |
| 138 | + |
| 139 | + var loader = ComponentCreation.CreateDataView(mlContext,data); |
| 140 | + |
| 141 | + var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"}; |
| 142 | + var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" }; |
| 143 | + var trans = new TensorFlowTransformer(mlContext, model_location, inputs, outputs).Transform(loader); ; |
| 144 | + |
| 145 | + using (var cursor = trans.GetRowCursor(a => true)) |
| 146 | + { |
| 147 | + var f64getter = cursor.GetGetter<VBuffer<double>>(11); |
| 148 | + var f32getter = cursor.GetGetter<VBuffer<float>>(12); |
| 149 | + var i64getter = cursor.GetGetter<VBuffer<long>>(13); |
| 150 | + var i32getter = cursor.GetGetter<VBuffer<int>>(14); |
| 151 | + var i16getter = cursor.GetGetter<VBuffer<short>>(15); |
| 152 | + var i8getter = cursor.GetGetter<VBuffer<sbyte>>(16); |
| 153 | + var u64getter = cursor.GetGetter<VBuffer<ulong>>(17); |
| 154 | + var u32getter = cursor.GetGetter<VBuffer<uint>>(18); |
| 155 | + var u16getter = cursor.GetGetter<VBuffer<ushort>>(19); |
| 156 | + var u8getter = cursor.GetGetter<VBuffer<byte>>(20); |
| 157 | + var boolgetter = cursor.GetGetter<VBuffer<bool>>(21); |
| 158 | + |
| 159 | + |
| 160 | + VBuffer<double> f64 = default; |
| 161 | + VBuffer<float> f32 = default; |
| 162 | + VBuffer<long> i64 = default; |
| 163 | + VBuffer<int> i32 = default; |
| 164 | + VBuffer<short> i16 = default; |
| 165 | + VBuffer<sbyte> i8 = default; |
| 166 | + VBuffer<ulong> u64 = default; |
| 167 | + VBuffer<uint> u32 = default; |
| 168 | + VBuffer<ushort> u16 = default; |
| 169 | + VBuffer<byte> u8 = default; |
| 170 | + VBuffer<bool> b = default; |
| 171 | + foreach (var sample in data) |
| 172 | + { |
| 173 | + Assert.True(cursor.MoveNext()); |
| 174 | + |
| 175 | + f64getter(ref f64); |
| 176 | + f32getter(ref f32); |
| 177 | + i64getter(ref i64); |
| 178 | + i32getter(ref i32); |
| 179 | + i16getter(ref i16); |
| 180 | + i8getter(ref i8); |
| 181 | + u64getter(ref u64); |
| 182 | + u32getter(ref u32); |
| 183 | + u16getter(ref u16); |
| 184 | + u8getter(ref u8); |
| 185 | + u8getter(ref u8); |
| 186 | + boolgetter(ref b); |
| 187 | + |
| 188 | + var f64Values = f64.GetValues(); |
| 189 | + Assert.Equal(2, f64Values.Length); |
| 190 | + Assert.True(f64Values.SequenceEqual(sample.f64)); |
| 191 | + var f32Values = f32.GetValues(); |
| 192 | + Assert.Equal(2, f32Values.Length); |
| 193 | + Assert.True(f32Values.SequenceEqual(sample.f32)); |
| 194 | + var i64Values = i64.GetValues(); |
| 195 | + Assert.Equal(2, i64Values.Length); |
| 196 | + Assert.True(i64Values.SequenceEqual(sample.i64)); |
| 197 | + var i32Values = i32.GetValues(); |
| 198 | + Assert.Equal(2, i32Values.Length); |
| 199 | + Assert.True(i32Values.SequenceEqual(sample.i32)); |
| 200 | + var i16Values = i16.GetValues(); |
| 201 | + Assert.Equal(2, i16Values.Length); |
| 202 | + Assert.True(i16Values.SequenceEqual(sample.i16)); |
| 203 | + var i8Values = i8.GetValues(); |
| 204 | + Assert.Equal(2, i8Values.Length); |
| 205 | + Assert.True(i8Values.SequenceEqual(sample.i8)); |
| 206 | + var u64Values = u64.GetValues(); |
| 207 | + Assert.Equal(2, u64Values.Length); |
| 208 | + Assert.True(u64Values.SequenceEqual(sample.u64)); |
| 209 | + var u32Values = u32.GetValues(); |
| 210 | + Assert.Equal(2, u32Values.Length); |
| 211 | + Assert.True(u32Values.SequenceEqual(sample.u32)); |
| 212 | + var u16Values = u16.GetValues(); |
| 213 | + Assert.Equal(2, u16Values.Length); |
| 214 | + Assert.True(u16Values.SequenceEqual(sample.u16)); |
| 215 | + var u8Values = u8.GetValues(); |
| 216 | + Assert.Equal(2, u8Values.Length); |
| 217 | + Assert.True(u8Values.SequenceEqual(sample.u8)); |
| 218 | + var bValues = b.GetValues(); |
| 219 | + Assert.Equal(2, bValues.Length); |
| 220 | + Assert.True(bValues.SequenceEqual(sample.b)); |
| 221 | + } |
| 222 | + Assert.False(cursor.MoveNext()); |
| 223 | + } |
| 224 | + } |
| 225 | + |
72 | 226 | [Fact(Skip = "Model files are not available yet")]
|
73 | 227 | public void TensorFlowTransformObjectDetectionTest()
|
74 | 228 | {
|
|
0 commit comments